├── DIEN ├── README.md ├── __pycache__ │ ├── data_iterator.cpython-35.pyc │ ├── model_taobao_allfea.cpython-35.pyc │ ├── rnn.cpython-35.pyc │ ├── utils.cpython-35.pyc │ └── wrap_time.cpython-35.pyc ├── data_iterator.py ├── data_loader.py ├── data_util.sh ├── data_utils.py ├── model_taobao_allfea.py ├── plot │ └── plot1.png ├── rnn.py ├── tianchi.md ├── train_taobao_processed_allfea.py ├── utils.py └── wrap_time.py └── LICENSE /DIEN/README.md: -------------------------------------------------------------------------------- 1 | DIEN model is developed from the paper "Deep Interest Evolution Network for Click-Through Rate Prediction" https://arxiv.org/abs/1809.03672 2 | In this repo, the model is revised to accommandate the Taobao user behavior history data with some sampling and post-precessing. 3 | # Data 4 | In our repo, we have processed the data and stored into Tianchi Platform. Following below instructions to fetch it: 5 | ``` 6 | https://tianchi.aliyun.com/dataset/dataDetail?dataId=81505 7 | ``` 8 | After downloading, please put the file under DIEN,your model folder. Please untar the folder. 9 | 10 | # Enviroment 11 | The reference use Nvidia docker. 12 | ``` 13 | docker pull nvcr.io/nvidia/tensorflow:19.10-py3 14 | ``` 15 | Set up docker container by 16 | ``` 17 | docker run --gpus all --privileged=true --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --name=tensorflow-nv -it --rm -v yourdirectory/DIEN:/DIEN --network=host nvcr.io/nvidia/tensorflow:19.10-py3 18 | ``` 19 | Map your data_dir folder to the container space. 20 | 21 | # Run 22 | The command to run this model is quite easy. Please check the command below. **model** can be DIEN_with_neg, DIEN, DIN, etc.. 23 | Please check the source code to find more model choice. **rand** is any integer number for random seed. 24 | 25 | `python train_taobao_processed_allfea.py train ` 26 | Please make sure your data dir is correctly specified in train_taobao_processed_allfea.py. 27 | 28 | We have made 5 runs for each model(each line represents mean values for 5 runs) to show the model accuracy of AUC for predicting the user click-through rate. 29 | ![A test image](plot/plot1.png) 30 | -------------------------------------------------------------------------------- /DIEN/__pycache__/data_iterator.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/bigcomputing/4e308d8338ea9c4cd0fe8c430c7ac21e33d94373/DIEN/__pycache__/data_iterator.cpython-35.pyc -------------------------------------------------------------------------------- /DIEN/__pycache__/model_taobao_allfea.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/bigcomputing/4e308d8338ea9c4cd0fe8c430c7ac21e33d94373/DIEN/__pycache__/model_taobao_allfea.cpython-35.pyc -------------------------------------------------------------------------------- /DIEN/__pycache__/rnn.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/bigcomputing/4e308d8338ea9c4cd0fe8c430c7ac21e33d94373/DIEN/__pycache__/rnn.cpython-35.pyc -------------------------------------------------------------------------------- /DIEN/__pycache__/utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/bigcomputing/4e308d8338ea9c4cd0fe8c430c7ac21e33d94373/DIEN/__pycache__/utils.cpython-35.pyc -------------------------------------------------------------------------------- /DIEN/__pycache__/wrap_time.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/bigcomputing/4e308d8338ea9c4cd0fe8c430c7ac21e33d94373/DIEN/__pycache__/wrap_time.cpython-35.pyc -------------------------------------------------------------------------------- /DIEN/data_iterator.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import json 3 | #import pickle as pkl 4 | import random 5 | import numpy as np 6 | import sys 7 | from functools import wraps 8 | import time 9 | from wrap_time import time_it 10 | import copy 11 | 12 | defaultencoding = 'utf-8' 13 | if sys.getdefaultencoding() != defaultencoding: 14 | reload(sys) 15 | sys.setdefaultencoding(defaultencoding) 16 | 17 | 18 | class DataIterator: 19 | 20 | @time_it(freq=10) 21 | def __init__(self, source, dict_list, 22 | batch_size=128, 23 | maxlen=2000, 24 | skip_empty=False, 25 | sort_by_length=True, 26 | max_batch_size=20, 27 | minlen=None, 28 | parall=False 29 | ): 30 | self.source = open(source, 'r') 31 | self.batch_shuffle = 1 32 | self.neg_sample = 'LastInstance' # 'Random' or 'LastInstance' 33 | #self.user_dict = copy.copy(dict_list[0]) 34 | #self.item_dict = copy.copy(dict_list[1]) 35 | #self.cate_dict = copy.copy(dict_list[2]) 36 | #self.shop_dict = copy.copy(dict_list[3]) 37 | #self.node_dict = copy.copy(dict_list[4]) 38 | #self.product_dict = copy.copy(dict_list[5]) 39 | #self.brand_dict = copy.copy(dict_list[6]) 40 | #self.item_info = copy.copy(dict_list[7]) 41 | self.user_dict, self.item_dict, self.cate_dict, self.shop_dict, self.node_dict, self.product_dict, self.brand_dict, self.item_info = dict_list 42 | # self.user_dict = json.load(open('user_voc.json', 'r')) 43 | # self.item_dict = json.load(open('item_voc.json', 'r')) 44 | # self.cate_dict = json.load(open('cate_voc.json', 'r')) 45 | # self.shop_dict = json.load(open('shop_voc.json', 'r')) 46 | # self.node_dict = json.load(open('node_voc.json', 'r')) 47 | # self.product_dict = json.load(open('product_voc.json', 'r')) 48 | # self.brand_dict = json.load(open('brand_voc.json', 'r')) 49 | # self.item_info = json.load(open('item_info.json', 'r')) 50 | self.all_items = self.item_info.keys() 51 | self.num_items = len(self.all_items) 52 | self.batch_size = batch_size 53 | self.maxlen = maxlen 54 | self.minlen = minlen 55 | self.skip_empty = skip_empty 56 | self.sort_by_length = sort_by_length 57 | self.max_catch_num = 20 58 | 59 | self.source_buffer = [] 60 | #self.batch_size = batch_size * max_batch_size 61 | self.batch_size = batch_size 62 | self.neg_hist_catch = {} 63 | self.end_of_data = False 64 | #generate random neg item to initialize last_item information 65 | item_idx = int(random.random()*self.num_items) 66 | neg_item = self.all_items[item_idx] 67 | self.last_cate = self.map_cate(self.item_info[neg_item][0]) 68 | self.last_shop = self.map_shop(self.item_info[neg_item][1]) 69 | self.last_node = self.map_node(self.item_info[neg_item][2]) 70 | self.last_product = self.map_product(self.item_info[neg_item][3]) 71 | self.last_brand = self.map_brand(self.item_info[neg_item][4]) 72 | self.last_item = self.map_item(self.all_items[item_idx])#map origin item to item_id 73 | 74 | def get_id_nums(self): 75 | uid_n = len(self.user_dict.keys()) 76 | item_n = len(self.item_dict.keys()) 77 | cate_n = len(self.cate_dict.keys()) 78 | shop_n = len(self.shop_dict.keys()) 79 | node_n = len(self.node_dict.keys()) 80 | product_n = len(self.product_dict.keys()) 81 | brand_n = len(self.brand_dict.keys()) 82 | return uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n 83 | 84 | def map_item(self, x): 85 | return int(self.item_dict.get(x, -1)) 86 | 87 | def map_user(self, x): 88 | return int(self.user_dict.get(x, -1)) 89 | 90 | def map_cate(self, x): 91 | return int(self.cate_dict.get(x, -1)) 92 | 93 | def map_shop(self, x): 94 | return int(self.shop_dict.get(x, -1)) 95 | 96 | def map_node(self, x): 97 | return int(self.node_dict.get(x, -1)) 98 | 99 | def map_product(self, x): 100 | return int(self.product_dict.get(x, -1)) 101 | 102 | def map_brand(self, x): 103 | return int(self.brand_dict.get(x, -1)) 104 | 105 | def gen_item_block(self, item_idx): 106 | neg_item = self.all_items[item_idx] 107 | neg_cate = self.map_cate(self.item_info[neg_item][0]) 108 | neg_shop = self.map_shop(self.item_info[neg_item][1]) 109 | neg_node = self.map_node(self.item_info[neg_item][2]) 110 | neg_product = self.map_product(self.item_info[neg_item][3]) 111 | neg_brand = self.map_brand(self.item_info[neg_item][4]) 112 | neg_item = self.map_item(self.all_items[item_idx])#map origin item to item_id 113 | return neg_item, neg_cate, neg_shop, neg_node, neg_product, neg_brand 114 | 115 | def gen_neg_hist(self, length): 116 | if len(self.neg_hist_catch.get(length, [1])) == self.max_catch_num: 117 | index = int(random.random()*self.max_catch_num) 118 | return self.neg_hist_catch[length][index] 119 | else: 120 | #generate a new neg hist 121 | neg_item_hist = [] 122 | neg_cate_hist = [] 123 | neg_shop_hist = [] 124 | neg_node_hist = [] 125 | neg_product_hist = [] 126 | neg_brand_hist = [] 127 | for i in range(length): 128 | item_idx = int(random.random()*self.num_items) 129 | neg_item = self.all_items[item_idx] 130 | neg_cate = self.map_cate(self.item_info[neg_item][0]) 131 | neg_cate_hist.append(neg_cate) 132 | neg_shop = self.map_shop(self.item_info[neg_item][1]) 133 | neg_shop_hist.append(neg_shop) 134 | neg_node = self.map_node(self.item_info[neg_item][2]) 135 | neg_node_hist.append(neg_node) 136 | neg_product = self.map_product(self.item_info[neg_item][3]) 137 | neg_product_hist.append(neg_product) 138 | neg_brand = self.map_brand(self.item_info[neg_item][4]) 139 | neg_brand_hist.append(neg_brand) 140 | neg_item = self.map_item(self.all_items[item_idx])#map origin item to item_id 141 | neg_item_hist.append(neg_item) 142 | self.neg_hist_catch[length] = self.neg_hist_catch.get(length, []) 143 | self.neg_hist_catch[length].append([neg_item_hist, neg_cate_hist, neg_shop_hist, neg_node_hist, neg_product_hist, neg_brand_hist]) 144 | return [neg_item_hist, neg_cate_hist, neg_shop_hist, neg_node_hist, neg_product_hist, neg_brand_hist] 145 | 146 | def fill_ndarray(self, hist): 147 | nd_his = numpy.ones(self.maxlen) * -1 148 | nd_his[:len(hist)] = hist 149 | return nd_his 150 | 151 | 152 | def __iter__(self): 153 | return self 154 | 155 | def reset(self): 156 | self.source.seek(0) 157 | 158 | 159 | @time_it(freq=20) 160 | def next(self): 161 | if self.end_of_data: 162 | self.end_of_data = False 163 | self.reset() 164 | raise StopIteration 165 | 166 | source = [] 167 | target = [] 168 | hist_item_list = [] 169 | hist_cate_list = [] 170 | hist_shop_list = [] 171 | hist_node_list = [] 172 | hist_product_list = [] 173 | hist_brand_list = [] 174 | 175 | neg_hist_item_list = [] 176 | neg_hist_cate_list = [] 177 | neg_hist_shop_list = [] 178 | neg_hist_node_list = [] 179 | neg_hist_product_list = [] 180 | neg_hist_brand_list = [] 181 | count = 0 182 | if len(self.source_buffer) == 0: 183 | for k_ in xrange(self.batch_size): 184 | ss = self.source.readline() 185 | if ss == "" and count < self.batch_size: 186 | self.end_of_data = True 187 | self.source.seek(0) 188 | ss = self.source.readline() 189 | self.source_buffer.append(ss.strip().split("^H")) 190 | count += 1 191 | 192 | if len(self.source_buffer) == 0: 193 | self.end_of_data = False 194 | self.reset() 195 | raise StopIteration 196 | try: 197 | 198 | # actual work here 199 | while True: 200 | 201 | # read from source file and map to word index 202 | try: 203 | ss = self.source_buffer.pop() 204 | except IndexError: 205 | break 206 | uid = self.map_user(ss[0]) 207 | hist_item = map(self.map_item, ss[1].split('\003')) 208 | hist_cate = map(self.map_cate, ss[2].split('\003')) 209 | hist_shop = map(self.map_shop, ss[3].split('\003')) 210 | hist_node = map(self.map_node, ss[4].split('\003')) 211 | hist_product = map(self.map_product, ss[5].split('\003')) 212 | hist_brand = map(self.map_brand, ss[6].split('\003')) 213 | 214 | pos_item = hist_item[-1] 215 | pos_cate = hist_cate[-1] 216 | pos_shop = hist_shop[-1] 217 | pos_node = hist_node[-1] 218 | pos_product = hist_product[-1] 219 | pos_brand = hist_brand[-1] 220 | if self.neg_sample == 'LastInstance': 221 | #set item of last instance as neg sample 222 | neg_item, neg_cate, neg_shop, neg_node, neg_product, neg_brand = self.last_item, self.last_cate, self.last_shop, self.last_node, self.last_product, self.last_brand 223 | elif self.neg_sample == 'Random': 224 | #generate random neg_item information for neg sample 225 | item_idx = int(random.random()*self.num_items) 226 | while self.map_item(self.all_items[item_idx]) in hist_item: 227 | item_idx = int(random.random()*self.num_items) 228 | neg_item = self.all_items[item_idx] 229 | neg_cate = self.map_cate(self.item_info[neg_item][0]) 230 | neg_shop = self.map_shop(self.item_info[neg_item][1]) 231 | neg_node = self.map_node(self.item_info[neg_item][2]) 232 | neg_product = self.map_product(self.item_info[neg_item][3]) 233 | neg_brand = self.map_brand(self.item_info[neg_item][4]) 234 | neg_item = self.map_item(self.all_items[item_idx])#map origin item to item_id 235 | #gen neg hist 236 | lengthx = len(hist_item[-self.maxlen:]) 237 | random_neg_hist = self.gen_neg_hist(lengthx) 238 | #add positive sample 239 | source.append([uid, pos_item, pos_cate, pos_shop, pos_node, pos_product, pos_brand]) 240 | target.append([1, 0]) 241 | # drop the last item 242 | hist_item_list.append(self.fill_ndarray(hist_item[-(self.maxlen+1):-1])) 243 | hist_cate_list.append(self.fill_ndarray(hist_cate[-(self.maxlen+1):-1])) 244 | hist_shop_list.append(self.fill_ndarray(hist_shop[-(self.maxlen+1):-1])) 245 | hist_node_list.append(self.fill_ndarray(hist_node[-(self.maxlen+1):-1])) 246 | hist_product_list.append(self.fill_ndarray(hist_product[-(self.maxlen+1):-1])) 247 | hist_brand_list.append(self.fill_ndarray(hist_brand[-(self.maxlen+1):-1])) 248 | neg_hist_item_list.append(self.fill_ndarray(random_neg_hist[0])) 249 | neg_hist_cate_list.append(self.fill_ndarray(random_neg_hist[1])) 250 | neg_hist_shop_list.append(self.fill_ndarray(random_neg_hist[2])) 251 | neg_hist_node_list.append(self.fill_ndarray(random_neg_hist[3])) 252 | neg_hist_product_list.append(self.fill_ndarray(random_neg_hist[4])) 253 | neg_hist_brand_list.append(self.fill_ndarray(random_neg_hist[5])) 254 | #add negative sample 255 | source.append([uid, neg_item, neg_cate, neg_shop, neg_node, neg_product, neg_brand]) 256 | target.append([0, 1]) 257 | hist_item_list.append(self.fill_ndarray(hist_item[-(self.maxlen+1):-1])) 258 | hist_cate_list.append(self.fill_ndarray(hist_cate[-(self.maxlen+1):-1])) 259 | hist_shop_list.append(self.fill_ndarray(hist_shop[-(self.maxlen+1):-1])) 260 | hist_node_list.append(self.fill_ndarray(hist_node[-(self.maxlen+1):-1])) 261 | hist_product_list.append(self.fill_ndarray(hist_product[-(self.maxlen+1):-1])) 262 | hist_brand_list.append(self.fill_ndarray(hist_brand[-(self.maxlen+1):-1])) 263 | neg_hist_item_list.append(self.fill_ndarray(random_neg_hist[0])) 264 | neg_hist_cate_list.append(self.fill_ndarray(random_neg_hist[1])) 265 | neg_hist_shop_list.append(self.fill_ndarray(random_neg_hist[2])) 266 | neg_hist_node_list.append(self.fill_ndarray(random_neg_hist[3])) 267 | neg_hist_product_list.append(self.fill_ndarray(random_neg_hist[4])) 268 | neg_hist_brand_list.append(self.fill_ndarray(random_neg_hist[5])) 269 | if self.neg_sample == 'LastInstance': 270 | self.last_item, self.last_cate, self.last_shop, self.last_node, self.last_product, self.last_brand = pos_item, pos_cate, pos_shop, pos_node, pos_product, pos_brand 271 | if len(source) >= self.batch_size or len(target) >= self.batch_size: 272 | break 273 | except IOError: 274 | self.end_of_data = True 275 | 276 | # all sentence pairs in maxibatch filtered out because of length 277 | if len(source) == 0 or len(target) == 0: 278 | source, target = self.next() 279 | 280 | 281 | uid_array = np.array(source)[:,0] 282 | item_array = np.array(source)[:,1] 283 | cate_array = np.array(source)[:,2] 284 | shop_array = np.array(source)[:,3] 285 | node_array = np.array(source)[:,4] 286 | product_array = np.array(source)[:,5] 287 | brand_array = np.array(source)[:,6] 288 | 289 | target_array = np.array(target) 290 | history_item_array = np.array(hist_item_list) 291 | history_cate_array = np.array(hist_cate_list) 292 | history_shop_array = np.array(hist_shop_list) 293 | history_node_array = np.array(hist_node_list) 294 | history_product_array = np.array(hist_product_list) 295 | history_brand_array = np.array(hist_brand_list) 296 | 297 | neg_history_item_array = np.array(neg_hist_item_list) 298 | neg_history_cate_array = np.array(neg_hist_cate_list) 299 | neg_history_shop_array = np.array(neg_hist_shop_list) 300 | neg_history_node_array = np.array(neg_hist_node_list) 301 | neg_history_product_array = np.array(neg_hist_product_list) 302 | neg_history_brand_array = np.array(neg_hist_brand_list) 303 | #history_neg_item_array = np.array(6eg_item_list) 304 | #history_neg_cate_array = np.array(neg_cate_list) 305 | 306 | history_mask_array = np.greater(history_item_array, 0)*1.0 307 | if self.batch_shuffle: 308 | per = np.random.permutation(uid_array.shape[0]) 309 | uid_array = uid_array[per] 310 | item_array = item_array[per] 311 | cate_array = cate_array[per] 312 | shop_array = shop_array[per] 313 | node_array = node_array[per] 314 | product_array = product_array[per] 315 | brand_array = brand_array[per] 316 | target_array = target_array[per] 317 | history_item_array = history_item_array[per, :] 318 | history_cate_array = history_cate_array[per, :] 319 | history_shop_array = history_shop_array[per, :] 320 | history_node_array = history_node_array[per, :] 321 | history_product_array = history_product_array[per, :] 322 | history_brand_array = history_brand_array[per, :] 323 | 324 | neg_history_item_array = neg_history_item_array[per, :] 325 | neg_history_cate_array = neg_history_cate_array[per, :] 326 | neg_history_shop_array = neg_history_shop_array[per, :] 327 | neg_history_node_array = neg_history_node_array[per, :] 328 | neg_history_product_array = neg_history_product_array[per, :] 329 | neg_history_brand_array = neg_history_brand_array[per, :] 330 | 331 | 332 | 333 | 334 | return (uid_array, item_array, cate_array, shop_array, node_array, product_array, brand_array), \ 335 | (target_array, history_item_array, history_cate_array, history_shop_array, history_node_array, history_product_array, history_brand_array, history_mask_array, neg_history_item_array, neg_history_cate_array, neg_history_shop_array, neg_history_node_array, neg_history_product_array, neg_history_brand_array) 336 | 337 | 338 | -------------------------------------------------------------------------------- /DIEN/data_loader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from os import path 3 | import random 4 | import numpy as np 5 | import json 6 | import time 7 | import threading 8 | from collections import deque 9 | 10 | class DataLoader: 11 | 12 | def __init__( 13 | self, 14 | data_path, 15 | data_file, 16 | batch_size, 17 | data_file_num, 18 | sleep_time=1, 19 | max_queue_size = 2 20 | ): 21 | # load data 22 | self.queue = deque() #multiprocessing.Queue(maxsize=max_queue_size) # it may change in future if we decide to split data into many small chunks instead of 4 23 | self.batch_size = batch_size 24 | self.data_path = data_path 25 | self.data_file = data_file 26 | self.data_file_num = data_file_num 27 | self.sleep_time = sleep_time 28 | self.max_queue_size = max_queue_size 29 | self.help_count = 0 30 | 31 | def __iter__(self): 32 | return self 33 | 34 | def data_read(self, start_id, total_thread): 35 | sample_id = start_id 36 | while sample_id < self.data_file_num: 37 | if len(self.queue) >= self.max_queue_size: 38 | time.sleep(1) 39 | continue 40 | processed_data_path = self.data_path + self.data_file + "_" + str(sample_id) + '_processed.npz' 41 | print('Start loading processed data...' + processed_data_path) 42 | st = time.time() 43 | data = np.load(processed_data_path) 44 | source = data['source_array'] 45 | uid_array = np.array(source)[:,0] 46 | item_array = np.array(source)[:,1] 47 | cate_array = np.array(source)[:,2] 48 | shop_array = np.array(source)[:,3] 49 | node_array = np.array(source)[:,4] 50 | product_array = np.array(source)[:,5] 51 | brand_array = np.array(source)[:,6] 52 | 53 | target = data['target_array'] 54 | history_item = data['history_item_array'] 55 | history_cate = data['history_cate_array'] 56 | history_shop = data['history_shop_array'] 57 | history_node = data['history_node_array'] 58 | history_product = data['history_product_array'] 59 | history_brand = data['history_brand_array'] 60 | 61 | neg_history_item = data['neg_history_item_array'] 62 | neg_history_cate = data['neg_history_cate_array'] 63 | neg_history_shop = data['neg_history_shop_array'] 64 | neg_history_node = data['neg_history_node_array'] 65 | neg_history_product = data['neg_history_product_array'] 66 | neg_history_brand = data['neg_history_brand_array'] 67 | print('Finish loading processed data id '+ str(sample_id) + ',Time cost = %.4f' % (time.time()-st)) 68 | data_file = (uid_array,item_array,cate_array,shop_array,node_array,product_array,brand_array,\ 69 | target, history_item,history_cate,history_shop, history_node,history_product,history_brand,\ 70 | neg_history_item,neg_history_cate,neg_history_shop, neg_history_node,neg_history_product,neg_history_brand) 71 | while self.help_count % total_thread != start_id: 72 | time.sleep(1) 73 | print('help_count=', self.help_count) 74 | self.queue.append(data_file) 75 | self.help_count += 1 76 | sample_id = sample_id + total_thread 77 | 78 | 79 | def _batch_data(self, data, data_slice): 80 | uid_array,item_array,cate_array,shop_array,node_array,product_array,brand_array,\ 81 | target, history_item,history_cate,history_shop, history_node,history_product,history_brand,\ 82 | neg_history_item,neg_history_cate,neg_history_shop, neg_history_node,neg_history_product,neg_history_brand = data 83 | #print("in _batch_data func") 84 | user_id = uid_array[data_slice] 85 | item_id = item_array[data_slice] 86 | cate_id = cate_array[data_slice] 87 | shop_id = shop_array[data_slice] 88 | node_id = node_array[data_slice] 89 | product_id = product_array[data_slice] 90 | brand_id = brand_array[data_slice] 91 | label = target[data_slice, :] 92 | hist_item = history_item[data_slice, :] 93 | hist_cate = history_cate[data_slice, :] 94 | hist_shop = history_shop[data_slice, :] 95 | hist_node = history_node[data_slice, :] 96 | hist_product = history_product[data_slice, :] 97 | hist_brand = history_brand[data_slice, :] 98 | 99 | hist_mask = np.greater( hist_item, 0) * 1.0 100 | 101 | neg_hist_item = neg_history_item[data_slice, :] 102 | neg_hist_cate = neg_history_cate[data_slice, :] 103 | neg_hist_shop = neg_history_shop[data_slice, :] 104 | neg_hist_node = neg_history_node[data_slice, :] 105 | neg_hist_product = neg_history_product[data_slice, :] 106 | neg_hist_brand = neg_history_brand[data_slice, :] 107 | 108 | return [user_id, item_id, cate_id,shop_id, node_id, product_id, brand_id, 109 | label, hist_item, hist_cate, hist_shop, hist_node, hist_product, hist_brand, 110 | hist_mask, neg_hist_item, neg_hist_cate, neg_hist_shop, neg_hist_node, 111 | neg_hist_product, neg_hist_brand ] 112 | 113 | def next(self): 114 | previous_data_out = [] 115 | data_file_read = 0 116 | batch_id = 0 117 | #print('in next func') 118 | #import pdb; pdb.set_trace() 119 | previous_line = 0 120 | while len(self.queue) < 2: 121 | time.sleep(1) 122 | print ('Now the queue has two data file loaded in!') 123 | while data_file_read < self.data_file_num: 124 | if len(self.queue) == 0: 125 | time.sleep(1) 126 | continue 127 | data = self.queue.popleft() 128 | file_line_num = data[0].shape[0] 129 | start_ind = 0 130 | data_file_read = data_file_read + 1 131 | stime = time.time() 132 | #print('start one file,time=', stime) 133 | while start_ind <= file_line_num - self.batch_size: 134 | 135 | if previous_line != 0: 136 | batch_left = self.batch_size - previous_line 137 | else: 138 | batch_left = self.batch_size 139 | data_slice = slice(start_ind, start_ind + batch_left) 140 | # slice the data from the list 141 | data_out = self._batch_data(data, data_slice) #data_out is tuple 142 | 143 | if previous_line != 0: 144 | #attach the data 145 | for i in range(len(data_out)): 146 | data_out[i] = np.concatenate( 147 | [previous_data_out[i], data_out[i]], 148 | axis=0 149 | ) 150 | if self.batch_size != len(data_out[0]): 151 | raise ValueError('batch fetched wrong!') 152 | 153 | start_ind = start_ind + batch_left 154 | 155 | previous_line = 0 156 | #print("start_ind ", start_ind) 157 | yield data_out 158 | if start_ind != file_line_num: 159 | data_slice = slice(start_ind, file_line_num) 160 | previous_data_out = self._batch_data(data, data_slice) 161 | previous_line = file_line_num - start_ind 162 | print("Left batch of size %d" %( previous_line)) 163 | etime = time.time() 164 | print('Consume one file takes time= %.4f' %(etime-stime)) 165 | print ('drop last batch since it is not full batch size') 166 | 167 | 168 | def test(): 169 | data_load = DataLoader('/disk3/w.wei/dien-new/process_data_maxlen100_0225/', 'train_sample', 256, 15) 170 | producer1 = threading.Thread(target=data_load.data_read, args=(0, 3)) 171 | producer2 = threading.Thread(target=data_load.data_read, args=(1, 3)) 172 | producer3 = threading.Thread(target=data_load.data_read, args=(2, 3)) 173 | producer1.start() 174 | producer2.start() 175 | producer3.start() 176 | #data_i = iter(data_load) 177 | #data_o = next(data_i) 178 | #print('print=====',len(data_o)) 179 | num = 0 180 | for data in data_load.next(): 181 | num = num+1 182 | cnt = 1 183 | for i in range(10000): 184 | cnt = cnt * 1.0 185 | if num%1000 == 0: 186 | print('i=',num,',cnt=',cnt) 187 | 188 | producer1.join() 189 | producer2.join() 190 | producer3.join() 191 | 192 | if __name__ == '__main__': 193 | test() 194 | -------------------------------------------------------------------------------- /DIEN/data_util.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python data_utils.py --raw-data-file='/disk2/public_dataset/DeepCTRData/data/' 3 | -------------------------------------------------------------------------------- /DIEN/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | # 6 | # Description: preprocess input data for DIEN benchmark 7 | # 8 | # Utility function(s) to download and pre-process public data sets 9 | # - DIEN public data 10 | # .......... 11 | # 12 | 13 | 14 | import sys 15 | # import os 16 | from os import path 17 | # import io 18 | # from io import StringIO 19 | # import collections as coll 20 | import random 21 | import numpy as np 22 | import json 23 | import time 24 | 25 | class preprocessDataset: 26 | 27 | def __init__( 28 | self, 29 | data_path, 30 | maxlen, 31 | ): 32 | self.data_path = data_path 33 | #self.user_dict, self.item_dict, self.cate_dict, self.shop_dict, self.node_dict \ 34 | #, self.product_dict, self.brand_dict, self.item_info = dict_list 35 | self.neg_sample = 'LastInstance' # 'Random' or 'LastInstance' 36 | 37 | 38 | self.source_buffer = [] 39 | self.neg_hist_catch = {} 40 | self.maxlen = maxlen 41 | self.max_catch_num = 20 42 | #self.end_of_data = False 43 | st = time.time() 44 | print("Start loading dict...") 45 | #hack here 46 | data_path = '/disk3/w.wei/dien-new/' 47 | self.item_info = json.load(open(data_path + '/item_info.json', 'r')) 48 | print ('Finish loading item_info.json,length=',len(self.item_info.keys())) 49 | self.user_dict = json.load(open(data_path + '/user_voc.json', 'r')) 50 | print ('Finish loading user_voc.json,length=',len(self.user_dict.keys())) 51 | self.item_dict = json.load(open(data_path + '/item_voc.json', 'r')) 52 | print ('Finish loading item_voc.json,length=',len(self.item_dict.keys())) 53 | self.cate_dict = json.load(open(data_path + '/cate_voc.json', 'r')) 54 | print ('Finish loading cate_voc.json,length=',len(self.cate_dict.keys())) 55 | self.shop_dict = json.load(open(data_path + '/shop_voc.json', 'r')) 56 | print ('Finish loading shop_voc.json,length=',len(self.shop_dict.keys())) 57 | self.node_dict = json.load(open(data_path + '/node_voc.json', 'r')) 58 | print ('Finish loading node_voc.json,length=',len(self.node_dict.keys())) 59 | self.product_dict = json.load(open(data_path + '/product_voc.json', 'r')) 60 | print ('Finish loading product_voc.json,length=',len(self.product_dict.keys())) 61 | self.brand_dict = json.load(open(data_path + '/brand_voc.json', 'r')) 62 | print ('Finish loading brand_voc.json,length=',len(self.brand_dict.keys())) 63 | print("Time for load dict=", time.time()-st) 64 | #import pdb; pdb.set_trace() 65 | self.all_items = self.item_info.keys() 66 | self.num_items = len(self.all_items) 67 | #generate random neg item to initialize last_item information 68 | item_idx = int(random.random()*self.num_items) 69 | neg_item = self.all_items[item_idx] 70 | self.last_cate = self.map_cate(self.item_info[neg_item][0]) 71 | self.last_shop = self.map_shop(self.item_info[neg_item][1]) 72 | self.last_node = self.map_node(self.item_info[neg_item][2]) 73 | self.last_product = self.map_product(self.item_info[neg_item][3]) 74 | self.last_brand = self.map_brand(self.item_info[neg_item][4]) 75 | self.last_item = self.map_item(self.all_items[item_idx])#map origin item to item_id 76 | 77 | def get_id_nums(self): 78 | uid_n = len(self.user_dict.keys()) 79 | item_n = len(self.item_dict.keys()) 80 | cate_n = len(self.cate_dict.keys()) 81 | shop_n = len(self.shop_dict.keys()) 82 | node_n = len(self.node_dict.keys()) 83 | product_n = len(self.product_dict.keys()) 84 | brand_n = len(self.brand_dict.keys()) 85 | return uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n 86 | 87 | def map_item(self, x): 88 | return int(self.item_dict.get(x, -1)) 89 | 90 | def map_user(self, x): 91 | return int(self.user_dict.get(x, -1)) 92 | 93 | def map_cate(self, x): 94 | return int(self.cate_dict.get(x, -1)) 95 | 96 | def map_shop(self, x): 97 | return int(self.shop_dict.get(x, -1)) 98 | 99 | def map_node(self, x): 100 | return int(self.node_dict.get(x, -1)) 101 | 102 | def map_product(self, x): 103 | return int(self.product_dict.get(x, -1)) 104 | 105 | def map_brand(self, x): 106 | return int(self.brand_dict.get(x, -1)) 107 | 108 | def gen_neg_hist(self, length): 109 | if len(self.neg_hist_catch.get(length, [1])) == self.max_catch_num: 110 | index = int(random.random()*self.max_catch_num) 111 | return self.neg_hist_catch[length][index] 112 | else: 113 | #generate a new neg hist 114 | neg_item_hist = [] 115 | neg_cate_hist = [] 116 | neg_shop_hist = [] 117 | neg_node_hist = [] 118 | neg_product_hist = [] 119 | neg_brand_hist = [] 120 | for i in range(length): 121 | item_idx = int(random.random()*self.num_items) 122 | neg_item = self.all_items[item_idx] 123 | neg_cate = self.map_cate(self.item_info[neg_item][0]) 124 | neg_cate_hist.append(neg_cate) 125 | neg_shop = self.map_shop(self.item_info[neg_item][1]) 126 | neg_shop_hist.append(neg_shop) 127 | neg_node = self.map_node(self.item_info[neg_item][2]) 128 | neg_node_hist.append(neg_node) 129 | neg_product = self.map_product(self.item_info[neg_item][3]) 130 | neg_product_hist.append(neg_product) 131 | neg_brand = self.map_brand(self.item_info[neg_item][4]) 132 | neg_brand_hist.append(neg_brand) 133 | neg_item = self.map_item(self.all_items[item_idx])#map origin item to item_id 134 | neg_item_hist.append(neg_item) 135 | self.neg_hist_catch[length] = self.neg_hist_catch.get(length, []) # do not understand TODO 136 | self.neg_hist_catch[length].append([neg_item_hist, neg_cate_hist, neg_shop_hist,\ 137 | neg_node_hist, neg_product_hist, neg_brand_hist]) 138 | return [neg_item_hist, neg_cate_hist, neg_shop_hist, neg_node_hist, neg_product_hist, neg_brand_hist] 139 | 140 | def gen_neg_hist3(self, length): 141 | if len(self.neg_hist_catch.get(length, [1])) == self.max_catch_num: 142 | index = int(random.random()*self.max_catch_num) 143 | return self.neg_hist_catch[length][index] 144 | else: 145 | #generate a new neg hist 146 | neg_item_hist = [] 147 | neg_cate_hist = [] 148 | neg_shop_hist = [] 149 | neg_node_hist = [] 150 | neg_product_hist = [] 151 | neg_brand_hist = [] 152 | for i in range(length): 153 | item_idx = int(random.random()*self.num_items) 154 | neg_item = self.all_items[item_idx] 155 | 156 | tmp = self.item_info[neg_item] 157 | neg_cate_hist.append(self.map_cate(tmp[0])) 158 | neg_shop_hist.append(self.map_shop(tmp[1])) 159 | neg_node_hist.append(self.map_node(tmp[2])) 160 | neg_product_hist.append(self.map_product(tmp[3])) 161 | neg_brand_hist.append(self.map_brand(tmp[4])) 162 | neg_item_hist.append(self.map_item(neg_item))#map origin item to item_id 163 | self.neg_hist_catch[length] = self.neg_hist_catch.get(length, []) # do not understand TODO 164 | self.neg_hist_catch[length].append([neg_item_hist, neg_cate_hist, neg_shop_hist,\ 165 | neg_node_hist, neg_product_hist, neg_brand_hist]) 166 | return [neg_item_hist, neg_cate_hist, neg_shop_hist, neg_node_hist, neg_product_hist, neg_brand_hist] 167 | 168 | def gen_neg_hist2(self, length): 169 | if len(self.neg_hist_catch.get(length, [1])) == self.max_catch_num: 170 | index = int(random.random()*self.max_catch_num) 171 | return self.neg_hist_catch[length][index] 172 | else: 173 | #generate a new neg hist 174 | neg_item_hist = [] 175 | neg_cate_hist = [] 176 | neg_shop_hist = [] 177 | neg_node_hist = [] 178 | neg_product_hist = [] 179 | neg_brand_hist = [] 180 | item_idx = [int(random.random()*self.num_items) for i in range(length)] 181 | neg_item = self.all_items[item_idx] 182 | neg_info = self.item_info[neg_item] 183 | neg_cate = self.map_cate(neg_info[0]) 184 | neg_cate_hist.append(neg_cate) 185 | neg_shop = self.map_shop(neg_info[1]) 186 | neg_shop_hist.append(neg_shop) 187 | neg_node = self.map_node(neg_info[2]) 188 | neg_node_hist.append(neg_node) 189 | neg_product = self.map_product(neg_info[3]) 190 | neg_product_hist.append(neg_product) 191 | neg_brand = self.map_brand(neg_info[4]) 192 | neg_brand_hist.append(neg_brand) 193 | neg_item_id = self.map_item(neg_item)#map origin item to item_id 194 | neg_item_hist.append(neg_item_id) 195 | self.neg_hist_catch[length] = self.neg_hist_catch.get(length, []) # do not understand TODO 196 | self.neg_hist_catch[length].append([neg_item_hist, neg_cate_hist, neg_shop_hist,\ 197 | neg_node_hist, neg_product_hist, neg_brand_hist]) 198 | return [neg_item_hist, neg_cate_hist, neg_shop_hist, neg_node_hist, neg_product_hist, neg_brand_hist] 199 | 200 | def fill_ndarray(self, hist): 201 | nd_his = np.ones(self.maxlen) * -1 202 | nd_his[:len(hist)] = hist 203 | return nd_his 204 | 205 | # 2nd method by using previous user's behavior to replace the current user's negative behavior data 206 | def process2(self, data_files, file_name): 207 | #try: 208 | # ss = self.source_buffer.pop() 209 | #except IndexError: 210 | 211 | 212 | source = [] 213 | target = [] 214 | hist_item_list = [] 215 | hist_cate_list = [] 216 | hist_shop_list = [] 217 | hist_node_list = [] 218 | hist_product_list = [] 219 | hist_brand_list = [] 220 | 221 | neg_hist_item_list = [] 222 | neg_hist_cate_list = [] 223 | neg_hist_shop_list = [] 224 | neg_hist_node_list = [] 225 | neg_hist_product_list = [] 226 | neg_hist_brand_list = [] 227 | 228 | last_neg_hist = self.gen_neg_hist3(self.maxlen) 229 | 230 | st = time.time() 231 | #data_size = 20000 232 | #data_size = 500000 #around 9GB data for maxlength=100 233 | data_size = 50000 #around 9GB data for maxlength=1000 234 | line_count = 0 235 | file_count = 0 236 | for file_read in data_files: 237 | data_file = self.data_path + file_read 238 | print('Open file ', data_file) 239 | f = open(data_file, 'r') 240 | tmp_s = f.readline() 241 | while(tmp_s): 242 | line_count = line_count + 1 243 | 244 | ss = tmp_s.strip().split("^H") 245 | uid = self.map_user(ss[0]) 246 | hist_item = map(self.map_item, ss[1].split('\003')) 247 | hist_cate = map(self.map_cate, ss[2].split('\003')) 248 | hist_shop = map(self.map_shop, ss[3].split('\003')) 249 | hist_node = map(self.map_node, ss[4].split('\003')) 250 | hist_product = map(self.map_product, ss[5].split('\003')) 251 | hist_brand = map(self.map_brand, ss[6].split('\003')) 252 | 253 | pos_item = hist_item[-1] 254 | pos_cate = hist_cate[-1] 255 | pos_shop = hist_shop[-1] 256 | pos_node = hist_node[-1] 257 | pos_product = hist_product[-1] 258 | pos_brand = hist_brand[-1] 259 | if self.neg_sample == 'LastInstance': 260 | #set item of last instance as neg sample 261 | neg_item, neg_cate, neg_shop, neg_node, neg_product, neg_brand \ 262 | = self.last_item, self.last_cate, self.last_shop, self.last_node, \ 263 | self.last_product, self.last_brand 264 | # neg is the last user's behavior which is not random 265 | random_neg_hist = last_neg_hist 266 | elif self.neg_sample == 'Random': 267 | #generate random neg_item information for neg sample 268 | item_idx = int(random.random()*self.num_items) 269 | while self.map_item(self.all_items[item_idx]) in hist_item: 270 | item_idx = int(random.random()*self.num_items) 271 | neg_item = self.all_items[item_idx] 272 | neg_cate = self.map_cate(self.item_info[neg_item][0]) 273 | neg_shop = self.map_shop(self.item_info[neg_item][1]) 274 | neg_node = self.map_node(self.item_info[neg_item][2]) 275 | neg_product = self.map_product(self.item_info[neg_item][3]) 276 | neg_brand = self.map_brand(self.item_info[neg_item][4]) 277 | neg_item = self.map_item(self.all_items[item_idx])#map origin item to item_id 278 | #gen neg hist 279 | #lengthx = len(hist_item[-self.maxlen:]) 280 | #random_neg_hist = self.gen_neg_hist3(lengthx) 281 | 282 | 283 | #add positive sample 284 | source.append([uid, pos_item, pos_cate, pos_shop, pos_node, pos_product, pos_brand]) 285 | target.append([1, 0]) 286 | hist_item_list.append(hist_item[-(self.maxlen+1):-1]) 287 | hist_cate_list.append(hist_cate[-(self.maxlen+1):-1]) 288 | hist_shop_list.append(hist_shop[-(self.maxlen+1):-1]) 289 | hist_node_list.append(hist_node[-(self.maxlen+1):-1]) 290 | hist_product_list.append(hist_product[-(self.maxlen+1):-1]) 291 | hist_brand_list.append(hist_brand[-(self.maxlen+1):-1]) 292 | 293 | neg_hist_item_list.append(random_neg_hist[0]) 294 | neg_hist_cate_list.append(random_neg_hist[1]) 295 | neg_hist_shop_list.append(random_neg_hist[2]) 296 | neg_hist_node_list.append(random_neg_hist[3]) 297 | neg_hist_product_list.append(random_neg_hist[4]) 298 | neg_hist_brand_list.append(random_neg_hist[5]) 299 | 300 | #add negative sample, histogram are the same! 301 | source.append([uid, neg_item, neg_cate, neg_shop, neg_node, neg_product, neg_brand]) 302 | target.append([0, 1]) 303 | hist_item_list.append(hist_item[-(self.maxlen+1):-1]) 304 | hist_cate_list.append(hist_cate[-(self.maxlen+1):-1]) 305 | hist_shop_list.append(hist_shop[-(self.maxlen+1):-1]) 306 | hist_node_list.append(hist_node[-(self.maxlen+1):-1]) 307 | hist_product_list.append(hist_product[-(self.maxlen+1):-1]) 308 | hist_brand_list.append(hist_brand[-(self.maxlen+1):-1]) 309 | 310 | neg_hist_item_list.append(random_neg_hist[0]) 311 | neg_hist_cate_list.append(random_neg_hist[1]) 312 | neg_hist_shop_list.append(random_neg_hist[2]) 313 | neg_hist_node_list.append(random_neg_hist[3]) 314 | neg_hist_product_list.append(random_neg_hist[4]) 315 | neg_hist_brand_list.append(random_neg_hist[5]) 316 | if self.neg_sample == 'LastInstance': 317 | self.last_item, self.last_cate, self.last_shop, self.last_node, self.last_product, self.last_brand = pos_item, pos_cate, pos_shop, pos_node, pos_product, pos_brand 318 | last_neg_hist= [hist_item[-(self.maxlen+1):-1], 319 | hist_cate[-(self.maxlen+1):-1], 320 | hist_shop[-(self.maxlen+1):-1], 321 | hist_node[-(self.maxlen+1):-1], 322 | hist_product[-(self.maxlen+1):-1], 323 | hist_brand[-(self.maxlen+1):-1] ] 324 | 325 | # read in next line 326 | tmp_s = f.readline() 327 | 328 | if(line_count % 10000 == 0): 329 | print("Total processed lines = ", line_count,",spent time = ",time.time() -st) 330 | #print("source =",source_array) 331 | #print("neg history iteam =",neg_history_item_array) 332 | #break 333 | if(line_count == data_size ): 334 | print("Start to save: total processed lines = ", line_count,",spent time = ",time.time() -st) 335 | line_count = 0 336 | source_array = np.array(source) 337 | target_array = np.array(target) 338 | history_item_array = np.array(hist_item_list) 339 | history_cate_array = np.array(hist_cate_list) 340 | history_shop_array = np.array(hist_shop_list) 341 | history_node_array = np.array(hist_node_list) 342 | history_product_array = np.array(hist_product_list) 343 | history_brand_array = np.array(hist_brand_list) 344 | # 345 | print("test1", len(hist_item_list)) 346 | neg_history_item_array = np.array(neg_hist_item_list) 347 | neg_history_cate_array = np.array(neg_hist_cate_list) 348 | neg_history_shop_array = np.array(neg_hist_shop_list) 349 | neg_history_node_array = np.array(neg_hist_node_list) 350 | neg_history_product_array = np.array(neg_hist_product_list) 351 | neg_history_brand_array = np.array(neg_hist_brand_list) 352 | print("test2", len(neg_hist_item_list)) 353 | print("Finished array transform") 354 | np.savez('./' + file_name + '_' + str(file_count) + '_processed.npz', source_array=source_array,target_array=target_array, 355 | history_item_array=history_item_array,history_cate_array=history_cate_array, 356 | history_shop_array=history_shop_array,history_node_array=history_node_array, 357 | history_product_array=history_product_array,history_brand_array=history_brand_array, 358 | neg_history_item_array=neg_history_item_array,neg_history_cate_array=neg_history_cate_array, 359 | neg_history_shop_array=neg_history_shop_array,neg_history_node_array=neg_history_node_array, 360 | neg_history_product_array=neg_history_product_array,neg_history_brand_array=neg_history_brand_array) 361 | file_count = file_count + 1 362 | #===== empty the var 363 | source[:] = [] 364 | target[:] = [] 365 | hist_item_list[:] = [] 366 | hist_cate_list[:] = [] 367 | hist_shop_list[:] = [] 368 | hist_node_list[:] = [] 369 | hist_product_list[:] = [] 370 | hist_brand_list[:] = [] 371 | 372 | neg_hist_item_list[:] = [] 373 | neg_hist_cate_list[:] = [] 374 | neg_hist_shop_list[:] = [] 375 | neg_hist_node_list[:] = [] 376 | neg_hist_product_list[:] = [] 377 | neg_hist_brand_list[:] = [] 378 | # for residual lines 379 | 380 | if(line_count != 0 ): 381 | print("Residue lines:total processed lines = ", line_count,",spent time = ",time.time() -st) 382 | line_count = 0 383 | source_array = np.array(source) 384 | target_array = np.array(target) 385 | history_item_array = np.array(hist_item_list) 386 | history_cate_array = np.array(hist_cate_list) 387 | history_shop_array = np.array(hist_shop_list) 388 | history_node_array = np.array(hist_node_list) 389 | history_product_array = np.array(hist_product_list) 390 | history_brand_array = np.array(hist_brand_list) 391 | neg_history_item_array = np.array(neg_hist_item_list) 392 | neg_history_cate_array = np.array(neg_hist_cate_list) 393 | neg_history_shop_array = np.array(neg_hist_shop_list) 394 | neg_history_node_array = np.array(neg_hist_node_list) 395 | neg_history_product_array = np.array(neg_hist_product_list) 396 | neg_history_brand_array = np.array(neg_hist_brand_list) 397 | print("Finished array transform") 398 | np.savez('./' + file_name + '_' + str(file_count) + '_processed.npz', source_array=source_array,target_array=target_array, 399 | history_item_array=history_item_array,history_cate_array=history_cate_array, 400 | history_shop_array=history_shop_array,history_node_array=history_node_array, 401 | history_product_array=history_product_array,history_brand_array=history_brand_array, 402 | neg_history_item_array=neg_history_item_array,neg_history_cate_array=neg_history_cate_array, 403 | neg_history_shop_array=neg_history_shop_array,neg_history_node_array=neg_history_node_array, 404 | neg_history_product_array=neg_history_product_array,neg_history_brand_array=neg_history_brand_array) 405 | file_count = file_count + 1 406 | #===== empty the var 407 | source[:] = [] 408 | target[:] = [] 409 | hist_item_list[:] = [] 410 | hist_cate_list[:] = [] 411 | hist_shop_list[:] = [] 412 | hist_node_list[:] = [] 413 | hist_product_list[:] = [] 414 | hist_brand_list[:] = [] 415 | 416 | neg_hist_item_list[:] = [] 417 | neg_hist_cate_list[:] = [] 418 | neg_hist_shop_list[:] = [] 419 | neg_hist_node_list[:] = [] 420 | neg_hist_product_list[:] = [] 421 | neg_hist_brand_list[:] = [] 422 | 423 | # Process the data by read in the user information from raw_path file 424 | def process(self, data_files, file_name): 425 | #try: 426 | # ss = self.source_buffer.pop() 427 | #except IndexError: 428 | 429 | 430 | source = [] 431 | target = [] 432 | hist_item_list = [] 433 | hist_cate_list = [] 434 | hist_shop_list = [] 435 | hist_node_list = [] 436 | hist_product_list = [] 437 | hist_brand_list = [] 438 | 439 | neg_hist_item_list = [] 440 | neg_hist_cate_list = [] 441 | neg_hist_shop_list = [] 442 | neg_hist_node_list = [] 443 | neg_hist_product_list = [] 444 | neg_hist_brand_list = [] 445 | 446 | st = time.time() 447 | data_size = 500000 #around 9GB data for maxlength=100 448 | #data_size = 250000 #around 9GB data for maxlength=1500 449 | line_count = 0 450 | file_count = 0 451 | for file_read in data_files: 452 | data_file = self.data_path + file_read 453 | print('Open file ', data_file) 454 | f = open(data_file, 'r') 455 | tmp_s = f.readline() 456 | while(tmp_s): 457 | line_count = line_count + 1 458 | 459 | ss = tmp_s.strip().split("^H") 460 | uid = self.map_user(ss[0]) 461 | hist_item = map(self.map_item, ss[1].split('\003')) 462 | hist_cate = map(self.map_cate, ss[2].split('\003')) 463 | hist_shop = map(self.map_shop, ss[3].split('\003')) 464 | hist_node = map(self.map_node, ss[4].split('\003')) 465 | hist_product = map(self.map_product, ss[5].split('\003')) 466 | hist_brand = map(self.map_brand, ss[6].split('\003')) 467 | 468 | pos_item = hist_item[-1] 469 | pos_cate = hist_cate[-1] 470 | pos_shop = hist_shop[-1] 471 | pos_node = hist_node[-1] 472 | pos_product = hist_product[-1] 473 | pos_brand = hist_brand[-1] 474 | if self.neg_sample == 'LastInstance': 475 | #set item of last instance as neg sample 476 | neg_item, neg_cate, neg_shop, neg_node, neg_product, neg_brand \ 477 | = self.last_item, self.last_cate, self.last_shop, self.last_node, \ 478 | self.last_product, self.last_brand 479 | elif self.neg_sample == 'Random': 480 | #generate random neg_item information for neg sample 481 | item_idx = int(random.random()*self.num_items) 482 | while self.map_item(self.all_items[item_idx]) in hist_item: 483 | item_idx = int(random.random()*self.num_items) 484 | neg_item = self.all_items[item_idx] 485 | neg_cate = self.map_cate(self.item_info[neg_item][0]) 486 | neg_shop = self.map_shop(self.item_info[neg_item][1]) 487 | neg_node = self.map_node(self.item_info[neg_item][2]) 488 | neg_product = self.map_product(self.item_info[neg_item][3]) 489 | neg_brand = self.map_brand(self.item_info[neg_item][4]) 490 | neg_item = self.map_item(self.all_items[item_idx])#map origin item to item_id 491 | #gen neg hist 492 | lengthx = len(hist_item[-self.maxlen:]) 493 | random_neg_hist = self.gen_neg_hist3(lengthx) 494 | 495 | #add positive sample 496 | source.append([uid, pos_item, pos_cate, pos_shop, pos_node, pos_product, pos_brand]) 497 | target.append([1, 0]) 498 | hist_item_list.append(hist_item[-(self.maxlen+1):-1]) 499 | hist_cate_list.append(hist_cate[-(self.maxlen+1):-1]) 500 | hist_shop_list.append(hist_shop[-(self.maxlen+1):-1]) 501 | hist_node_list.append(hist_node[-(self.maxlen+1):-1]) 502 | hist_product_list.append(hist_product[-(self.maxlen+1):-1]) 503 | hist_brand_list.append(hist_brand[-(self.maxlen+1):-1]) 504 | 505 | neg_hist_item_list.append(random_neg_hist[0]) 506 | neg_hist_cate_list.append(random_neg_hist[1]) 507 | neg_hist_shop_list.append(random_neg_hist[2]) 508 | neg_hist_node_list.append(random_neg_hist[3]) 509 | neg_hist_product_list.append(random_neg_hist[4]) 510 | neg_hist_brand_list.append(random_neg_hist[5]) 511 | 512 | #add negative sample, histogram are the same! 513 | source.append([uid, neg_item, neg_cate, neg_shop, neg_node, neg_product, neg_brand]) 514 | target.append([0, 1]) 515 | hist_item_list.append(hist_item[-(self.maxlen+1):-1]) 516 | hist_cate_list.append(hist_cate[-(self.maxlen+1):-1]) 517 | hist_shop_list.append(hist_shop[-(self.maxlen+1):-1]) 518 | hist_node_list.append(hist_node[-(self.maxlen+1):-1]) 519 | hist_product_list.append(hist_product[-(self.maxlen+1):-1]) 520 | hist_brand_list.append(hist_brand[-(self.maxlen+1):-1]) 521 | 522 | neg_hist_item_list.append(random_neg_hist[0]) 523 | neg_hist_cate_list.append(random_neg_hist[1]) 524 | neg_hist_shop_list.append(random_neg_hist[2]) 525 | neg_hist_node_list.append(random_neg_hist[3]) 526 | neg_hist_product_list.append(random_neg_hist[4]) 527 | neg_hist_brand_list.append(random_neg_hist[5]) 528 | if self.neg_sample == 'LastInstance': 529 | self.last_item, self.last_cate, self.last_shop, self.last_node, self.last_product, self.last_brand = pos_item, pos_cate, pos_shop, pos_node, pos_product, pos_brand 530 | # seperate the source and target into array 531 | #uid_array = np.array(source)[:,0] 532 | #item_array = np.arrayoelf(source)[:,1] 533 | #cate_array = np.array(source)[:,2] 534 | #shop_array = np.array(source)[:,3] 535 | #node_array = np.array(source)[:,4] 536 | #product_array = np.array(source)[:,5] 537 | #brand_array = np.array(source)[:,6] 538 | 539 | 540 | # read in next line 541 | tmp_s = f.readline() 542 | 543 | if(line_count % 10000 == 0): 544 | print("Total processed lines = ", line_count,",spent time = ",time.time() -st) 545 | if(line_count == data_size): 546 | line_count = 0 547 | #print("Total processed lines = ", line_count,",spent time = ",time.time() -st) 548 | source_array = np.array(source) 549 | target_array = np.array(target) 550 | history_item_array = np.array(hist_item_list) 551 | history_cate_array = np.array(hist_cate_list) 552 | history_shop_array = np.array(hist_shop_list) 553 | history_node_array = np.array(hist_node_list) 554 | history_product_array = np.array(hist_product_list) 555 | history_brand_array = np.array(hist_brand_list) 556 | neg_history_item_array = np.array(neg_hist_item_list) 557 | neg_history_cate_array = np.array(neg_hist_cate_list) 558 | neg_history_shop_array = np.array(neg_hist_shop_list) 559 | neg_history_node_array = np.array(neg_hist_node_list) 560 | neg_history_product_array = np.array(neg_hist_product_list) 561 | neg_history_brand_array = np.array(neg_hist_brand_list) 562 | print("Finished array transform") 563 | np.savez('./' + file_name + '_' + str(file_count) + '_processed.npz', source_array=source_array,target_array=target_array, 564 | history_item_array=history_item_array,history_cate_array=history_cate_array, 565 | history_shop_array=history_shop_array,history_node_array=history_node_array, 566 | history_product_array=history_product_array,history_brand_array=history_brand_array, 567 | neg_history_item_array=neg_history_item_array,neg_history_cate_array=neg_history_cate_array, 568 | neg_history_shop_array=neg_history_shop_array,neg_history_node_array=neg_history_node_array, 569 | neg_history_product_array=neg_history_product_array,neg_history_brand_array=neg_history_brand_array) 570 | file_count = file_count + 1 571 | #===== empty the var 572 | source[:] = [] 573 | target[:] = [] 574 | hist_item_list[:] = [] 575 | hist_cate_list[:] = [] 576 | hist_shop_list[:] = [] 577 | hist_node_list[:] = [] 578 | hist_product_list[:] = [] 579 | hist_brand_list[:] = [] 580 | 581 | neg_hist_item_list[:] = [] 582 | neg_hist_cate_list[:] = [] 583 | neg_hist_shop_list[:] = [] 584 | neg_hist_node_list[:] = [] 585 | neg_hist_product_list[:] = [] 586 | neg_hist_brand_list[:] = [] 587 | # for residual lines 588 | 589 | if(line_count != 0 ): 590 | line_count = 0 591 | #print("Total processed lines = ", line_count,",spent time = ",time.time() -st) 592 | source_array = np.array(source) 593 | target_array = np.array(target) 594 | history_item_array = np.array(hist_item_list) 595 | history_cate_array = np.array(hist_cate_list) 596 | history_shop_array = np.array(hist_shop_list) 597 | history_node_array = np.array(hist_node_list) 598 | history_product_array = np.array(hist_product_list) 599 | history_brand_array = np.array(hist_brand_list) 600 | neg_history_item_array = np.array(neg_hist_item_list) 601 | neg_history_cate_array = np.array(neg_hist_cate_list) 602 | neg_history_shop_array = np.array(neg_hist_shop_list) 603 | neg_history_node_array = np.array(neg_hist_node_list) 604 | neg_history_product_array = np.array(neg_hist_product_list) 605 | neg_history_brand_array = np.array(neg_hist_brand_list) 606 | print("Finished array transform") 607 | np.savez('./' + file_name + '_' + str(file_count) + '_processed.npz', source_array=source_array,target_array=target_array, 608 | history_item_array=history_item_array,history_cate_array=history_cate_array, 609 | history_shop_array=history_shop_array,history_node_array=history_node_array, 610 | history_product_array=history_product_array,history_brand_array=history_brand_array, 611 | neg_history_item_array=neg_history_item_array,neg_history_cate_array=neg_history_cate_array, 612 | neg_history_shop_array=neg_history_shop_array,neg_history_node_array=neg_history_node_array, 613 | neg_history_product_array=neg_history_product_array,neg_history_brand_array=neg_history_brand_array) 614 | file_count = file_count + 1 615 | #===== empty the var 616 | source[:] = [] 617 | target[:] = [] 618 | hist_item_list[:] = [] 619 | hist_cate_list[:] = [] 620 | hist_shop_list[:] = [] 621 | hist_node_list[:] = [] 622 | hist_product_list[:] = [] 623 | hist_brand_list[:] = [] 624 | 625 | neg_hist_item_list[:] = [] 626 | neg_hist_cate_list[:] = [] 627 | neg_hist_shop_list[:] = [] 628 | neg_hist_node_list[:] = [] 629 | neg_hist_product_list[:] = [] 630 | neg_hist_brand_list[:] = [] 631 | 632 | 633 | def loadDataset( 634 | raw_path 635 | ): 636 | # dataset 637 | #output_filename = "dienDataset_processed" 638 | 639 | # read in the dataset 640 | #pos = raw_path.rfind('/') 641 | #data_path = raw_path[:pos] 642 | #data_file = raw_path[pos:] 643 | #npzfile = "." + data_file + "_processed.npz" 644 | 645 | data_path = raw_path 646 | max_length = 1000 647 | # If already processed, just load it 648 | file = preprocessDataset(data_path, max_length) 649 | 650 | #train_data_file = [ 'sample_00','sample_01', 'sample_02', 'sample_03'] 651 | #train_data_file = [ 'sample_03'] 652 | #file.process2( train_data_file, 'train_sample') 653 | 654 | test_data_file = ['test_sample_00', 'test_sample_01', 'test_sample_02', 'test_sample_03'] 655 | file.process2( test_data_file, 'test_sample') 656 | 657 | 658 | if __name__ == "__main__": 659 | ### import packages ### 660 | import argparse 661 | ### parse arguments ### 662 | parser = argparse.ArgumentParser( 663 | description="Preprocess Alibaba dataset" 664 | ) 665 | # model related parameters 666 | parser.add_argument("--raw-data-file", type=str, default="") 667 | args = parser.parse_args() 668 | print("Start process ", args.raw_data_file) 669 | # control randomness 670 | random.seed(0) 671 | 672 | loadDataset( 673 | args.raw_data_file 674 | ) 675 | -------------------------------------------------------------------------------- /DIEN/model_taobao_allfea.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import tensorflow as tf 3 | from utils import * 4 | from tensorflow.python.ops.rnn_cell import GRUCell 5 | from rnn import dynamic_rnn 6 | # import mann_simple_cell as mann_cell 7 | class Model(object): 8 | def __init__(self, uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN, use_negsample=False, Flag="DNN"): 9 | self.model_flag = Flag 10 | self.use_negsample= use_negsample 11 | def get_embeddings_variable(var_name, embedding_shape): 12 | # workaround to return vector of 0 13 | embeddings = tf.get_variable(var_name, embedding_shape, trainable=True) 14 | embeddings = tf.concat([ embeddings, [tf.constant([0.] * embedding_shape[1])] ], axis = 0) 15 | return embeddings 16 | 17 | with tf.name_scope('Inputs'): 18 | self.item_id_his_batch_ph = tf.placeholder(tf.int32, [None, None], name='item_id_his_batch_ph') 19 | self.cate_his_batch_ph = tf.placeholder(tf.int32, [None, None], name='cate_his_batch_ph') 20 | self.shop_his_batch_ph = tf.placeholder(tf.int32, [None, None], name='shop_his_batch_ph') 21 | self.node_his_batch_ph = tf.placeholder(tf.int32, [None, None], name='node_his_batch_ph') 22 | self.product_his_batch_ph = tf.placeholder(tf.int32, [None, None], name='product_his_batch_ph') 23 | self.brand_his_batch_ph = tf.placeholder(tf.int32, [None, None], name='brand_his_batch_ph') 24 | self.uid_batch_ph = tf.placeholder(tf.int32, [None, ], name='uid_batch_ph') 25 | self.item_id_batch_ph = tf.placeholder(tf.int32, [None, ], name='item_id_batch_ph') 26 | self.cate_id_batch_ph = tf.placeholder(tf.int32, [None, ], name='cate_id_batch_ph') 27 | self.shop_id_batch_ph = tf.placeholder(tf.int32, [None, ], name='shop_id_batch_ph') 28 | self.node_id_batch_ph = tf.placeholder(tf.int32, [None, ], name='node_id_batch_ph') 29 | self.product_id_batch_ph = tf.placeholder(tf.int32, [None, ], name='product_id_batch_ph') 30 | self.brand_id_batch_ph = tf.placeholder(tf.int32, [None, ], name='brand_id_batch_ph') 31 | self.mask = tf.placeholder(tf.float32, [None, None], name='mask_batch_ph') 32 | self.target_ph = tf.placeholder(tf.float32, [None, 2], name='target_ph') 33 | self.lr = tf.placeholder(tf.float64, []) 34 | 35 | # Embedding layer 36 | with tf.name_scope('Embedding_layer'): 37 | 38 | #self.item_id_embeddings_var = tf.get_variable("item_id_embedding_var", [item_n, EMBEDDING_DIM], trainable=True) 39 | self.item_id_embeddings_var = get_embeddings_variable("item_id_embedding_var", [item_n, EMBEDDING_DIM]) 40 | self.item_id_batch_embedded = tf.nn.embedding_lookup(self.item_id_embeddings_var, self.item_id_batch_ph) 41 | self.item_id_his_batch_embedded = tf.nn.embedding_lookup(self.item_id_embeddings_var, self.item_id_his_batch_ph) 42 | 43 | self.cate_id_embeddings_var = tf.get_variable("cate_id_embedding_var", [cate_n, EMBEDDING_DIM], trainable=True) 44 | self.cate_id_batch_embedded = tf.nn.embedding_lookup(self.cate_id_embeddings_var, self.cate_id_batch_ph) 45 | self.cate_his_batch_embedded = tf.nn.embedding_lookup(self.cate_id_embeddings_var, self.cate_his_batch_ph) 46 | 47 | #self.shop_id_embeddings_var = tf.get_variable("shop_id_embedding_var", [shop_n, EMBEDDING_DIM], trainable=True) 48 | self.shop_id_embeddings_var = get_embeddings_variable("shop_id_embedding_var", [shop_n, EMBEDDING_DIM]) 49 | self.shop_id_batch_embedded = tf.nn.embedding_lookup(self.shop_id_embeddings_var, self.shop_id_batch_ph) 50 | self.shop_his_batch_embedded = tf.nn.embedding_lookup(self.shop_id_embeddings_var, self.shop_his_batch_ph) 51 | 52 | self.node_id_embeddings_var = tf.get_variable("node_id_embedding_var", [node_n, EMBEDDING_DIM], trainable=True) 53 | self.node_id_batch_embedded = tf.nn.embedding_lookup(self.node_id_embeddings_var, self.node_id_batch_ph) 54 | self.node_his_batch_embedded = tf.nn.embedding_lookup(self.node_id_embeddings_var, self.node_his_batch_ph) 55 | 56 | self.product_id_embeddings_var = tf.get_variable("product_id_embedding_var", [product_n, EMBEDDING_DIM], trainable=True) 57 | self.product_id_batch_embedded = tf.nn.embedding_lookup(self.product_id_embeddings_var, self.product_id_batch_ph) 58 | self.product_his_batch_embedded = tf.nn.embedding_lookup(self.product_id_embeddings_var, self.product_his_batch_ph) 59 | 60 | #self.brand_id_embeddings_var = tf.get_variable("brand_id_embedding_var", [brand_n, EMBEDDING_DIM], trainable=True) 61 | self.brand_id_embeddings_var = get_embeddings_variable("brand_id_embedding_var", [brand_n, EMBEDDING_DIM]) 62 | self.brand_id_batch_embedded = tf.nn.embedding_lookup(self.brand_id_embeddings_var, self.brand_id_batch_ph) 63 | self.brand_his_batch_embedded = tf.nn.embedding_lookup(self.brand_id_embeddings_var, self.brand_his_batch_ph) 64 | 65 | if self.use_negsample: 66 | self.item_id_neg_batch_ph = tf.placeholder(tf.int32, [None, None], name='neg_his_batch_ph') 67 | self.cate_neg_batch_ph = tf.placeholder(tf.int32, [None, None], name='neg_cate_his_batch_ph') 68 | self.shop_neg_batch_ph = tf.placeholder(tf.int32, [None, None], name='neg_shop_his_batch_ph') 69 | self.node_neg_batch_ph = tf.placeholder(tf.int32, [None, None], name='neg_node_his_batch_ph') 70 | self.product_neg_batch_ph = tf.placeholder(tf.int32, [None, None], name='neg_product_his_batch_ph') 71 | self.brand_neg_batch_ph = tf.placeholder(tf.int32, [None, None], name='neg_brand_his_batch_ph') 72 | self.neg_item_his_eb = tf.nn.embedding_lookup(self.item_id_embeddings_var, self.item_id_neg_batch_ph) 73 | self.neg_cate_his_eb = tf.nn.embedding_lookup(self.cate_id_embeddings_var, self.cate_neg_batch_ph) 74 | self.neg_shop_his_eb = tf.nn.embedding_lookup(self.shop_id_embeddings_var, self.shop_neg_batch_ph) 75 | self.neg_node_his_eb = tf.nn.embedding_lookup(self.node_id_embeddings_var, self.node_neg_batch_ph) 76 | self.neg_product_his_eb = tf.nn.embedding_lookup(self.product_id_embeddings_var, self.product_neg_batch_ph) 77 | self.neg_brand_his_eb = tf.nn.embedding_lookup(self.brand_id_embeddings_var, self.brand_neg_batch_ph) 78 | self.neg_his_eb = tf.concat([self.neg_item_his_eb,self.neg_cate_his_eb, self.neg_shop_his_eb, self.neg_node_his_eb, self.neg_product_his_eb, self.neg_brand_his_eb], axis=2) * tf.reshape(self.mask,(BATCH_SIZE, SEQ_LEN, 1)) 79 | 80 | self.item_eb = tf.concat([self.item_id_batch_embedded, self.cate_id_batch_embedded, self.shop_id_batch_embedded, self.node_id_batch_embedded, self.product_id_batch_embedded, self.brand_id_batch_embedded], axis=1) 81 | self.item_his_eb = tf.concat([self.item_id_his_batch_embedded,self.cate_his_batch_embedded, self.shop_his_batch_embedded, self.node_his_batch_embedded, self.product_his_batch_embedded, self.brand_his_batch_embedded], axis=2) * tf.reshape(self.mask,(BATCH_SIZE, SEQ_LEN, 1)) 82 | #debug if last item of history is leaked 83 | #self.item_his_eb = self.item_his_eb[:,:-1,:] 84 | self.item_his_eb_sum = tf.reduce_sum(self.item_his_eb, 1) 85 | 86 | def build_fcn_net(self, inp, use_dice = False): 87 | bn1 = tf.layers.batch_normalization(inputs=inp, name='bn1') 88 | dnn1 = tf.layers.dense(bn1, 200, activation=None, name='f1') 89 | if use_dice: 90 | dnn1 = dice(dnn1, name='dice_1') 91 | else: 92 | dnn1 = prelu(dnn1, scope='prelu_1') 93 | 94 | dnn2 = tf.layers.dense(dnn1, 80, activation=None, name='f2') 95 | if use_dice: 96 | dnn2 = dice(dnn2, name='dice_2') 97 | else: 98 | dnn2 = prelu(dnn2, scope='prelu_2') 99 | dnn3 = tf.layers.dense(dnn2, 2, activation=None, name='f3') 100 | self.y_hat = tf.nn.softmax(dnn3) + 0.00000001 101 | 102 | with tf.name_scope('Metrics'): 103 | # Cross-entropy loss and optimizer initialization 104 | ctr_loss = - tf.reduce_mean(tf.log(self.y_hat) * self.target_ph) 105 | self.loss = ctr_loss 106 | if self.use_negsample: 107 | self.loss += self.aux_loss 108 | tf.summary.scalar('loss', self.loss) 109 | self.optimizer = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss) 110 | # Accuracy metric 111 | self.accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.round(self.y_hat), self.target_ph), tf.float32)) 112 | tf.summary.scalar('accuracy', self.accuracy) 113 | 114 | self.merged = tf.summary.merge_all() 115 | 116 | def auxiliary_loss(self, h_states, click_seq, noclick_seq, mask = None, stag = None): 117 | #mask = tf.cast(mask, tf.float32) 118 | click_input_ = tf.concat([h_states, click_seq], -1) 119 | noclick_input_ = tf.concat([h_states, noclick_seq], -1) 120 | click_prop_ = self.auxiliary_net(click_input_, stag = stag)[:, :, 0] 121 | noclick_prop_ = self.auxiliary_net(noclick_input_, stag = stag)[:, :, 0] 122 | 123 | click_loss_ = - tf.reshape(tf.log(click_prop_), [-1, tf.shape(click_seq)[1]]) * mask 124 | noclick_loss_ = - tf.reshape(tf.log(1.0 - noclick_prop_), [-1, tf.shape(noclick_seq)[1]]) * mask 125 | 126 | loss_ = tf.reduce_mean(click_loss_ + noclick_loss_) 127 | return loss_ 128 | 129 | def auxiliary_net(self, in_, stag='auxiliary_net'): 130 | bn1 = tf.layers.batch_normalization(inputs=in_, name='bn1' + stag, reuse=tf.AUTO_REUSE) 131 | dnn1 = tf.layers.dense(bn1, 100, activation=None, name='f1' + stag, reuse=tf.AUTO_REUSE) 132 | dnn1 = tf.nn.sigmoid(dnn1) 133 | dnn2 = tf.layers.dense(dnn1, 50, activation=None, name='f2' + stag, reuse=tf.AUTO_REUSE) 134 | dnn2 = tf.nn.sigmoid(dnn2) 135 | dnn3 = tf.layers.dense(dnn2, 2, activation=None, name='f3' + stag, reuse=tf.AUTO_REUSE) 136 | y_hat = tf.nn.softmax(dnn3) + 0.000001 137 | return y_hat 138 | 139 | 140 | def train(self, sess, inps): 141 | if self.use_negsample: 142 | loss, aux_loss, accuracy, _ = sess.run([self.loss, self.aux_loss, self.accuracy, self.optimizer], feed_dict={ 143 | self.uid_batch_ph: inps[0], 144 | self.item_id_batch_ph: inps[1], 145 | self.cate_id_batch_ph: inps[2], 146 | self.shop_id_batch_ph: inps[3], 147 | self.node_id_batch_ph: inps[4], 148 | self.product_id_batch_ph: inps[5], 149 | self.brand_id_batch_ph: inps[6], 150 | self.item_id_his_batch_ph: inps[7], 151 | self.cate_his_batch_ph: inps[8], 152 | self.shop_his_batch_ph: inps[9], 153 | self.node_his_batch_ph: inps[10], 154 | self.product_his_batch_ph: inps[11], 155 | self.brand_his_batch_ph: inps[12], 156 | self.item_id_neg_batch_ph: inps[13], 157 | self.cate_neg_batch_ph: inps[14], 158 | self.shop_neg_batch_ph: inps[15], 159 | self.node_neg_batch_ph: inps[16], 160 | self.product_neg_batch_ph: inps[17], 161 | self.brand_neg_batch_ph: inps[18], 162 | self.mask: inps[19], 163 | self.target_ph: inps[20], 164 | self.lr: inps[21] 165 | }) 166 | else: 167 | loss, aux_loss, accuracy, _ = sess.run([self.loss, self.loss, self.accuracy, self.optimizer], feed_dict={ 168 | self.uid_batch_ph: inps[0], 169 | self.item_id_batch_ph: inps[1], 170 | self.cate_id_batch_ph: inps[2], 171 | self.shop_id_batch_ph: inps[3], 172 | self.node_id_batch_ph: inps[4], 173 | self.product_id_batch_ph: inps[5], 174 | self.brand_id_batch_ph: inps[6], 175 | self.item_id_his_batch_ph: inps[7], 176 | self.cate_his_batch_ph: inps[8], 177 | self.shop_his_batch_ph: inps[9], 178 | self.node_his_batch_ph: inps[10], 179 | self.product_his_batch_ph: inps[11], 180 | self.brand_his_batch_ph: inps[12], 181 | # self.item_id_neg_batch_ph: inps[13], 182 | # self.cate_neg_batch_ph: inps[14], 183 | # self.shop_neg_batch_ph: inps[15], 184 | # self.node_neg_batch_ph: inps[16], 185 | # self.product_neg_batch_ph: inps[17], 186 | # self.brand_neg_batch_ph: inps[18], 187 | self.mask: inps[19], 188 | self.target_ph: inps[20], 189 | self.lr: inps[21] 190 | }) 191 | 192 | return loss, accuracy, aux_loss, 0, 0 193 | 194 | def calculate(self, sess, inps): 195 | if self.use_negsample: 196 | probs, loss, aux_loss, accuracy = sess.run([self.y_hat, self.loss, self.aux_loss, self.accuracy], feed_dict={ 197 | self.uid_batch_ph: inps[0], 198 | self.item_id_batch_ph: inps[1], 199 | self.cate_id_batch_ph: inps[2], 200 | self.shop_id_batch_ph: inps[3], 201 | self.node_id_batch_ph: inps[4], 202 | self.product_id_batch_ph: inps[5], 203 | self.brand_id_batch_ph: inps[6], 204 | self.item_id_his_batch_ph: inps[7], 205 | self.cate_his_batch_ph: inps[8], 206 | self.shop_his_batch_ph: inps[9], 207 | self.node_his_batch_ph: inps[10], 208 | self.product_his_batch_ph: inps[11], 209 | self.brand_his_batch_ph: inps[12], 210 | self.item_id_neg_batch_ph: inps[13], 211 | self.cate_neg_batch_ph: inps[14], 212 | self.shop_neg_batch_ph: inps[15], 213 | self.node_neg_batch_ph: inps[16], 214 | self.product_neg_batch_ph: inps[17], 215 | self.brand_neg_batch_ph: inps[18], 216 | self.mask: inps[19], 217 | self.target_ph: inps[20], 218 | }) 219 | else: 220 | probs, loss, aux_loss, accuracy = sess.run([self.y_hat, self.loss, self.loss, self.accuracy], feed_dict={ 221 | self.uid_batch_ph: inps[0], 222 | self.item_id_batch_ph: inps[1], 223 | self.cate_id_batch_ph: inps[2], 224 | self.shop_id_batch_ph: inps[3], 225 | self.node_id_batch_ph: inps[4], 226 | self.product_id_batch_ph: inps[5], 227 | self.brand_id_batch_ph: inps[6], 228 | self.item_id_his_batch_ph: inps[7], 229 | self.cate_his_batch_ph: inps[8], 230 | self.shop_his_batch_ph: inps[9], 231 | self.node_his_batch_ph: inps[10], 232 | self.product_his_batch_ph: inps[11], 233 | self.brand_his_batch_ph: inps[12], 234 | # self.item_id_neg_batch_ph: inps[13], 235 | # self.cate_neg_batch_ph: inps[14], 236 | # self.shop_neg_batch_ph: inps[15], 237 | # self.node_neg_batch_ph: inps[16], 238 | # self.product_neg_batch_ph: inps[17], 239 | # self.brand_neg_batch_ph: inps[18], 240 | self.mask: inps[19], 241 | self.target_ph: inps[20], 242 | }) 243 | return probs, loss, accuracy, aux_loss 244 | 245 | def save(self, sess, path): 246 | saver = tf.train.Saver() 247 | saver.save(sess, save_path=path) 248 | 249 | def restore(self, sess, path): 250 | saver = tf.train.Saver() 251 | saver.restore(sess, save_path=path) 252 | print('model restored from %s' % path) 253 | 254 | class Model_DNN(Model): 255 | def __init__(self,uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN=256): 256 | super(Model_DNN, self).__init__(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, 257 | BATCH_SIZE, SEQ_LEN, Flag="DNN") 258 | 259 | inp = tf.concat([self.item_eb, self.item_his_eb_sum], 1) 260 | self.build_fcn_net(inp, use_dice=False) 261 | 262 | 263 | class Model_PNN(Model): 264 | def __init__(self,uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN=256): 265 | super(Model_PNN, self).__init__(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, 266 | BATCH_SIZE, SEQ_LEN, Flag="PNN") 267 | 268 | inp = tf.concat([self.item_eb, self.item_his_eb_sum, self.item_eb * self.item_his_eb_sum], 1) 269 | self.build_fcn_net(inp, use_dice=False) 270 | 271 | 272 | class Model_GRU4REC(Model): 273 | def __init__(self,uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN=256): 274 | super(Model_GRU4REC, self).__init__(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, 275 | BATCH_SIZE, SEQ_LEN, Flag="GRU4REC") 276 | with tf.name_scope('rnn_1'): 277 | self.sequence_length = tf.Variable([SEQ_LEN] * BATCH_SIZE) 278 | rnn_outputs, final_state1 = dynamic_rnn(GRUCell(HIDDEN_SIZE), inputs=self.item_his_eb, 279 | sequence_length=self.sequence_length, dtype=tf.float32, 280 | scope="gru1") 281 | tf.summary.histogram('GRU_outputs', rnn_outputs) 282 | 283 | inp = tf.concat([self.item_eb, self.item_his_eb_sum, final_state1], 1) 284 | self.build_fcn_net(inp, use_dice=False) 285 | 286 | 287 | class Model_DIN(Model): 288 | def __init__(self,uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN=256): 289 | super(Model_DIN, self).__init__(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, 290 | BATCH_SIZE, SEQ_LEN, Flag="DIN") 291 | with tf.name_scope('Attention_layer'): 292 | attention_output = din_attention(self.item_eb, self.item_his_eb, HIDDEN_SIZE, self.mask) 293 | att_fea = tf.reduce_sum(attention_output, 1) 294 | tf.summary.histogram('att_fea', att_fea) 295 | inp = tf.concat([self.item_eb, self.item_his_eb_sum, att_fea], -1) 296 | self.build_fcn_net(inp, use_dice=False) 297 | 298 | 299 | class Model_ARNN(Model): 300 | def __init__(self,uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN=256): 301 | super(Model_ARNN, self).__init__(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, 302 | BATCH_SIZE, SEQ_LEN, Flag="ARNN") 303 | with tf.name_scope('rnn_1'): 304 | self.sequence_length = tf.Variable([SEQ_LEN] * BATCH_SIZE) 305 | rnn_outputs, final_state1 = dynamic_rnn(GRUCell(HIDDEN_SIZE), inputs=self.item_his_eb, 306 | sequence_length=self.sequence_length, dtype=tf.float32, 307 | scope="gru1") 308 | tf.summary.histogram('GRU_outputs', rnn_outputs) 309 | # Attention layer 310 | with tf.name_scope('Attention_layer_1'): 311 | att_gru = din_attention(self.item_eb, rnn_outputs, HIDDEN_SIZE, self.mask) 312 | #att_gru = din_attention(self.item_eb, rnn_outputs, HIDDEN_SIZE, None) 313 | att_gru = tf.reduce_sum(att_gru, 1) 314 | #att_hist = din_attention(self.item_eb, self.item_his_eb, HIDDEN_SIZE, None, stag="att") 315 | #att_hist = tf.reduce_sum(att_hist, 1) 316 | 317 | inp = tf.concat([self.item_eb, self.item_his_eb_sum, final_state1, att_gru], -1) 318 | self.build_fcn_net(inp, use_dice=False) 319 | 320 | class Model_DIEN(Model): 321 | def __init__(self, uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN=400, use_negsample=False): 322 | super(Model_DIEN, self).__init__(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, 323 | BATCH_SIZE, SEQ_LEN, use_negsample, Flag="DIEN") 324 | 325 | with tf.name_scope('rnn_1'): 326 | self.sequence_length = tf.Variable([SEQ_LEN] * BATCH_SIZE) 327 | rnn_outputs, _ = dynamic_rnn(GRUCell(HIDDEN_SIZE), inputs=self.item_his_eb, 328 | sequence_length=self.sequence_length, dtype=tf.float32, 329 | scope="gru1") 330 | tf.summary.histogram('GRU_outputs', rnn_outputs) 331 | 332 | if use_negsample: 333 | aux_loss_1 = self.auxiliary_loss(rnn_outputs[:, :-1, :], self.item_his_eb[:, 1:, :], 334 | self.neg_his_eb[:, 1:, :], self.mask[:, 1:], stag = "bigru_0") 335 | self.aux_loss = aux_loss_1 336 | 337 | # Attention layer 338 | with tf.name_scope('Attention_layer_1'): 339 | #att_outputs, alphas = din_fcn_attention(self.item_eb, rnn_outputs, HIDDEN_SIZE, self.mask, softmax_stag=1, stag='1_1', mode="LIST", return_alphas=True) 340 | att_outputs, alphas = din_attention(self.item_eb, rnn_outputs, HIDDEN_SIZE, mask=self.mask, mode="LIST", return_alphas=True) 341 | tf.summary.histogram('alpha_outputs', alphas) 342 | 343 | with tf.name_scope('rnn_2'): 344 | rnn_outputs2, final_state2 = dynamic_rnn(VecAttGRUCell(HIDDEN_SIZE), inputs=rnn_outputs, 345 | att_scores = tf.expand_dims(alphas, -1), 346 | sequence_length=self.sequence_length, dtype=tf.float32, 347 | scope="gru2") 348 | tf.summary.histogram('GRU2_Final_State', final_state2) 349 | 350 | inp = tf.concat([self.item_eb, final_state2, self.item_his_eb_sum], 1) 351 | self.build_fcn_net(inp, use_dice=False) 352 | -------------------------------------------------------------------------------- /DIEN/plot/plot1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/bigcomputing/4e308d8338ea9c4cd0fe8c430c7ac21e33d94373/DIEN/plot/plot1.png -------------------------------------------------------------------------------- /DIEN/rnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """RNN helpers for TensorFlow models. 17 | 18 | 19 | @@bidirectional_dynamic_rnn 20 | @@dynamic_rnn 21 | @@raw_rnn 22 | @@static_rnn 23 | @@static_state_saving_rnn 24 | @@static_bidirectional_rnn 25 | """ 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | from tensorflow.python.framework import constant_op 31 | from tensorflow.python.framework import dtypes 32 | from tensorflow.python.framework import ops 33 | from tensorflow.python.framework import tensor_shape 34 | from tensorflow.python.ops import array_ops 35 | from tensorflow.python.ops import control_flow_ops 36 | from tensorflow.python.ops import math_ops 37 | from tensorflow.python.ops import rnn_cell_impl 38 | from tensorflow.python.ops import tensor_array_ops 39 | from tensorflow.python.ops import variable_scope as vs 40 | from tensorflow.python.util import nest 41 | 42 | 43 | # pylint: disable=protected-access 44 | _concat = rnn_cell_impl._concat 45 | assert_like_rnncell = rnn_cell_impl.assert_like_rnncell 46 | # pylint: enable=protected-access 47 | 48 | 49 | def _transpose_batch_time(x): 50 | """Transpose the batch and time dimensions of a Tensor. 51 | 52 | Retains as much of the static shape information as possible. 53 | 54 | Args: 55 | x: A tensor of rank 2 or higher. 56 | 57 | Returns: 58 | x transposed along the first two dimensions. 59 | 60 | Raises: 61 | ValueError: if `x` is rank 1 or lower. 62 | """ 63 | x_static_shape = x.get_shape() 64 | if x_static_shape.ndims is not None and x_static_shape.ndims < 2: 65 | raise ValueError( 66 | "Expected input tensor %s to have rank at least 2, but saw shape: %s" % 67 | (x, x_static_shape)) 68 | x_rank = array_ops.rank(x) 69 | x_t = array_ops.transpose( 70 | x, array_ops.concat( 71 | ([1, 0], math_ops.range(2, x_rank)), axis=0)) 72 | x_t.set_shape( 73 | tensor_shape.TensorShape([ 74 | x_static_shape[1].value, x_static_shape[0].value 75 | ]).concatenate(x_static_shape[2:])) 76 | return x_t 77 | 78 | 79 | def _best_effort_input_batch_size(flat_input): 80 | """Get static input batch size if available, with fallback to the dynamic one. 81 | 82 | Args: 83 | flat_input: An iterable of time major input Tensors of shape [max_time, 84 | batch_size, ...]. All inputs should have compatible batch sizes. 85 | 86 | Returns: 87 | The batch size in Python integer if available, or a scalar Tensor otherwise. 88 | 89 | Raises: 90 | ValueError: if there is any input with an invalid shape. 91 | """ 92 | for input_ in flat_input: 93 | shape = input_.shape 94 | if shape.ndims is None: 95 | continue 96 | if shape.ndims < 2: 97 | raise ValueError( 98 | "Expected input tensor %s to have rank at least 2" % input_) 99 | batch_size = shape[1].value 100 | if batch_size is not None: 101 | return batch_size 102 | # Fallback to the dynamic batch size of the first input. 103 | return array_ops.shape(flat_input[0])[1] 104 | 105 | 106 | def _infer_state_dtype(explicit_dtype, state): 107 | """Infer the dtype of an RNN state. 108 | 109 | Args: 110 | explicit_dtype: explicitly declared dtype or None. 111 | state: RNN's hidden state. Must be a Tensor or a nested iterable containing 112 | Tensors. 113 | 114 | Returns: 115 | dtype: inferred dtype of hidden state. 116 | 117 | Raises: 118 | ValueError: if `state` has heterogeneous dtypes or is empty. 119 | """ 120 | if explicit_dtype is not None: 121 | return explicit_dtype 122 | elif nest.is_sequence(state): 123 | inferred_dtypes = [element.dtype for element in nest.flatten(state)] 124 | if not inferred_dtypes: 125 | raise ValueError("Unable to infer dtype from empty state.") 126 | all_same = all([x == inferred_dtypes[0] for x in inferred_dtypes]) 127 | if not all_same: 128 | raise ValueError( 129 | "State has tensors of different inferred_dtypes. Unable to infer a " 130 | "single representative dtype.") 131 | return inferred_dtypes[0] 132 | else: 133 | return state.dtype 134 | 135 | 136 | # pylint: disable=unused-argument 137 | def _rnn_step( 138 | time, sequence_length, min_sequence_length, max_sequence_length, 139 | zero_output, state, call_cell, state_size, skip_conditionals=False): 140 | """Calculate one step of a dynamic RNN minibatch. 141 | 142 | Returns an (output, state) pair conditioned on the sequence_lengths. 143 | When skip_conditionals=False, the pseudocode is something like: 144 | 145 | if t >= max_sequence_length: 146 | return (zero_output, state) 147 | if t < min_sequence_length: 148 | return call_cell() 149 | 150 | # Selectively output zeros or output, old state or new state depending 151 | # on if we've finished calculating each row. 152 | new_output, new_state = call_cell() 153 | final_output = np.vstack([ 154 | zero_output if time >= sequence_lengths[r] else new_output_r 155 | for r, new_output_r in enumerate(new_output) 156 | ]) 157 | final_state = np.vstack([ 158 | state[r] if time >= sequence_lengths[r] else new_state_r 159 | for r, new_state_r in enumerate(new_state) 160 | ]) 161 | return (final_output, final_state) 162 | 163 | Args: 164 | time: Python int, the current time step 165 | sequence_length: int32 `Tensor` vector of size [batch_size] 166 | min_sequence_length: int32 `Tensor` scalar, min of sequence_length 167 | max_sequence_length: int32 `Tensor` scalar, max of sequence_length 168 | zero_output: `Tensor` vector of shape [output_size] 169 | state: Either a single `Tensor` matrix of shape `[batch_size, state_size]`, 170 | or a list/tuple of such tensors. 171 | call_cell: lambda returning tuple of (new_output, new_state) where 172 | new_output is a `Tensor` matrix of shape `[batch_size, output_size]`. 173 | new_state is a `Tensor` matrix of shape `[batch_size, state_size]`. 174 | state_size: The `cell.state_size` associated with the state. 175 | skip_conditionals: Python bool, whether to skip using the conditional 176 | calculations. This is useful for `dynamic_rnn`, where the input tensor 177 | matches `max_sequence_length`, and using conditionals just slows 178 | everything down. 179 | 180 | Returns: 181 | A tuple of (`final_output`, `final_state`) as given by the pseudocode above: 182 | final_output is a `Tensor` matrix of shape [batch_size, output_size] 183 | final_state is either a single `Tensor` matrix, or a tuple of such 184 | matrices (matching length and shapes of input `state`). 185 | 186 | Raises: 187 | ValueError: If the cell returns a state tuple whose length does not match 188 | that returned by `state_size`. 189 | """ 190 | 191 | # Convert state to a list for ease of use 192 | flat_state = nest.flatten(state) 193 | flat_zero_output = nest.flatten(zero_output) 194 | 195 | def _copy_one_through(output, new_output): 196 | # If the state contains a scalar value we simply pass it through. 197 | if output.shape.ndims == 0: 198 | return new_output 199 | copy_cond = (time >= sequence_length) 200 | with ops.colocate_with(new_output): 201 | return array_ops.where(copy_cond, output, new_output) 202 | 203 | def _copy_some_through(flat_new_output, flat_new_state): 204 | # Use broadcasting select to determine which values should get 205 | # the previous state & zero output, and which values should get 206 | # a calculated state & output. 207 | flat_new_output = [ 208 | _copy_one_through(zero_output, new_output) 209 | for zero_output, new_output in zip(flat_zero_output, flat_new_output)] 210 | flat_new_state = [ 211 | _copy_one_through(state, new_state) 212 | for state, new_state in zip(flat_state, flat_new_state)] 213 | return flat_new_output + flat_new_state 214 | 215 | def _maybe_copy_some_through(): 216 | """Run RNN step. Pass through either no or some past state.""" 217 | new_output, new_state = call_cell() 218 | 219 | nest.assert_same_structure(state, new_state) 220 | 221 | flat_new_state = nest.flatten(new_state) 222 | flat_new_output = nest.flatten(new_output) 223 | return control_flow_ops.cond( 224 | # if t < min_seq_len: calculate and return everything 225 | time < min_sequence_length, lambda: flat_new_output + flat_new_state, 226 | # else copy some of it through 227 | lambda: _copy_some_through(flat_new_output, flat_new_state)) 228 | 229 | # TODO(ebrevdo): skipping these conditionals may cause a slowdown, 230 | # but benefits from removing cond() and its gradient. We should 231 | # profile with and without this switch here. 232 | if skip_conditionals: 233 | # Instead of using conditionals, perform the selective copy at all time 234 | # steps. This is faster when max_seq_len is equal to the number of unrolls 235 | # (which is typical for dynamic_rnn). 236 | new_output, new_state = call_cell() 237 | nest.assert_same_structure(state, new_state) 238 | new_state = nest.flatten(new_state) 239 | new_output = nest.flatten(new_output) 240 | final_output_and_state = _copy_some_through(new_output, new_state) 241 | else: 242 | empty_update = lambda: flat_zero_output + flat_state 243 | final_output_and_state = control_flow_ops.cond( 244 | # if t >= max_seq_len: copy all state through, output zeros 245 | time >= max_sequence_length, empty_update, 246 | # otherwise calculation is required: copy some or all of it through 247 | _maybe_copy_some_through) 248 | 249 | if len(final_output_and_state) != len(flat_zero_output) + len(flat_state): 250 | raise ValueError("Internal error: state and output were not concatenated " 251 | "correctly.") 252 | final_output = final_output_and_state[:len(flat_zero_output)] 253 | final_state = final_output_and_state[len(flat_zero_output):] 254 | 255 | for output, flat_output in zip(final_output, flat_zero_output): 256 | output.set_shape(flat_output.get_shape()) 257 | for substate, flat_substate in zip(final_state, flat_state): 258 | substate.set_shape(flat_substate.get_shape()) 259 | 260 | final_output = nest.pack_sequence_as( 261 | structure=zero_output, flat_sequence=final_output) 262 | final_state = nest.pack_sequence_as( 263 | structure=state, flat_sequence=final_state) 264 | 265 | return final_output, final_state 266 | 267 | 268 | def _reverse_seq(input_seq, lengths): 269 | """Reverse a list of Tensors up to specified lengths. 270 | 271 | Args: 272 | input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features) 273 | or nested tuples of tensors. 274 | lengths: A `Tensor` of dimension batch_size, containing lengths for each 275 | sequence in the batch. If "None" is specified, simply reverses 276 | the list. 277 | 278 | Returns: 279 | time-reversed sequence 280 | """ 281 | if lengths is None: 282 | return list(reversed(input_seq)) 283 | 284 | flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq) 285 | 286 | flat_results = [[] for _ in range(len(input_seq))] 287 | for sequence in zip(*flat_input_seq): 288 | input_shape = tensor_shape.unknown_shape( 289 | ndims=sequence[0].get_shape().ndims) 290 | for input_ in sequence: 291 | input_shape.merge_with(input_.get_shape()) 292 | input_.set_shape(input_shape) 293 | 294 | # Join into (time, batch_size, depth) 295 | s_joined = array_ops.stack(sequence) 296 | 297 | # Reverse along dimension 0 298 | s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1) 299 | # Split again into list 300 | result = array_ops.unstack(s_reversed) 301 | for r, flat_result in zip(result, flat_results): 302 | r.set_shape(input_shape) 303 | flat_result.append(r) 304 | 305 | results = [nest.pack_sequence_as(structure=input_, flat_sequence=flat_result) 306 | for input_, flat_result in zip(input_seq, flat_results)] 307 | return results 308 | 309 | 310 | def bidirectional_dynamic_rnn(cell_fw, cell_bw, inputs, sequence_length=None, 311 | initial_state_fw=None, initial_state_bw=None, 312 | dtype=None, parallel_iterations=None, 313 | swap_memory=False, time_major=False, scope=None): 314 | """Creates a dynamic version of bidirectional recurrent neural network. 315 | 316 | Takes input and builds independent forward and backward RNNs. The input_size 317 | of forward and backward cell must match. The initial state for both directions 318 | is zero by default (but can be set optionally) and no intermediate states are 319 | ever returned -- the network is fully unrolled for the given (passed in) 320 | length(s) of the sequence(s) or completely unrolled if length(s) is not 321 | given. 322 | 323 | Args: 324 | cell_fw: An instance of RNNCell, to be used for forward direction. 325 | cell_bw: An instance of RNNCell, to be used for backward direction. 326 | inputs: The RNN inputs. 327 | If time_major == False (default), this must be a tensor of shape: 328 | `[batch_size, max_time, ...]`, or a nested tuple of such elements. 329 | If time_major == True, this must be a tensor of shape: 330 | `[max_time, batch_size, ...]`, or a nested tuple of such elements. 331 | sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 332 | containing the actual lengths for each of the sequences in the batch. 333 | If not provided, all batch entries are assumed to be full sequences; and 334 | time reversal is applied from time `0` to `max_time` for each sequence. 335 | initial_state_fw: (optional) An initial state for the forward RNN. 336 | This must be a tensor of appropriate type and shape 337 | `[batch_size, cell_fw.state_size]`. 338 | If `cell_fw.state_size` is a tuple, this should be a tuple of 339 | tensors having shapes `[batch_size, s] for s in cell_fw.state_size`. 340 | initial_state_bw: (optional) Same as for `initial_state_fw`, but using 341 | the corresponding properties of `cell_bw`. 342 | dtype: (optional) The data type for the initial states and expected output. 343 | Required if initial_states are not provided or RNN states have a 344 | heterogeneous dtype. 345 | parallel_iterations: (Default: 32). The number of iterations to run in 346 | parallel. Those operations which do not have any temporal dependency 347 | and can be run in parallel, will be. This parameter trades off 348 | time for space. Values >> 1 use more memory but take less time, 349 | while smaller values use less memory but computations take longer. 350 | swap_memory: Transparently swap the tensors produced in forward inference 351 | but needed for back prop from GPU to CPU. This allows training RNNs 352 | which would typically not fit on a single GPU, with very minimal (or no) 353 | performance penalty. 354 | time_major: The shape format of the `inputs` and `outputs` Tensors. 355 | If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. 356 | If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. 357 | Using `time_major = True` is a bit more efficient because it avoids 358 | transposes at the beginning and end of the RNN calculation. However, 359 | most TensorFlow data is batch-major, so by default this function 360 | accepts input and emits output in batch-major form. 361 | scope: VariableScope for the created subgraph; defaults to 362 | "bidirectional_rnn" 363 | 364 | Returns: 365 | A tuple (outputs, output_states) where: 366 | outputs: A tuple (output_fw, output_bw) containing the forward and 367 | the backward rnn output `Tensor`. 368 | If time_major == False (default), 369 | output_fw will be a `Tensor` shaped: 370 | `[batch_size, max_time, cell_fw.output_size]` 371 | and output_bw will be a `Tensor` shaped: 372 | `[batch_size, max_time, cell_bw.output_size]`. 373 | If time_major == True, 374 | output_fw will be a `Tensor` shaped: 375 | `[max_time, batch_size, cell_fw.output_size]` 376 | and output_bw will be a `Tensor` shaped: 377 | `[max_time, batch_size, cell_bw.output_size]`. 378 | It returns a tuple instead of a single concatenated `Tensor`, unlike 379 | in the `bidirectional_rnn`. If the concatenated one is preferred, 380 | the forward and backward outputs can be concatenated as 381 | `tf.concat(outputs, 2)`. 382 | output_states: A tuple (output_state_fw, output_state_bw) containing 383 | the forward and the backward final states of bidirectional rnn. 384 | 385 | Raises: 386 | TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 387 | """ 388 | 389 | assert_like_rnncell(cell_fw.name, cell_fw) 390 | assert_like_rnncell(cell_bw.name, cell_bw) 391 | 392 | with vs.variable_scope(scope or "bidirectional_rnn"): 393 | # Forward direction 394 | with vs.variable_scope("fw") as fw_scope: 395 | output_fw, output_state_fw = dynamic_rnn( 396 | cell=cell_fw, inputs=inputs, sequence_length=sequence_length, 397 | initial_state=initial_state_fw, dtype=dtype, 398 | parallel_iterations=parallel_iterations, swap_memory=swap_memory, 399 | time_major=time_major, scope=fw_scope) 400 | 401 | # Backward direction 402 | if not time_major: 403 | time_dim = 1 404 | batch_dim = 0 405 | else: 406 | time_dim = 0 407 | batch_dim = 1 408 | 409 | def _reverse(input_, seq_lengths, seq_dim, batch_dim): 410 | if seq_lengths is not None: 411 | return array_ops.reverse_sequence( 412 | input=input_, seq_lengths=seq_lengths, 413 | seq_dim=seq_dim, batch_dim=batch_dim) 414 | else: 415 | return array_ops.reverse(input_, axis=[seq_dim]) 416 | 417 | with vs.variable_scope("bw") as bw_scope: 418 | inputs_reverse = _reverse( 419 | inputs, seq_lengths=sequence_length, 420 | seq_dim=time_dim, batch_dim=batch_dim) 421 | tmp, output_state_bw = dynamic_rnn( 422 | cell=cell_bw, inputs=inputs_reverse, sequence_length=sequence_length, 423 | initial_state=initial_state_bw, dtype=dtype, 424 | parallel_iterations=parallel_iterations, swap_memory=swap_memory, 425 | time_major=time_major, scope=bw_scope) 426 | 427 | output_bw = _reverse( 428 | tmp, seq_lengths=sequence_length, 429 | seq_dim=time_dim, batch_dim=batch_dim) 430 | 431 | outputs = (output_fw, output_bw) 432 | output_states = (output_state_fw, output_state_bw) 433 | 434 | return (outputs, output_states) 435 | 436 | 437 | def dynamic_rnn(cell, inputs, att_scores=None, sequence_length=None, initial_state=None, 438 | dtype=None, parallel_iterations=None, swap_memory=False, 439 | time_major=False, scope=None): 440 | """Creates a recurrent neural network specified by RNNCell `cell`. 441 | 442 | Performs fully dynamic unrolling of `inputs`. 443 | 444 | Example: 445 | 446 | ```python 447 | # create a BasicRNNCell 448 | rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) 449 | 450 | # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size] 451 | 452 | # defining initial state 453 | initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32) 454 | 455 | # 'state' is a tensor of shape [batch_size, cell_state_size] 456 | outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data, 457 | initial_state=initial_state, 458 | dtype=tf.float32) 459 | ``` 460 | 461 | ```python 462 | # create 2 LSTMCells 463 | rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]] 464 | 465 | # create a RNN cell composed sequentially of a number of RNNCells 466 | multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers) 467 | 468 | # 'outputs' is a tensor of shape [batch_size, max_time, 256] 469 | # 'state' is a N-tuple where N is the number of LSTMCells containing a 470 | # tf.contrib.rnn.LSTMStateTuple for each cell 471 | outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell, 472 | inputs=data, 473 | dtype=tf.float32) 474 | ``` 475 | 476 | 477 | Args: 478 | cell: An instance of RNNCell. 479 | inputs: The RNN inputs. 480 | If `time_major == False` (default), this must be a `Tensor` of shape: 481 | `[batch_size, max_time, ...]`, or a nested tuple of such 482 | elements. 483 | If `time_major == True`, this must be a `Tensor` of shape: 484 | `[max_time, batch_size, ...]`, or a nested tuple of such 485 | elements. 486 | This may also be a (possibly nested) tuple of Tensors satisfying 487 | this property. The first two dimensions must match across all the inputs, 488 | but otherwise the ranks and other shape components may differ. 489 | In this case, input to `cell` at each time-step will replicate the 490 | structure of these tuples, except for the time dimension (from which the 491 | time is taken). 492 | The input to `cell` at each time step will be a `Tensor` or (possibly 493 | nested) tuple of Tensors each with dimensions `[batch_size, ...]`. 494 | sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. 495 | Used to copy-through state and zero-out outputs when past a batch 496 | element's sequence length. So it's more for correctness than performance. 497 | initial_state: (optional) An initial state for the RNN. 498 | If `cell.state_size` is an integer, this must be 499 | a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`. 500 | If `cell.state_size` is a tuple, this should be a tuple of 501 | tensors having shapes `[batch_size, s] for s in cell.state_size`. 502 | dtype: (optional) The data type for the initial state and expected output. 503 | Required if initial_state is not provided or RNN state has a heterogeneous 504 | dtype. 505 | parallel_iterations: (Default: 32). The number of iterations to run in 506 | parallel. Those operations which do not have any temporal dependency 507 | and can be run in parallel, will be. This parameter trades off 508 | time for space. Values >> 1 use more memory but take less time, 509 | while smaller values use less memory but computations take longer. 510 | swap_memory: Transparently swap the tensors produced in forward inference 511 | but needed for back prop from GPU to CPU. This allows training RNNs 512 | which would typically not fit on a single GPU, with very minimal (or no) 513 | performance penalty. 514 | time_major: The shape format of the `inputs` and `outputs` Tensors. 515 | If true, these `Tensors` must be shaped `[max_time, batch_size, depth]`. 516 | If false, these `Tensors` must be shaped `[batch_size, max_time, depth]`. 517 | Using `time_major = True` is a bit more efficient because it avoids 518 | transposes at the beginning and end of the RNN calculation. However, 519 | most TensorFlow data is batch-major, so by default this function 520 | accepts input and emits output in batch-major form. 521 | scope: VariableScope for the created subgraph; defaults to "rnn". 522 | 523 | Returns: 524 | A pair (outputs, state) where: 525 | 526 | outputs: The RNN output `Tensor`. 527 | 528 | If time_major == False (default), this will be a `Tensor` shaped: 529 | `[batch_size, max_time, cell.output_size]`. 530 | 531 | If time_major == True, this will be a `Tensor` shaped: 532 | `[max_time, batch_size, cell.output_size]`. 533 | 534 | Note, if `cell.output_size` is a (possibly nested) tuple of integers 535 | or `TensorShape` objects, then `outputs` will be a tuple having the 536 | same structure as `cell.output_size`, containing Tensors having shapes 537 | corresponding to the shape data in `cell.output_size`. 538 | 539 | state: The final state. If `cell.state_size` is an int, this 540 | will be shaped `[batch_size, cell.state_size]`. If it is a 541 | `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 542 | If it is a (possibly nested) tuple of ints or `TensorShape`, this will 543 | be a tuple having the corresponding shapes. If cells are `LSTMCells` 544 | `state` will be a tuple containing a `LSTMStateTuple` for each cell. 545 | 546 | Raises: 547 | TypeError: If `cell` is not an instance of RNNCell. 548 | ValueError: If inputs is None or an empty list. 549 | """ 550 | assert_like_rnncell(cell.name, cell) 551 | 552 | # By default, time_major==False and inputs are batch-major: shaped 553 | # [batch, time, depth] 554 | # For internal calculations, we transpose to [time, batch, depth] 555 | flat_input = nest.flatten(inputs) 556 | 557 | if not time_major: 558 | # (B,T,D) => (T,B,D) 559 | flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input] 560 | flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input) 561 | 562 | parallel_iterations = parallel_iterations or 32 563 | if sequence_length is not None: 564 | sequence_length = math_ops.to_int32(sequence_length) 565 | if sequence_length.get_shape().ndims not in (None, 1): 566 | raise ValueError( 567 | "sequence_length must be a vector of length batch_size, " 568 | "but saw shape: %s" % sequence_length.get_shape()) 569 | sequence_length = array_ops.identity( # Just to find it in the graph. 570 | sequence_length, name="sequence_length") 571 | 572 | # Create a new scope in which the caching device is either 573 | # determined by the parent scope, or is set to place the cached 574 | # Variable using the same placement as for the rest of the RNN. 575 | with vs.variable_scope(scope or "rnn") as varscope: 576 | if varscope.caching_device is None: 577 | varscope.set_caching_device(lambda op: op.device) 578 | batch_size = _best_effort_input_batch_size(flat_input) 579 | 580 | if initial_state is not None: 581 | state = initial_state 582 | else: 583 | if not dtype: 584 | raise ValueError("If there is no initial_state, you must give a dtype.") 585 | state = cell.zero_state(batch_size, dtype) 586 | 587 | def _assert_has_shape(x, shape): 588 | x_shape = array_ops.shape(x) 589 | packed_shape = array_ops.stack(shape) 590 | return control_flow_ops.Assert( 591 | math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), 592 | ["Expected shape for Tensor %s is " % x.name, 593 | packed_shape, " but saw shape: ", x_shape]) 594 | 595 | if sequence_length is not None: 596 | # Perform some shape validation 597 | with ops.control_dependencies( 598 | [_assert_has_shape(sequence_length, [batch_size])]): 599 | sequence_length = array_ops.identity( 600 | sequence_length, name="CheckSeqLen") 601 | 602 | inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) 603 | 604 | (outputs, final_state) = _dynamic_rnn_loop( 605 | cell, 606 | inputs, 607 | state, 608 | parallel_iterations=parallel_iterations, 609 | swap_memory=swap_memory, 610 | att_scores = att_scores, 611 | sequence_length=sequence_length, 612 | dtype=dtype) 613 | 614 | # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth]. 615 | # If we are performing batch-major calculations, transpose output back 616 | # to shape [batch, time, depth] 617 | if not time_major: 618 | # (T,B,D) => (B,T,D) 619 | outputs = nest.map_structure(_transpose_batch_time, outputs) 620 | 621 | return (outputs, final_state) 622 | 623 | 624 | def _dynamic_rnn_loop(cell, 625 | inputs, 626 | initial_state, 627 | parallel_iterations, 628 | swap_memory, 629 | att_scores = None, 630 | sequence_length=None, 631 | dtype=None): 632 | """Internal implementation of Dynamic RNN. 633 | 634 | Args: 635 | cell: An instance of RNNCell. 636 | inputs: A `Tensor` of shape [time, batch_size, input_size], or a nested 637 | tuple of such elements. 638 | initial_state: A `Tensor` of shape `[batch_size, state_size]`, or if 639 | `cell.state_size` is a tuple, then this should be a tuple of 640 | tensors having shapes `[batch_size, s] for s in cell.state_size`. 641 | parallel_iterations: Positive Python int. 642 | swap_memory: A Python boolean 643 | sequence_length: (optional) An `int32` `Tensor` of shape [batch_size]. 644 | dtype: (optional) Expected dtype of output. If not specified, inferred from 645 | initial_state. 646 | 647 | Returns: 648 | Tuple `(final_outputs, final_state)`. 649 | final_outputs: 650 | A `Tensor` of shape `[time, batch_size, cell.output_size]`. If 651 | `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape` 652 | objects, then this returns a (possibly nsted) tuple of Tensors matching 653 | the corresponding shapes. 654 | final_state: 655 | A `Tensor`, or possibly nested tuple of Tensors, matching in length 656 | and shapes to `initial_state`. 657 | 658 | Raises: 659 | ValueError: If the input depth cannot be inferred via shape inference 660 | from the inputs. 661 | """ 662 | state = initial_state 663 | assert isinstance(parallel_iterations, int), "parallel_iterations must be int" 664 | 665 | state_size = cell.state_size 666 | 667 | flat_input = nest.flatten(inputs) 668 | flat_output_size = nest.flatten(cell.output_size) 669 | 670 | # Construct an initial output 671 | input_shape = array_ops.shape(flat_input[0]) 672 | time_steps = input_shape[0] 673 | batch_size = _best_effort_input_batch_size(flat_input) 674 | 675 | inputs_got_shape = tuple(input_.get_shape().with_rank_at_least(3) 676 | for input_ in flat_input) 677 | 678 | const_time_steps, const_batch_size = inputs_got_shape[0].as_list()[:2] 679 | 680 | for shape in inputs_got_shape: 681 | if not shape[2:].is_fully_defined(): 682 | raise ValueError( 683 | "Input size (depth of inputs) must be accessible via shape inference," 684 | " but saw value None.") 685 | got_time_steps = shape[0].value 686 | got_batch_size = shape[1].value 687 | if const_time_steps != got_time_steps: 688 | raise ValueError( 689 | "Time steps is not the same for all the elements in the input in a " 690 | "batch.") 691 | if const_batch_size != got_batch_size: 692 | raise ValueError( 693 | "Batch_size is not the same for all the elements in the input.") 694 | 695 | # Prepare dynamic conditional copying of state & output 696 | def _create_zero_arrays(size): 697 | size = _concat(batch_size, size) 698 | return array_ops.zeros( 699 | array_ops.stack(size), _infer_state_dtype(dtype, state)) 700 | 701 | flat_zero_output = tuple(_create_zero_arrays(output) 702 | for output in flat_output_size) 703 | zero_output = nest.pack_sequence_as(structure=cell.output_size, 704 | flat_sequence=flat_zero_output) 705 | 706 | if sequence_length is not None: 707 | min_sequence_length = math_ops.reduce_min(sequence_length) 708 | max_sequence_length = math_ops.reduce_max(sequence_length) 709 | 710 | time = array_ops.constant(0, dtype=dtypes.int32, name="time") 711 | 712 | with ops.name_scope("dynamic_rnn") as scope: 713 | base_name = scope 714 | 715 | def _create_ta(name, dtype): 716 | return tensor_array_ops.TensorArray(dtype=dtype, 717 | size=time_steps, 718 | tensor_array_name=base_name + name) 719 | 720 | output_ta = tuple(_create_ta("output_%d" % i, 721 | _infer_state_dtype(dtype, state)) 722 | for i in range(len(flat_output_size))) 723 | input_ta = tuple(_create_ta("input_%d" % i, flat_input[i].dtype) 724 | for i in range(len(flat_input))) 725 | 726 | input_ta = tuple(ta.unstack(input_) 727 | for ta, input_ in zip(input_ta, flat_input)) 728 | 729 | def _time_step(time, output_ta_t, state, att_scores=None): 730 | """Take a time step of the dynamic RNN. 731 | 732 | Args: 733 | time: int32 scalar Tensor. 734 | output_ta_t: List of `TensorArray`s that represent the output. 735 | state: nested tuple of vector tensors that represent the state. 736 | 737 | Returns: 738 | The tuple (time + 1, output_ta_t with updated flow, new_state). 739 | """ 740 | 741 | input_t = tuple(ta.read(time) for ta in input_ta) 742 | # Restore some shape information 743 | for input_, shape in zip(input_t, inputs_got_shape): 744 | input_.set_shape(shape[1:]) 745 | 746 | input_t = nest.pack_sequence_as(structure=inputs, flat_sequence=input_t) 747 | if att_scores is not None: 748 | att_score = att_scores[:, time, :] 749 | call_cell = lambda: cell(input_t, state, att_score) 750 | else: 751 | call_cell = lambda: cell(input_t, state) 752 | 753 | if sequence_length is not None: 754 | (output, new_state) = _rnn_step( 755 | time=time, 756 | sequence_length=sequence_length, 757 | min_sequence_length=min_sequence_length, 758 | max_sequence_length=max_sequence_length, 759 | zero_output=zero_output, 760 | state=state, 761 | call_cell=call_cell, 762 | state_size=state_size, 763 | skip_conditionals=True) 764 | else: 765 | (output, new_state) = call_cell() 766 | 767 | # Pack state if using state tuples 768 | output = nest.flatten(output) 769 | 770 | output_ta_t = tuple( 771 | ta.write(time, out) for ta, out in zip(output_ta_t, output)) 772 | if att_scores is not None: 773 | return (time + 1, output_ta_t, new_state, att_scores) 774 | else: 775 | return (time + 1, output_ta_t, new_state) 776 | 777 | if att_scores is not None: 778 | _, output_final_ta, final_state, _ = control_flow_ops.while_loop( 779 | cond=lambda time, *_: time < time_steps, 780 | body=_time_step, 781 | loop_vars=(time, output_ta, state, att_scores), 782 | parallel_iterations=parallel_iterations, 783 | swap_memory=swap_memory) 784 | else: 785 | _, output_final_ta, final_state = control_flow_ops.while_loop( 786 | cond=lambda time, *_: time < time_steps, 787 | body=_time_step, 788 | loop_vars=(time, output_ta, state), 789 | parallel_iterations=parallel_iterations, 790 | swap_memory=swap_memory) 791 | 792 | # Unpack final output if not using output tuples. 793 | final_outputs = tuple(ta.stack() for ta in output_final_ta) 794 | 795 | # Restore some shape information 796 | for output, output_size in zip(final_outputs, flat_output_size): 797 | shape = _concat( 798 | [const_time_steps, const_batch_size], output_size, static=True) 799 | output.set_shape(shape) 800 | 801 | final_outputs = nest.pack_sequence_as( 802 | structure=cell.output_size, flat_sequence=final_outputs) 803 | 804 | return (final_outputs, final_state) 805 | 806 | 807 | def raw_rnn(cell, loop_fn, 808 | parallel_iterations=None, swap_memory=False, scope=None): 809 | """Creates an `RNN` specified by RNNCell `cell` and loop function `loop_fn`. 810 | 811 | **NOTE: This method is still in testing, and the API may change.** 812 | 813 | This function is a more primitive version of `dynamic_rnn` that provides 814 | more direct access to the inputs each iteration. It also provides more 815 | control over when to start and finish reading the sequence, and 816 | what to emit for the output. 817 | 818 | For example, it can be used to implement the dynamic decoder of a seq2seq 819 | model. 820 | 821 | Instead of working with `Tensor` objects, most operations work with 822 | `TensorArray` objects directly. 823 | 824 | The operation of `raw_rnn`, in pseudo-code, is basically the following: 825 | 826 | ```python 827 | time = tf.constant(0, dtype=tf.int32) 828 | (finished, next_input, initial_state, _, loop_state) = loop_fn( 829 | time=time, cell_output=None, cell_state=None, loop_state=None) 830 | emit_ta = TensorArray(dynamic_size=True, dtype=initial_state.dtype) 831 | state = initial_state 832 | while not all(finished): 833 | (output, cell_state) = cell(next_input, state) 834 | (next_finished, next_input, next_state, emit, loop_state) = loop_fn( 835 | time=time + 1, cell_output=output, cell_state=cell_state, 836 | loop_state=loop_state) 837 | # Emit zeros and copy forward state for minibatch entries that are finished. 838 | state = tf.where(finished, state, next_state) 839 | emit = tf.where(finished, tf.zeros_like(emit), emit) 840 | emit_ta = emit_ta.write(time, emit) 841 | # If any new minibatch entries are marked as finished, mark these. 842 | finished = tf.logical_or(finished, next_finished) 843 | time += 1 844 | return (emit_ta, state, loop_state) 845 | ``` 846 | 847 | with the additional properties that output and state may be (possibly nested) 848 | tuples, as determined by `cell.output_size` and `cell.state_size`, and 849 | as a result the final `state` and `emit_ta` may themselves be tuples. 850 | 851 | A simple implementation of `dynamic_rnn` via `raw_rnn` looks like this: 852 | 853 | ```python 854 | inputs = tf.placeholder(shape=(max_time, batch_size, input_depth), 855 | dtype=tf.float32) 856 | sequence_length = tf.placeholder(shape=(batch_size,), dtype=tf.int32) 857 | inputs_ta = tf.TensorArray(dtype=tf.float32, size=max_time) 858 | inputs_ta = inputs_ta.unstack(inputs) 859 | 860 | cell = tf.contrib.rnn.LSTMCell(num_units) 861 | 862 | def loop_fn(time, cell_output, cell_state, loop_state): 863 | emit_output = cell_output # == None for time == 0 864 | if cell_output is None: # time == 0 865 | next_cell_state = cell.zero_state(batch_size, tf.float32) 866 | else: 867 | next_cell_state = cell_state 868 | elements_finished = (time >= sequence_length) 869 | finished = tf.reduce_all(elements_finished) 870 | next_input = tf.cond( 871 | finished, 872 | lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32), 873 | lambda: inputs_ta.read(time)) 874 | next_loop_state = None 875 | return (elements_finished, next_input, next_cell_state, 876 | emit_output, next_loop_state) 877 | 878 | outputs_ta, final_state, _ = raw_rnn(cell, loop_fn) 879 | outputs = outputs_ta.stack() 880 | ``` 881 | 882 | Args: 883 | cell: An instance of RNNCell. 884 | loop_fn: A callable that takes inputs 885 | `(time, cell_output, cell_state, loop_state)` 886 | and returns the tuple 887 | `(finished, next_input, next_cell_state, emit_output, next_loop_state)`. 888 | Here `time` is an int32 scalar `Tensor`, `cell_output` is a 889 | `Tensor` or (possibly nested) tuple of tensors as determined by 890 | `cell.output_size`, and `cell_state` is a `Tensor` 891 | or (possibly nested) tuple of tensors, as determined by the `loop_fn` 892 | on its first call (and should match `cell.state_size`). 893 | The outputs are: `finished`, a boolean `Tensor` of 894 | shape `[batch_size]`, `next_input`: the next input to feed to `cell`, 895 | `next_cell_state`: the next state to feed to `cell`, 896 | and `emit_output`: the output to store for this iteration. 897 | 898 | Note that `emit_output` should be a `Tensor` or (possibly nested) 899 | tuple of tensors with shapes and structure matching `cell.output_size` 900 | and `cell_output` above. The parameter `cell_state` and output 901 | `next_cell_state` may be either a single or (possibly nested) tuple 902 | of tensors. The parameter `loop_state` and 903 | output `next_loop_state` may be either a single or (possibly nested) tuple 904 | of `Tensor` and `TensorArray` objects. This last parameter 905 | may be ignored by `loop_fn` and the return value may be `None`. If it 906 | is not `None`, then the `loop_state` will be propagated through the RNN 907 | loop, for use purely by `loop_fn` to keep track of its own state. 908 | The `next_loop_state` parameter returned may be `None`. 909 | 910 | The first call to `loop_fn` will be `time = 0`, `cell_output = None`, 911 | `cell_state = None`, and `loop_state = None`. For this call: 912 | The `next_cell_state` value should be the value with which to initialize 913 | the cell's state. It may be a final state from a previous RNN or it 914 | may be the output of `cell.zero_state()`. It should be a 915 | (possibly nested) tuple structure of tensors. 916 | If `cell.state_size` is an integer, this must be 917 | a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`. 918 | If `cell.state_size` is a `TensorShape`, this must be a `Tensor` of 919 | appropriate type and shape `[batch_size] + cell.state_size`. 920 | If `cell.state_size` is a (possibly nested) tuple of ints or 921 | `TensorShape`, this will be a tuple having the corresponding shapes. 922 | The `emit_output` value may be either `None` or a (possibly nested) 923 | tuple structure of tensors, e.g., 924 | `(tf.zeros(shape_0, dtype=dtype_0), tf.zeros(shape_1, dtype=dtype_1))`. 925 | If this first `emit_output` return value is `None`, 926 | then the `emit_ta` result of `raw_rnn` will have the same structure and 927 | dtypes as `cell.output_size`. Otherwise `emit_ta` will have the same 928 | structure, shapes (prepended with a `batch_size` dimension), and dtypes 929 | as `emit_output`. The actual values returned for `emit_output` at this 930 | initializing call are ignored. Note, this emit structure must be 931 | consistent across all time steps. 932 | 933 | parallel_iterations: (Default: 32). The number of iterations to run in 934 | parallel. Those operations which do not have any temporal dependency 935 | and can be run in parallel, will be. This parameter trades off 936 | time for space. Values >> 1 use more memory but take less time, 937 | while smaller values use less memory but computations take longer. 938 | swap_memory: Transparently swap the tensors produced in forward inference 939 | but needed for back prop from GPU to CPU. This allows training RNNs 940 | which would typically not fit on a single GPU, with very minimal (or no) 941 | performance penalty. 942 | scope: VariableScope for the created subgraph; defaults to "rnn". 943 | 944 | Returns: 945 | A tuple `(emit_ta, final_state, final_loop_state)` where: 946 | 947 | `emit_ta`: The RNN output `TensorArray`. 948 | If `loop_fn` returns a (possibly nested) set of Tensors for 949 | `emit_output` during initialization, (inputs `time = 0`, 950 | `cell_output = None`, and `loop_state = None`), then `emit_ta` will 951 | have the same structure, dtypes, and shapes as `emit_output` instead. 952 | If `loop_fn` returns `emit_output = None` during this call, 953 | the structure of `cell.output_size` is used: 954 | If `cell.output_size` is a (possibly nested) tuple of integers 955 | or `TensorShape` objects, then `emit_ta` will be a tuple having the 956 | same structure as `cell.output_size`, containing TensorArrays whose 957 | elements' shapes correspond to the shape data in `cell.output_size`. 958 | 959 | `final_state`: The final cell state. If `cell.state_size` is an int, this 960 | will be shaped `[batch_size, cell.state_size]`. If it is a 961 | `TensorShape`, this will be shaped `[batch_size] + cell.state_size`. 962 | If it is a (possibly nested) tuple of ints or `TensorShape`, this will 963 | be a tuple having the corresponding shapes. 964 | 965 | `final_loop_state`: The final loop state as returned by `loop_fn`. 966 | 967 | Raises: 968 | TypeError: If `cell` is not an instance of RNNCell, or `loop_fn` is not 969 | a `callable`. 970 | """ 971 | 972 | assert_like_rnncell(cell.name, cell) 973 | if not callable(loop_fn): 974 | raise TypeError("loop_fn must be a callable") 975 | 976 | parallel_iterations = parallel_iterations or 32 977 | 978 | # Create a new scope in which the caching device is either 979 | # determined by the parent scope, or is set to place the cached 980 | # Variable using the same placement as for the rest of the RNN. 981 | with vs.variable_scope(scope or "rnn") as varscope: 982 | if varscope.caching_device is None: 983 | varscope.set_caching_device(lambda op: op.device) 984 | 985 | time = constant_op.constant(0, dtype=dtypes.int32) 986 | (elements_finished, next_input, initial_state, emit_structure, 987 | init_loop_state) = loop_fn( 988 | time, None, None, None) # time, cell_output, cell_state, loop_state 989 | flat_input = nest.flatten(next_input) 990 | 991 | # Need a surrogate loop state for the while_loop if none is available. 992 | loop_state = (init_loop_state if init_loop_state is not None 993 | else constant_op.constant(0, dtype=dtypes.int32)) 994 | 995 | input_shape = [input_.get_shape() for input_ in flat_input] 996 | static_batch_size = input_shape[0][0] 997 | 998 | for input_shape_i in input_shape: 999 | # Static verification that batch sizes all match 1000 | static_batch_size.merge_with(input_shape_i[0]) 1001 | 1002 | batch_size = static_batch_size.value 1003 | if batch_size is None: 1004 | batch_size = array_ops.shape(flat_input[0])[0] 1005 | 1006 | nest.assert_same_structure(initial_state, cell.state_size) 1007 | state = initial_state 1008 | flat_state = nest.flatten(state) 1009 | flat_state = [ops.convert_to_tensor(s) for s in flat_state] 1010 | state = nest.pack_sequence_as(structure=state, 1011 | flat_sequence=flat_state) 1012 | 1013 | if emit_structure is not None: 1014 | flat_emit_structure = nest.flatten(emit_structure) 1015 | flat_emit_size = [emit.shape if emit.shape.is_fully_defined() else 1016 | array_ops.shape(emit) for emit in flat_emit_structure] 1017 | flat_emit_dtypes = [emit.dtype for emit in flat_emit_structure] 1018 | else: 1019 | emit_structure = cell.output_size 1020 | flat_emit_size = nest.flatten(emit_structure) 1021 | flat_emit_dtypes = [flat_state[0].dtype] * len(flat_emit_size) 1022 | 1023 | flat_emit_ta = [ 1024 | tensor_array_ops.TensorArray( 1025 | dtype=dtype_i, dynamic_size=True, size=0, name="rnn_output_%d" % i) 1026 | for i, dtype_i in enumerate(flat_emit_dtypes)] 1027 | emit_ta = nest.pack_sequence_as(structure=emit_structure, 1028 | flat_sequence=flat_emit_ta) 1029 | flat_zero_emit = [ 1030 | array_ops.zeros(_concat(batch_size, size_i), dtype_i) 1031 | for size_i, dtype_i in zip(flat_emit_size, flat_emit_dtypes)] 1032 | zero_emit = nest.pack_sequence_as(structure=emit_structure, 1033 | flat_sequence=flat_zero_emit) 1034 | 1035 | def condition(unused_time, elements_finished, *_): 1036 | return math_ops.logical_not(math_ops.reduce_all(elements_finished)) 1037 | 1038 | def body(time, elements_finished, current_input, 1039 | emit_ta, state, loop_state): 1040 | """Internal while loop body for raw_rnn. 1041 | 1042 | Args: 1043 | time: time scalar. 1044 | elements_finished: batch-size vector. 1045 | current_input: possibly nested tuple of input tensors. 1046 | emit_ta: possibly nested tuple of output TensorArrays. 1047 | state: possibly nested tuple of state tensors. 1048 | loop_state: possibly nested tuple of loop state tensors. 1049 | 1050 | Returns: 1051 | Tuple having the same size as Args but with updated values. 1052 | """ 1053 | (next_output, cell_state) = cell(current_input, state) 1054 | 1055 | nest.assert_same_structure(state, cell_state) 1056 | nest.assert_same_structure(cell.output_size, next_output) 1057 | 1058 | next_time = time + 1 1059 | (next_finished, next_input, next_state, emit_output, 1060 | next_loop_state) = loop_fn( 1061 | next_time, next_output, cell_state, loop_state) 1062 | 1063 | nest.assert_same_structure(state, next_state) 1064 | nest.assert_same_structure(current_input, next_input) 1065 | nest.assert_same_structure(emit_ta, emit_output) 1066 | 1067 | # If loop_fn returns None for next_loop_state, just reuse the 1068 | # previous one. 1069 | loop_state = loop_state if next_loop_state is None else next_loop_state 1070 | 1071 | def _copy_some_through(current, candidate): 1072 | """Copy some tensors through via array_ops.where.""" 1073 | def copy_fn(cur_i, cand_i): 1074 | with ops.colocate_with(cand_i): 1075 | return array_ops.where(elements_finished, cur_i, cand_i) 1076 | return nest.map_structure(copy_fn, current, candidate) 1077 | 1078 | emit_output = _copy_some_through(zero_emit, emit_output) 1079 | next_state = _copy_some_through(state, next_state) 1080 | 1081 | emit_ta = nest.map_structure( 1082 | lambda ta, emit: ta.write(time, emit), emit_ta, emit_output) 1083 | 1084 | elements_finished = math_ops.logical_or(elements_finished, next_finished) 1085 | 1086 | return (next_time, elements_finished, next_input, 1087 | emit_ta, next_state, loop_state) 1088 | 1089 | returned = control_flow_ops.while_loop( 1090 | condition, body, loop_vars=[ 1091 | time, elements_finished, next_input, 1092 | emit_ta, state, loop_state], 1093 | parallel_iterations=parallel_iterations, 1094 | swap_memory=swap_memory) 1095 | 1096 | (emit_ta, final_state, final_loop_state) = returned[-3:] 1097 | 1098 | if init_loop_state is None: 1099 | final_loop_state = None 1100 | 1101 | return (emit_ta, final_state, final_loop_state) 1102 | 1103 | 1104 | def static_rnn(cell, 1105 | inputs, 1106 | initial_state=None, 1107 | dtype=None, 1108 | sequence_length=None, 1109 | scope=None): 1110 | """Creates a recurrent neural network specified by RNNCell `cell`. 1111 | 1112 | The simplest form of RNN network generated is: 1113 | 1114 | ```python 1115 | state = cell.zero_state(...) 1116 | outputs = [] 1117 | for input_ in inputs: 1118 | output, state = cell(input_, state) 1119 | outputs.append(output) 1120 | return (outputs, state) 1121 | ``` 1122 | However, a few other options are available: 1123 | 1124 | An initial state can be provided. 1125 | If the sequence_length vector is provided, dynamic calculation is performed. 1126 | This method of calculation does not compute the RNN steps past the maximum 1127 | sequence length of the minibatch (thus saving computational time), 1128 | and properly propagates the state at an example's sequence length 1129 | to the final state output. 1130 | 1131 | The dynamic calculation performed is, at time `t` for batch row `b`, 1132 | 1133 | ```python 1134 | (output, state)(b, t) = 1135 | (t >= sequence_length(b)) 1136 | ? (zeros(cell.output_size), states(b, sequence_length(b) - 1)) 1137 | : cell(input(b, t), state(b, t - 1)) 1138 | ``` 1139 | 1140 | Args: 1141 | cell: An instance of RNNCell. 1142 | inputs: A length T list of inputs, each a `Tensor` of shape 1143 | `[batch_size, input_size]`, or a nested tuple of such elements. 1144 | initial_state: (optional) An initial state for the RNN. 1145 | If `cell.state_size` is an integer, this must be 1146 | a `Tensor` of appropriate type and shape `[batch_size, cell.state_size]`. 1147 | If `cell.state_size` is a tuple, this should be a tuple of 1148 | tensors having shapes `[batch_size, s] for s in cell.state_size`. 1149 | dtype: (optional) The data type for the initial state and expected output. 1150 | Required if initial_state is not provided or RNN state has a heterogeneous 1151 | dtype. 1152 | sequence_length: Specifies the length of each sequence in inputs. 1153 | An int32 or int64 vector (tensor) size `[batch_size]`, values in `[0, T)`. 1154 | scope: VariableScope for the created subgraph; defaults to "rnn". 1155 | 1156 | Returns: 1157 | A pair (outputs, state) where: 1158 | 1159 | - outputs is a length T list of outputs (one for each input), or a nested 1160 | tuple of such elements. 1161 | - state is the final state 1162 | 1163 | Raises: 1164 | TypeError: If `cell` is not an instance of RNNCell. 1165 | ValueError: If `inputs` is `None` or an empty list, or if the input depth 1166 | (column size) cannot be inferred from inputs via shape inference. 1167 | """ 1168 | 1169 | assert_like_rnncell(cell.name, cell) 1170 | if not nest.is_sequence(inputs): 1171 | raise TypeError("inputs must be a sequence") 1172 | if not inputs: 1173 | raise ValueError("inputs must not be empty") 1174 | 1175 | outputs = [] 1176 | # Create a new scope in which the caching device is either 1177 | # determined by the parent scope, or is set to place the cached 1178 | # Variable using the same placement as for the rest of the RNN. 1179 | with vs.variable_scope(scope or "rnn") as varscope: 1180 | if varscope.caching_device is None: 1181 | varscope.set_caching_device(lambda op: op.device) 1182 | 1183 | # Obtain the first sequence of the input 1184 | first_input = inputs 1185 | while nest.is_sequence(first_input): 1186 | first_input = first_input[0] 1187 | 1188 | # Temporarily avoid EmbeddingWrapper and seq2seq badness 1189 | # TODO(lukaszkaiser): remove EmbeddingWrapper 1190 | if first_input.get_shape().ndims != 1: 1191 | 1192 | input_shape = first_input.get_shape().with_rank_at_least(2) 1193 | fixed_batch_size = input_shape[0] 1194 | 1195 | flat_inputs = nest.flatten(inputs) 1196 | for flat_input in flat_inputs: 1197 | input_shape = flat_input.get_shape().with_rank_at_least(2) 1198 | batch_size, input_size = input_shape[0], input_shape[1:] 1199 | fixed_batch_size.merge_with(batch_size) 1200 | for i, size in enumerate(input_size): 1201 | if size.value is None: 1202 | raise ValueError( 1203 | "Input size (dimension %d of inputs) must be accessible via " 1204 | "shape inference, but saw value None." % i) 1205 | else: 1206 | fixed_batch_size = first_input.get_shape().with_rank_at_least(1)[0] 1207 | 1208 | if fixed_batch_size.value: 1209 | batch_size = fixed_batch_size.value 1210 | else: 1211 | batch_size = array_ops.shape(first_input)[0] 1212 | if initial_state is not None: 1213 | state = initial_state 1214 | else: 1215 | if not dtype: 1216 | raise ValueError("If no initial_state is provided, " 1217 | "dtype must be specified") 1218 | state = cell.zero_state(batch_size, dtype) 1219 | 1220 | if sequence_length is not None: # Prepare variables 1221 | sequence_length = ops.convert_to_tensor( 1222 | sequence_length, name="sequence_length") 1223 | if sequence_length.get_shape().ndims not in (None, 1): 1224 | raise ValueError( 1225 | "sequence_length must be a vector of length batch_size") 1226 | 1227 | def _create_zero_output(output_size): 1228 | # convert int to TensorShape if necessary 1229 | size = _concat(batch_size, output_size) 1230 | output = array_ops.zeros( 1231 | array_ops.stack(size), _infer_state_dtype(dtype, state)) 1232 | shape = _concat(fixed_batch_size.value, output_size, static=True) 1233 | output.set_shape(tensor_shape.TensorShape(shape)) 1234 | return output 1235 | 1236 | output_size = cell.output_size 1237 | flat_output_size = nest.flatten(output_size) 1238 | flat_zero_output = tuple( 1239 | _create_zero_output(size) for size in flat_output_size) 1240 | zero_output = nest.pack_sequence_as( 1241 | structure=output_size, flat_sequence=flat_zero_output) 1242 | 1243 | sequence_length = math_ops.to_int32(sequence_length) 1244 | min_sequence_length = math_ops.reduce_min(sequence_length) 1245 | max_sequence_length = math_ops.reduce_max(sequence_length) 1246 | 1247 | for time, input_ in enumerate(inputs): 1248 | if time > 0: 1249 | varscope.reuse_variables() 1250 | # pylint: disable=cell-var-from-loop 1251 | call_cell = lambda: cell(input_, state) 1252 | # pylint: enable=cell-var-from-loop 1253 | if sequence_length is not None: 1254 | (output, state) = _rnn_step( 1255 | time=time, 1256 | sequence_length=sequence_length, 1257 | min_sequence_length=min_sequence_length, 1258 | max_sequence_length=max_sequence_length, 1259 | zero_output=zero_output, 1260 | state=state, 1261 | call_cell=call_cell, 1262 | state_size=cell.state_size) 1263 | else: 1264 | (output, state) = call_cell() 1265 | 1266 | outputs.append(output) 1267 | 1268 | return (outputs, state) 1269 | 1270 | 1271 | def static_state_saving_rnn(cell, 1272 | inputs, 1273 | state_saver, 1274 | state_name, 1275 | sequence_length=None, 1276 | scope=None): 1277 | """RNN that accepts a state saver for time-truncated RNN calculation. 1278 | 1279 | Args: 1280 | cell: An instance of `RNNCell`. 1281 | inputs: A length T list of inputs, each a `Tensor` of shape 1282 | `[batch_size, input_size]`. 1283 | state_saver: A state saver object with methods `state` and `save_state`. 1284 | state_name: Python string or tuple of strings. The name to use with the 1285 | state_saver. If the cell returns tuples of states (i.e., 1286 | `cell.state_size` is a tuple) then `state_name` should be a tuple of 1287 | strings having the same length as `cell.state_size`. Otherwise it should 1288 | be a single string. 1289 | sequence_length: (optional) An int32/int64 vector size [batch_size]. 1290 | See the documentation for rnn() for more details about sequence_length. 1291 | scope: VariableScope for the created subgraph; defaults to "rnn". 1292 | 1293 | Returns: 1294 | A pair (outputs, state) where: 1295 | outputs is a length T list of outputs (one for each input) 1296 | states is the final state 1297 | 1298 | Raises: 1299 | TypeError: If `cell` is not an instance of RNNCell. 1300 | ValueError: If `inputs` is `None` or an empty list, or if the arity and 1301 | type of `state_name` does not match that of `cell.state_size`. 1302 | """ 1303 | state_size = cell.state_size 1304 | state_is_tuple = nest.is_sequence(state_size) 1305 | state_name_tuple = nest.is_sequence(state_name) 1306 | 1307 | if state_is_tuple != state_name_tuple: 1308 | raise ValueError("state_name should be the same type as cell.state_size. " 1309 | "state_name: %s, cell.state_size: %s" % (str(state_name), 1310 | str(state_size))) 1311 | 1312 | if state_is_tuple: 1313 | state_name_flat = nest.flatten(state_name) 1314 | state_size_flat = nest.flatten(state_size) 1315 | 1316 | if len(state_name_flat) != len(state_size_flat): 1317 | raise ValueError("#elems(state_name) != #elems(state_size): %d vs. %d" % 1318 | (len(state_name_flat), len(state_size_flat))) 1319 | 1320 | initial_state = nest.pack_sequence_as( 1321 | structure=state_size, 1322 | flat_sequence=[state_saver.state(s) for s in state_name_flat]) 1323 | else: 1324 | initial_state = state_saver.state(state_name) 1325 | 1326 | (outputs, state) = static_rnn( 1327 | cell, 1328 | inputs, 1329 | initial_state=initial_state, 1330 | sequence_length=sequence_length, 1331 | scope=scope) 1332 | 1333 | if state_is_tuple: 1334 | flat_state = nest.flatten(state) 1335 | state_name = nest.flatten(state_name) 1336 | save_state = [ 1337 | state_saver.save_state(name, substate) 1338 | for name, substate in zip(state_name, flat_state) 1339 | ] 1340 | else: 1341 | save_state = [state_saver.save_state(state_name, state)] 1342 | 1343 | with ops.control_dependencies(save_state): 1344 | last_output = outputs[-1] 1345 | flat_last_output = nest.flatten(last_output) 1346 | flat_last_output = [ 1347 | array_ops.identity(output) for output in flat_last_output 1348 | ] 1349 | outputs[-1] = nest.pack_sequence_as( 1350 | structure=last_output, flat_sequence=flat_last_output) 1351 | 1352 | return (outputs, state) 1353 | 1354 | 1355 | def static_bidirectional_rnn(cell_fw, 1356 | cell_bw, 1357 | inputs, 1358 | initial_state_fw=None, 1359 | initial_state_bw=None, 1360 | dtype=None, 1361 | sequence_length=None, 1362 | scope=None): 1363 | """Creates a bidirectional recurrent neural network. 1364 | 1365 | Similar to the unidirectional case above (rnn) but takes input and builds 1366 | independent forward and backward RNNs with the final forward and backward 1367 | outputs depth-concatenated, such that the output will have the format 1368 | [time][batch][cell_fw.output_size + cell_bw.output_size]. The input_size of 1369 | forward and backward cell must match. The initial state for both directions 1370 | is zero by default (but can be set optionally) and no intermediate states are 1371 | ever returned -- the network is fully unrolled for the given (passed in) 1372 | length(s) of the sequence(s) or completely unrolled if length(s) is not given. 1373 | 1374 | Args: 1375 | cell_fw: An instance of RNNCell, to be used for forward direction. 1376 | cell_bw: An instance of RNNCell, to be used for backward direction. 1377 | inputs: A length T list of inputs, each a tensor of shape 1378 | [batch_size, input_size], or a nested tuple of such elements. 1379 | initial_state_fw: (optional) An initial state for the forward RNN. 1380 | This must be a tensor of appropriate type and shape 1381 | `[batch_size, cell_fw.state_size]`. 1382 | If `cell_fw.state_size` is a tuple, this should be a tuple of 1383 | tensors having shapes `[batch_size, s] for s in cell_fw.state_size`. 1384 | initial_state_bw: (optional) Same as for `initial_state_fw`, but using 1385 | the corresponding properties of `cell_bw`. 1386 | dtype: (optional) The data type for the initial state. Required if 1387 | either of the initial states are not provided. 1388 | sequence_length: (optional) An int32/int64 vector, size `[batch_size]`, 1389 | containing the actual lengths for each of the sequences. 1390 | scope: VariableScope for the created subgraph; defaults to 1391 | "bidirectional_rnn" 1392 | 1393 | Returns: 1394 | A tuple (outputs, output_state_fw, output_state_bw) where: 1395 | outputs is a length `T` list of outputs (one for each input), which 1396 | are depth-concatenated forward and backward outputs. 1397 | output_state_fw is the final state of the forward rnn. 1398 | output_state_bw is the final state of the backward rnn. 1399 | 1400 | Raises: 1401 | TypeError: If `cell_fw` or `cell_bw` is not an instance of `RNNCell`. 1402 | ValueError: If inputs is None or an empty list. 1403 | """ 1404 | 1405 | assert_like_rnncell(cell_fw.name, cell_fw) 1406 | assert_like_rnncell(cell_bw.name, cell_bw) 1407 | if not nest.is_sequence(inputs): 1408 | raise TypeError("inputs must be a sequence") 1409 | if not inputs: 1410 | raise ValueError("inputs must not be empty") 1411 | 1412 | with vs.variable_scope(scope or "bidirectional_rnn"): 1413 | # Forward direction 1414 | with vs.variable_scope("fw") as fw_scope: 1415 | output_fw, output_state_fw = static_rnn( 1416 | cell_fw, 1417 | inputs, 1418 | initial_state_fw, 1419 | dtype, 1420 | sequence_length, 1421 | scope=fw_scope) 1422 | 1423 | # Backward direction 1424 | with vs.variable_scope("bw") as bw_scope: 1425 | reversed_inputs = _reverse_seq(inputs, sequence_length) 1426 | tmp, output_state_bw = static_rnn( 1427 | cell_bw, 1428 | reversed_inputs, 1429 | initial_state_bw, 1430 | dtype, 1431 | sequence_length, 1432 | scope=bw_scope) 1433 | 1434 | output_bw = _reverse_seq(tmp, sequence_length) 1435 | # Concat each of the forward/backward outputs 1436 | flat_output_fw = nest.flatten(output_fw) 1437 | flat_output_bw = nest.flatten(output_bw) 1438 | 1439 | flat_outputs = tuple( 1440 | array_ops.concat([fw, bw], 1) 1441 | for fw, bw in zip(flat_output_fw, flat_output_bw)) 1442 | 1443 | outputs = nest.pack_sequence_as( 1444 | structure=output_fw, flat_sequence=flat_outputs) 1445 | 1446 | return (outputs, output_state_fw, output_state_bw) 1447 | -------------------------------------------------------------------------------- /DIEN/tianchi.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | ## Dataset 3 | We random select about 8 million users with over 15M shopping samples. Each sample consists of 100 user history data. The dataset is organized in a very similar form to MovieLens-20M, i.e., each line represents a specific user-item interaction and history interactions, including key information like user ID, item ID, category ID, shop ID, node ID, product ID, brand ID. The dataset is split into training and testing. Training has 15M samples and testing has 0.97M samples. 4 | 5 | Dimensions of the dataset are 6 | 7 | | Dimension | Dimension Size | Feature Exaplanation| 8 | |----------------------|----------------|---------------------| 9 | | Number of users | 7956430 |An integer that represents a user| 10 | | Number of items | 34196611 |An integer that represents an item| 11 | | Number of categories | 5596 |An integer that represents the category which the corresponding item belongs to| 12 | | Number of shops | 4377722 |An integer that represents a shop | 13 | | Number of nodes | 2975349 |An integer that represents a cluster which some items belong to| 14 | | Number of products | 65624 |An integer that represents a product| 15 | | Number of brands | 584181 |An integer that represents a brand| 16 | | Number of interactions| 15M |An integer that represents a sample| 17 | 18 | **Citations** 19 | Guorui Zhou,* Na Mou,† Ying Fan, Qi Pi, Weijie Bian,. Chang Zhou, Xiaoqiang Zhu, Kun Gai. Deep Interest Evolution Network for Click-Through Rate Prediction. In Proceedings of the 28th _AAAI_ Conference on Artificial Intelligence, 1369–1375. -------------------------------------------------------------------------------- /DIEN/train_taobao_processed_allfea.py: -------------------------------------------------------------------------------- 1 | #coding:utf-8 2 | import numpy as np 3 | from data_iterator import DataIterator 4 | import tensorflow as tf 5 | from model_taobao_allfea import * 6 | import time 7 | import random 8 | import sys 9 | import json 10 | from utils import * 11 | import multiprocessing 12 | from multiprocessing import Process, Value, Array 13 | from wrap_time import time_it 14 | from data_loader import DataLoader 15 | import threading 16 | from collections import deque 17 | 18 | import os 19 | if "BRIDGE_ENABLE_TAO" in os.environ and os.environ["BRIDGE_ENABLE_TAO"] == "true": 20 | tf.load_op_library("/home/admin/.tao_compiler/libtao_ops.so") 21 | 22 | def file_num(x): 23 | if x < 10: 24 | return '0' + str(x) 25 | else: 26 | return str(x) 27 | 28 | 29 | EMBEDDING_DIM = 4 30 | HIDDEN_SIZE = EMBEDDING_DIM * 6 31 | MEMORY_SIZE = 4 32 | 33 | 34 | def eval(sess, test_file, model, model_path, batch_size, maxlen, best_auc = [1.0]): 35 | print("Testing starts------------") 36 | data_load_test= DataLoader(test_file, 'test_sample', batch_size, 4 ) 37 | producer1 = threading.Thread(target=data_load_test.data_read, args=(0,2)) 38 | producer2 = threading.Thread(target=data_load_test.data_read, args=(1,2)) 39 | producer1.start() 40 | producer2.start() 41 | loss_sum = 0. 42 | accuracy_sum = 0. 43 | aux_loss_sum = 0. 44 | iterations = 0 45 | stored_arr = [] 46 | 47 | for data in data_load_test.next(): 48 | iterations +=1 49 | user_id, item_id, cate_id,shop_id, node_id, product_id, brand_id, \ 50 | label, hist_item, hist_cate, hist_shop, hist_node, hist_product, \ 51 | hist_brand, hist_mask, neg_hist_item, neg_hist_cate, \ 52 | neg_hist_shop, neg_hist_node, neg_hist_product, neg_hist_brand = data 53 | target = label 54 | prob, loss, acc, aux_loss = model.calculate(sess, [user_id, item_id, \ 55 | cate_id, shop_id, node_id, product_id, brand_id, \ 56 | hist_item, hist_cate, hist_shop, hist_node, hist_product, hist_brand, \ 57 | neg_hist_item, neg_hist_cate, neg_hist_shop, neg_hist_node, neg_hist_product, neg_hist_brand, hist_mask, label]) 58 | loss_sum += loss 59 | aux_loss_sum = aux_loss 60 | accuracy_sum += acc 61 | prob_1 = prob[:, 0].tolist() 62 | target_1 = target[:, 0].tolist() 63 | # user_l = user_id.tolist() 64 | for p ,t in zip(prob_1, target_1): 65 | stored_arr.append([p, t]) 66 | 67 | #test_auc = calc_gauc(stored_arr, user_l) 68 | test_auc = calc_auc(stored_arr) 69 | accuracy_sum = accuracy_sum / iterations 70 | loss_sum = loss_sum / iterations 71 | aux_loss_sum = aux_loss_sum / iterations 72 | if best_auc[0] < test_auc: 73 | best_auc[0] = test_auc 74 | model.save(sess, model_path) 75 | producer1.join() 76 | producer2.join() 77 | return test_auc, loss_sum, accuracy_sum, aux_loss_sum, best_auc[0] 78 | 79 | def train( 80 | train_file, 81 | test_file, 82 | batch_size = 256, 83 | maxlen = 100, 84 | test_iter = 500, 85 | save_iter = 5000, 86 | model_type = 'DNN', 87 | Memory_Size = 4, 88 | ): 89 | TEM_MEMORY_SIZE = Memory_Size 90 | model_path = "dnn_save_path/taobao_ckpt_noshuff" + model_type 91 | best_model_path = "dnn_best_model/taobao_ckpt_noshuff" + model_type 92 | gpu_options = tf.GPUOptions(allow_growth=True) 93 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: 94 | 95 | # Obtained in the data preprocess stage. To save time, json files are not needed. 96 | uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n = [7956430, 34196611, 5596, 4377722, 2975349, 65624, 584181] 97 | BATCH_SIZE = batch_size 98 | SEQ_LEN = maxlen 99 | 100 | if model_type == 'DNN': 101 | model = Model_DNN(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN) 102 | elif model_type == 'PNN': 103 | model = Model_PNN(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN) 104 | elif model_type == 'GRU4REC': 105 | model = Model_GRU4REC(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN) 106 | elif model_type == 'DIN': 107 | model = Model_DIN(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN) 108 | elif model_type == 'ARNN': 109 | model = Model_ARNN(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN) 110 | elif model_type == 'DIEN': 111 | model = Model_DIEN(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN) 112 | elif model_type == 'DIEN_with_neg': 113 | model = Model_DIEN(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN, use_negsample=True) 114 | else: 115 | print ("Invalid model_type : %s", model_type) 116 | return 117 | 118 | #参数初始化 119 | sess.run(tf.global_variables_initializer()) 120 | sess.run(tf.local_variables_initializer()) 121 | 122 | sys.stdout.flush() 123 | 124 | start_time = time.time() 125 | last_time = start_time 126 | iter = 0 127 | lr = 0.001 128 | best_auc= [0.0] 129 | loss_sum = 0.0 130 | accuracy_sum = 0. 131 | left_loss_sum = 0. 132 | aux_loss_sum = 0. 133 | mem_loss_sum = 0. 134 | # set 1 epoch only 135 | epoch = 1 136 | for itr in range(epoch): 137 | print("epoch"+str(itr)) 138 | #load data 139 | data_load= DataLoader(train_file, 'train_sample', 256, 15) 140 | producer1 = threading.Thread(target=data_load.data_read, args=(0,2)) 141 | producer2 = threading.Thread(target=data_load.data_read, args=(1,2)) 142 | producer1.start() 143 | producer2.start() 144 | #for iteration in range(number_samples): 145 | for data in data_load.next(): 146 | user_id, item_id, cate_id,shop_id, node_id, product_id, brand_id, \ 147 | label, hist_item, hist_cate, hist_shop, hist_node, hist_product, \ 148 | hist_brand, hist_mask, neg_hist_item, neg_hist_cate, neg_hist_shop, \ 149 | neg_hist_node, neg_hist_product, neg_hist_brand = data 150 | # preprocess the -1 index in batch data 151 | item_id[item_id == -1] = item_n 152 | shop_id[shop_id == -1] = shop_n 153 | node_id[node_id == -1] = node_n 154 | product_id[product_id == -1] = product_n 155 | brand_id[brand_id == -1] = brand_n 156 | 157 | loss, acc, aux_loss, mem_loss, left_loss = model.train(sess, [user_id, item_id,\ 158 | cate_id, shop_id, node_id, product_id, brand_id, \ 159 | hist_item, hist_cate, hist_shop, hist_node, hist_product, hist_brand, \ 160 | neg_hist_item, neg_hist_cate, neg_hist_shop, neg_hist_node, neg_hist_product, \ 161 | neg_hist_brand, hist_mask, label, lr]) 162 | # cate_id, hist_item, hist_cate, neg_hist_item, neg_hist_cate, hist_mask, label, lr]) 163 | loss_sum += loss 164 | accuracy_sum += acc 165 | left_loss_sum += left_loss 166 | aux_loss_sum += aux_loss 167 | mem_loss_sum += mem_loss 168 | iter += 1 169 | sys.stdout.flush() 170 | 171 | if (iter % test_iter) == 0: 172 | test_time = time.time() 173 | print('[Iteration]=%d, train_loss=%.4f, train_accuracy=%.4f, train_aux_loss=%.4f, train_left_loss=%.4f, throughput=%.4f, total time=%.4f' %(iter, loss_sum / test_iter, accuracy_sum / test_iter, \ 174 | aux_loss_sum /test_iter , left_loss_sum / test_iter, batch_size*test_iter/(test_time-last_time),\ 175 | test_time-start_time)) 176 | loss_sum = 0.0 177 | accuracy_sum = 0.0 178 | left_loss_sum = 0.0 179 | aux_loss_sum = 0. 180 | mem_loss_sum = 0. 181 | if (iter % save_iter) == 0: 182 | print('save model iter: %d' %(iter)) 183 | model.save(sess, model_path+"--"+str(iter)) 184 | print('Testing finishes-------test_auc=%.4f, test_loss=%.4f, test_accuracy=%.4f, test_aux_loss=%.4f, best_auc=%.4f ' % eval(sess, test_file, model, best_model_path, batch_size, maxlen, best_auc)) 185 | last_time = test_time 186 | 187 | producer1.join() 188 | producer2.join() 189 | 190 | def test( 191 | train_file, 192 | test_file, 193 | batch_size = 256, 194 | maxlen = 100, 195 | test_iter = 100, 196 | save_iter = 400, 197 | model_type = 'DNN', 198 | Memory_Size = 4, 199 | ): 200 | 201 | TEM_MEMORY_SIZE = Memory_Size 202 | Ntm_Flag = "base" 203 | if Ntm_Flag == "base": 204 | model_path = "dnn_save_path/taobao_ckpt_noshuff" + model_type 205 | else: 206 | model_path = "dnn_save_path/taobao_ckpt_noshuff" + model_type+str(Memory_Size) 207 | gpu_options = tf.GPUOptions(allow_growth=True) 208 | with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: 209 | uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n = [7956430, 34196611, 5596, 4377722, 2975349, 65624, 584181] 210 | BATCH_SIZE = batch_size 211 | SEQ_LEN = maxlen 212 | 213 | if model_type == 'DNN': 214 | model = Model_DNN(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN) 215 | elif model_type == 'PNN': 216 | model = Model_PNN(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN) 217 | elif model_type == 'GRU4REC': 218 | model = Model_GRU4REC(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN) 219 | elif model_type == 'DIN': 220 | model = Model_DIN(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN) 221 | elif model_type == 'ARNN': 222 | model = Model_ARNN(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN) 223 | elif model_type == 'DIEN': 224 | model = Model_DIEN(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN) 225 | elif model_type == 'DIEN_with_neg': 226 | model = Model_DIEN(uid_n, item_n, cate_n, shop_n, node_n, product_n, brand_n, EMBEDDING_DIM, HIDDEN_SIZE, MEMORY_SIZE, BATCH_SIZE, SEQ_LEN, use_negsample=True) 227 | else: 228 | print ("Invalid model_type : %s", model_type) 229 | return 230 | 231 | model.restore(sess, model_path+'--50000') 232 | print('test_auc: %.4f ----test_loss: %.4f ---- test_accuracy: %.4f ---- test_aux_loss: %.4f ---- best_auc=%.4f ' % eval(sess, test_file, model, model_path, batch_size, maxlen)) 233 | 234 | 235 | if __name__ == '__main__': 236 | SEED = int(sys.argv[3]) 237 | if len(sys.argv) > 5: 238 | Memory_Size = int(sys.argv[4]) 239 | else: 240 | Memory_Size = 4 241 | tf.set_random_seed(SEED) 242 | np.random.seed(SEED) 243 | random.seed(SEED) 244 | train_file = './process_data_maxlen100_0225/' 245 | test_file = train_file 246 | if sys.argv[1] == 'train': 247 | train(train_file=train_file, test_file=test_file, model_type=sys.argv[2], Memory_Size=Memory_Size) 248 | elif sys.argv[1] == 'test': 249 | test(train_file=train_file, test_file=test_file, model_type=sys.argv[2], Memory_Size=Memory_Size) 250 | else: 251 | print('do nothing...') 252 | -------------------------------------------------------------------------------- /DIEN/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops.rnn_cell import * 3 | #from tensorflow.python.ops.rnn_cell_impl import _Linear 4 | from tensorflow import keras 5 | from tensorflow.python.ops import math_ops 6 | from tensorflow.python.ops import init_ops 7 | from tensorflow.python.ops import array_ops 8 | from tensorflow.python.ops import variable_scope as vs 9 | #from keras import backend as K 10 | 11 | _BIAS_VARIABLE_NAME = "bias" 12 | _WEIGHTS_VARIABLE_NAME = "kernel" 13 | 14 | class _Linear(object): 15 | """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. 16 | Args: 17 | args: a 2D Tensor or a list of 2D, batch x n, Tensors. 18 | output_size: int, second dimension of weight variable. 19 | dtype: data type for variables. 20 | build_bias: boolean, whether to build a bias variable. 21 | bias_initializer: starting value to initialize the bias 22 | (default is all zeros). 23 | kernel_initializer: starting value to initialize the weight. 24 | Raises: 25 | ValueError: if inputs_shape is wrong. 26 | """ 27 | 28 | 29 | def __init__(self, 30 | args, 31 | output_size, 32 | build_bias, 33 | bias_initializer=None, 34 | kernel_initializer=None): 35 | self._build_bias = build_bias 36 | 37 | if args is None or (nest.is_sequence(args) and not args): 38 | raise ValueError("`args` must be specified") 39 | if not nest.is_sequence(args): 40 | args = [args] 41 | self._is_sequence = False 42 | else: 43 | self._is_sequence = True 44 | 45 | # Calculate the total size of arguments on dimension 1. 46 | total_arg_size = 0 47 | shapes = [a.get_shape() for a in args] 48 | for shape in shapes: 49 | if shape.ndims != 2: 50 | raise ValueError("linear is expecting 2D arguments: %s" % shapes) 51 | if shape[1].value is None: 52 | raise ValueError("linear expects shape[1] to be provided for shape %s, " 53 | "but saw %s" % (shape, shape[1])) 54 | else: 55 | total_arg_size += shape[1].value 56 | 57 | dtype = [a.dtype for a in args][0] 58 | 59 | scope = vs.get_variable_scope() 60 | with vs.variable_scope(scope) as outer_scope: 61 | self._weights = vs.get_variable( 62 | _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size], 63 | dtype=dtype, 64 | initializer=kernel_initializer) 65 | if build_bias: 66 | with vs.variable_scope(outer_scope) as inner_scope: 67 | inner_scope.set_partitioner(None) 68 | if bias_initializer is None: 69 | bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) 70 | self._biases = vs.get_variable( 71 | _BIAS_VARIABLE_NAME, [output_size], 72 | dtype=dtype, 73 | initializer=bias_initializer) 74 | 75 | def __call__(self, args): 76 | if not self._is_sequence: 77 | args = [args] 78 | 79 | if len(args) == 1: 80 | res = math_ops.matmul(args[0], self._weights) 81 | else: 82 | res = math_ops.matmul(array_ops.concat(args, 1), self._weights) 83 | if self._build_bias: 84 | res = nn_ops.bias_add(res, self._biases) 85 | return res 86 | 87 | def din_attention(query, facts, attention_size, mask=None, stag='null', mode='SUM', softmax_stag=1, time_major=False, return_alphas=False): 88 | if isinstance(facts, tuple): 89 | # In case of Bi-RNN, concatenate the forward and the backward RNN outputs. 90 | facts = tf.concat(facts, 2) 91 | print ("query_size mismatch") 92 | query = tf.concat(values = [ 93 | query, 94 | query, 95 | ], axis=1) 96 | 97 | if time_major: 98 | # (T,B,D) => (B,T,D) 99 | facts = tf.array_ops.transpose(facts, [1, 0, 2]) 100 | facts_size = facts.get_shape().as_list()[-1] # D value - hidden size of the RNN layer 101 | querry_size = query.get_shape().as_list()[-1] 102 | queries = tf.tile(query, [1, tf.shape(facts)[1]]) 103 | queries = tf.reshape(queries, tf.shape(facts)) 104 | din_all = tf.concat([queries, facts, queries-facts, queries*facts], axis=-1) 105 | d_layer_1_all = tf.layers.dense(din_all, 80, activation=tf.nn.sigmoid, name='f1_att' + stag) 106 | d_layer_2_all = tf.layers.dense(d_layer_1_all, 40, activation=tf.nn.sigmoid, name='f2_att' + stag) 107 | d_layer_3_all = tf.layers.dense(d_layer_2_all, 1, activation=None, name='f3_att' + stag) 108 | d_layer_3_all = tf.reshape(d_layer_3_all, [-1, 1, tf.shape(facts)[1]]) 109 | scores = d_layer_3_all 110 | 111 | if mask is not None: 112 | mask = tf.equal(mask, tf.ones_like(mask)) 113 | key_masks = tf.expand_dims(mask, 1) # [B, 1, T] 114 | paddings = tf.ones_like(scores) * (-2 ** 32 + 1) 115 | scores = tf.where(key_masks, scores, paddings) # [B, 1, T] 116 | 117 | # Activation 118 | if softmax_stag: 119 | scores = tf.nn.softmax(scores) # [B, 1, T] 120 | 121 | # Weighted sum 122 | if mode == 'SUM': 123 | output = tf.matmul(scores, facts) # [B, 1, H] 124 | # output = tf.reshape(output, [-1, tf.shape(facts)[-1]]) 125 | else: 126 | scores = tf.reshape(scores, [-1, tf.shape(facts)[1]]) 127 | output = facts * tf.expand_dims(scores, -1) 128 | output = tf.reshape(output, tf.shape(facts)) 129 | 130 | if return_alphas: 131 | return output, scores 132 | 133 | return output 134 | 135 | 136 | class VecAttGRUCell(RNNCell): 137 | """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078). 138 | Args: 139 | num_units: int, The number of units in the GRU cell. 140 | activation: Nonlinearity to use. Default: `tanh`. 141 | reuse: (optional) Python boolean describing whether to reuse variables 142 | in an existing scope. If not `True`, and the existing scope already has 143 | the given variables, an error is raised. 144 | kernel_initializer: (optional) The initializer to use for the weight and 145 | projection matrices. 146 | bias_initializer: (optional) The initializer to use for the bias. 147 | """ 148 | 149 | def __init__(self, 150 | num_units, 151 | activation=None, 152 | reuse=None, 153 | kernel_initializer=None, 154 | bias_initializer=None): 155 | super(VecAttGRUCell, self).__init__(_reuse=reuse) 156 | self._num_units = num_units 157 | self._activation = activation or math_ops.tanh 158 | self._kernel_initializer = kernel_initializer 159 | self._bias_initializer = bias_initializer 160 | self._gate_linear = None 161 | self._candidate_linear = None 162 | 163 | @property 164 | def state_size(self): 165 | return self._num_units 166 | 167 | @property 168 | def output_size(self): 169 | return self._num_units 170 | def __call__(self, inputs, state, att_score): 171 | return self.call(inputs, state, att_score) 172 | def call(self, inputs, state, att_score=None): 173 | """Gated recurrent unit (GRU) with nunits cells.""" 174 | if self._gate_linear is None: 175 | bias_ones = self._bias_initializer 176 | if self._bias_initializer is None: 177 | bias_ones = init_ops.constant_initializer(1.0, dtype=inputs.dtype) 178 | with vs.variable_scope("gates"): # Reset gate and update gate. 179 | self._gate_linear = _Linear( 180 | [inputs, state], 181 | 2 * self._num_units, 182 | True, 183 | bias_initializer=bias_ones, 184 | kernel_initializer=self._kernel_initializer) 185 | 186 | value = math_ops.sigmoid(self._gate_linear([inputs, state])) 187 | r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) 188 | 189 | r_state = r * state 190 | if self._candidate_linear is None: 191 | with vs.variable_scope("candidate"): 192 | self._candidate_linear = _Linear( 193 | [inputs, r_state], 194 | self._num_units, 195 | True, 196 | bias_initializer=self._bias_initializer, 197 | kernel_initializer=self._kernel_initializer) 198 | c = self._activation(self._candidate_linear([inputs, r_state])) 199 | u = (1.0 - att_score) * u 200 | new_h = u * state + (1 - u) * c 201 | return new_h, new_h 202 | 203 | def prelu(_x, scope=''): 204 | """parametric ReLU activation""" 205 | with tf.variable_scope(name_or_scope=scope, default_name="prelu"): 206 | _alpha = tf.get_variable("prelu_"+scope, shape=_x.get_shape()[-1], 207 | dtype=_x.dtype, initializer=tf.constant_initializer(0.1)) 208 | return tf.maximum(0.0, _x) + _alpha * tf.minimum(0.0, _x) 209 | 210 | def calc_auc(raw_arr): 211 | """Summary 212 | 213 | Args: 214 | raw_arr (TYPE): Description 215 | 216 | Returns: 217 | TYPE: Description 218 | """ 219 | 220 | arr = sorted(raw_arr, key=lambda d:d[0], reverse=True) 221 | pos, neg = 0., 0. 222 | for record in arr: 223 | if record[1] == 1.: 224 | pos += 1 225 | else: 226 | neg += 1 227 | 228 | fp, tp = 0., 0. 229 | xy_arr = [] 230 | for record in arr: 231 | if record[1] == 1.: 232 | tp += 1 233 | else: 234 | fp += 1 235 | xy_arr.append([fp/neg, tp/pos]) 236 | 237 | auc = 0. 238 | prev_x = 0. 239 | prev_y = 0. 240 | for x, y in xy_arr: 241 | if x != prev_x: 242 | auc += ((x - prev_x) * (y + prev_y) / 2.) 243 | prev_x = x 244 | prev_y = y 245 | 246 | return auc 247 | 248 | def calc_gauc(raw_arr, nick_index): 249 | """Summary 250 | 251 | Args: 252 | raw_arr (TYPE): Description 253 | 254 | Returns: 255 | TYPE: Description 256 | """ 257 | last_index = 0 258 | gauc = 0. 259 | pv_sum = 0 260 | for idx in xrange(len(nick_index)): 261 | if nick_index[idx] != nick_index[last_index]: 262 | input_arr = raw_arr[last_index:idx] 263 | auc_val=calc_auc(input_arr) 264 | if auc_val >= 0.0: 265 | gauc += auc_val * len(input_arr) 266 | pv_sum += len(input_arr) 267 | else: 268 | pv_sum += len(input_arr) 269 | last_index = idx 270 | return gauc / pv_sum 271 | 272 | 273 | 274 | 275 | def attention(query, facts, attention_size, mask, stag='null', mode='LIST', softmax_stag=1, time_major=False, return_alphas=False): 276 | if isinstance(facts, tuple): 277 | # In case of Bi-RNN, concatenate the forward and the backward RNN outputs. 278 | facts = tf.concat(facts, 2) 279 | 280 | if time_major: 281 | # (T,B,D) => (B,T,D) 282 | facts = tf.array_ops.transpose(facts, [1, 0, 2]) 283 | 284 | mask = tf.equal(mask, tf.ones_like(mask)) 285 | hidden_size = facts.get_shape().as_list()[-1] # D value - hidden size of the RNN layer 286 | input_size = query.get_shape().as_list()[-1] 287 | 288 | # Trainable parameters 289 | w1 = tf.Variable(tf.random_normal([hidden_size, attention_size], stddev=0.1)) 290 | w2 = tf.Variable(tf.random_normal([input_size, attention_size], stddev=0.1)) 291 | b = tf.Variable(tf.random_normal([attention_size], stddev=0.1)) 292 | v = tf.Variable(tf.random_normal([attention_size], stddev=0.1)) 293 | 294 | with tf.name_scope('v'): 295 | # Applying fully connected layer with non-linear activation to each of the B*T timestamps; 296 | # the shape of `tmp` is (B,T,D)*(D,A)=(B,T,A), where A=attention_size 297 | tmp1 = tf.tensordot(facts, w1, axes=1) 298 | tmp2 = tf.tensordot(query, w2, axes=1) 299 | tmp2 = tf.reshape(tmp2, [-1, 1, tf.shape(tmp2)[-1]]) 300 | tmp = tf.tanh((tmp1 + tmp2) + b) 301 | 302 | # For each of the timestamps its vector of size A from `tmp` is reduced with `v` vector 303 | v_dot_tmp = tf.tensordot(tmp, v, axes=1, name='v_dot_tmp') # (B,T) shape 304 | key_masks = mask # [B, 1, T] 305 | # key_masks = tf.expand_dims(mask, 1) # [B, 1, T] 306 | paddings = tf.ones_like(v_dot_tmp) * (-2 ** 32 + 1) 307 | v_dot_tmp = tf.where(key_masks, v_dot_tmp, paddings) # [B, 1, T] 308 | alphas = tf.nn.softmax(v_dot_tmp, name='alphas') # (B,T) shape 309 | 310 | # Output of (Bi-)RNN is reduced with attention vector; the result has (B,D) shape 311 | #output = tf.reduce_sum(facts * tf.expand_dims(alphas, -1), 1) 312 | output = facts * tf.expand_dims(alphas, -1) 313 | output = tf.reshape(output, tf.shape(facts)) 314 | # output = output / (facts.get_shape().as_list()[-1] ** 0.5) 315 | if not return_alphas: 316 | return output 317 | else: 318 | return output, alphas 319 | 320 | 321 | def din_fcn_attention(query, facts, attention_size, mask, stag='null', mode='SUM', softmax_stag=1, time_major=False, return_alphas=False, forCnn=False): 322 | if isinstance(facts, tuple): 323 | # In case of Bi-RNN, concatenate the forward and the backward RNN outputs. 324 | facts = tf.concat(facts, 2) 325 | if len(facts.get_shape().as_list()) == 2: 326 | facts = tf.expand_dims(facts, 1) 327 | 328 | if time_major: 329 | # (T,B,D) => (B,T,D) 330 | facts = tf.array_ops.transpose(facts, [1, 0, 2]) 331 | # Trainable parameters 332 | facts_size = facts.get_shape().as_list()[-1] # D value - hidden size of the RNN layer 333 | querry_size = query.get_shape().as_list()[-1] 334 | query = tf.layers.dense(query, facts_size, activation=None, name='f1' + stag) 335 | query = prelu(query) 336 | queries = tf.tile(query, [1, tf.shape(facts)[1]]) 337 | queries = tf.reshape(queries, tf.shape(facts)) 338 | din_all = tf.concat([queries, facts, queries-facts, queries*facts], axis=-1) 339 | d_layer_1_all = tf.layers.dense(din_all, 80, activation=tf.nn.sigmoid, name='f1_att' + stag) 340 | d_layer_2_all = tf.layers.dense(d_layer_1_all, 40, activation=tf.nn.sigmoid, name='f2_att' + stag) 341 | d_layer_3_all = tf.layers.dense(d_layer_2_all, 1, activation=None, name='f3_att' + stag) 342 | d_layer_3_all = tf.reshape(d_layer_3_all, [-1, 1, tf.shape(facts)[1]]) 343 | scores = d_layer_3_all 344 | # Mask 345 | if mask is not None: 346 | # key_masks = tf.sequence_mask(facts_length, tf.shape(facts)[1]) # [B, T] 347 | key_masks = tf.expand_dims(mask, 1) # [B, 1, T] 348 | paddings = tf.ones_like(scores) * (-2 ** 32 + 1) 349 | if not forCnn: 350 | scores = tf.where(key_masks, scores, paddings) # [B, 1, T] 351 | 352 | # Scale 353 | # scores = scores / (facts.get_shape().as_list()[-1] ** 0.5) 354 | 355 | # Activation 356 | if softmax_stag: 357 | scores = tf.nn.softmax(scores) # [B, 1, T] 358 | 359 | # Weighted sum 360 | if mode == 'SUM': 361 | output = tf.matmul(scores, facts) # [B, 1, H] 362 | # output = tf.reshape(output, [-1, tf.shape(facts)[-1]]) 363 | else: 364 | scores = tf.reshape(scores, [-1, tf.shape(facts)[1]]) 365 | output = facts * tf.expand_dims(scores, -1) 366 | output = tf.reshape(output, tf.shape(facts)) 367 | if return_alphas: 368 | return output, scores 369 | return output 370 | 371 | def self_attention(facts, ATTENTION_SIZE, mask, stag='null'): 372 | if len(facts.get_shape().as_list()) == 2: 373 | facts = tf.expand_dims(facts, 1) 374 | 375 | def cond(batch, output, i): 376 | return tf.less(i, tf.shape(batch)[1]) 377 | 378 | def body(batch, output, i): 379 | self_attention_tmp = din_fcn_attention(batch[:, i, :], batch[:, 0:i+1, :], 380 | ATTENTION_SIZE, mask[:, 0:i+1], softmax_stag=1, stag=stag, 381 | mode='LIST') 382 | self_attention_tmp = tf.reduce_sum(self_attention_tmp, 1) 383 | output = output.write(i, self_attention_tmp) 384 | return batch, output, i + 1 385 | 386 | output_ta = tf.TensorArray(dtype=tf.float32, 387 | size=0, 388 | dynamic_size=True, 389 | element_shape=(facts[:, 0, :].get_shape())) 390 | _, output_op, _ = tf.while_loop(cond, body, [facts, output_ta, 0]) 391 | self_attention = output_op.stack() 392 | self_attention = tf.transpose(self_attention, perm = [1, 0, 2]) 393 | return self_attention 394 | 395 | def self_all_attention(facts, ATTENTION_SIZE, mask, stag='null'): 396 | if len(facts.get_shape().as_list()) == 2: 397 | facts = tf.expand_dims(facts, 1) 398 | 399 | def cond(batch, output, i): 400 | return tf.less(i, tf.shape(batch)[1]) 401 | 402 | def body(batch, output, i): 403 | self_attention_tmp = din_fcn_attention(batch[:, i, :], batch, 404 | ATTENTION_SIZE, mask, softmax_stag=1, stag=stag, 405 | mode='LIST') 406 | self_attention_tmp = tf.reduce_sum(self_attention_tmp, 1) 407 | output = output.write(i, self_attention_tmp) 408 | return batch, output, i + 1 409 | 410 | output_ta = tf.TensorArray(dtype=tf.float32, 411 | size=0, 412 | dynamic_size=True, 413 | element_shape=(facts[:, 0, :].get_shape())) 414 | _, output_op, _ = tf.while_loop(cond, body, [facts, output_ta, 0]) 415 | self_attention = output_op.stack() 416 | self_attention = tf.transpose(self_attention, perm = [1, 0, 2]) 417 | return self_attention 418 | 419 | def din_fcn_shine(query, facts, attention_size, mask, stag='null', mode='SUM', softmax_stag=1, time_major=False, return_alphas=False): 420 | if isinstance(facts, tuple): 421 | # In case of Bi-RNN, concatenate the forward and the backward RNN outputs. 422 | facts = tf.concat(facts, 2) 423 | 424 | if time_major: 425 | # (T,B,D) => (B,T,D) 426 | facts = tf.array_ops.transpose(facts, [1, 0, 2]) 427 | # Trainable parameters 428 | mask = tf.equal(mask, tf.ones_like(mask)) 429 | facts_size = facts.get_shape().as_list()[-1] # D value - hidden size of the RNN layer 430 | querry_size = query.get_shape().as_list()[-1] 431 | query = tf.layers.dense(query, facts_size, activation=None, name='f1_trans_shine' + stag) 432 | query = prelu(query) 433 | queries = tf.tile(query, [1, tf.shape(facts)[1]]) 434 | queries = tf.reshape(queries, tf.shape(facts)) 435 | din_all = tf.concat([queries, facts, queries-facts, queries*facts], axis=-1) 436 | d_layer_1_all = tf.layers.dense(din_all, facts_size, activation=tf.nn.sigmoid, name='f1_shine_att' + stag) 437 | d_layer_2_all = tf.layers.dense(d_layer_1_all, facts_size, activation=tf.nn.sigmoid, name='f2_shine_att' + stag) 438 | d_layer_2_all = tf.reshape(d_layer_2_all, tf.shape(facts)) 439 | output = d_layer_2_all 440 | return output 441 | 442 | -------------------------------------------------------------------------------- /DIEN/wrap_time.py: -------------------------------------------------------------------------------- 1 | """Summary 2 | """ 3 | import time 4 | import logging 5 | import types 6 | 7 | 8 | def time_it(freq=1): 9 | """Summary 10 | 11 | Args: 12 | freq (TYPE): Description 13 | """ 14 | class time_cls(object): 15 | 16 | """Summary 17 | 18 | Attributes: 19 | acc_time (float): Description 20 | call_count (int): Description 21 | class_name (str): Description 22 | freq (TYPE): Description 23 | func (TYPE): Description 24 | """ 25 | 26 | def __init__(self, func): 27 | """Summary 28 | 29 | Args: 30 | func (TYPE): Description 31 | """ 32 | # wraps(func)(self) 33 | self.__name__ = func.__name__ 34 | self.__module__ = func.__module__ 35 | self.__doc__ = func.__doc__ 36 | self.func = func 37 | self.class_name = "" 38 | 39 | # used to calculate average run time 40 | self.call_count = 0 41 | self.acc_time = 0.0 42 | self.freq = freq 43 | 44 | def __call__(self, *args, **kwargs): 45 | """Summary 46 | 47 | Args: 48 | *args (TYPE): Description 49 | **kwargs (TYPE): Description 50 | 51 | Returns: 52 | TYPE: Description 53 | """ 54 | start = time.time() 55 | # args contains the caller itself 56 | # print args[0] and you will see 57 | result = self.func(*args, **kwargs) 58 | end = time.time() 59 | 60 | self.call_count += 1 61 | self.acc_time += (end - start) 62 | if self.call_count > 0 and self.call_count % self.freq == 0: 63 | print("Avg time cost of %s%s is %f", 64 | self.class_name, self.__name__, 65 | self.acc_time / self.call_count) 66 | self.call_count = 0 67 | self.acc_time = 0.0 68 | 69 | return result 70 | 71 | def __get__(self, instance, cls): 72 | """Summary 73 | 74 | Args: 75 | instance (TYPE): Description 76 | cls (TYPE): Description 77 | 78 | Returns: 79 | TYPE: Description 80 | """ 81 | if instance is None: 82 | return self 83 | else: 84 | self.class_name = "%s." % type(instance).__name__ 85 | # bind this method to instance 86 | return types.MethodType(self, instance) 87 | 88 | return time_cls 89 | 90 | 91 | @time_it(freq=2) 92 | def test(): 93 | """Summary 94 | """ 95 | print("abc") 96 | 97 | 98 | if __name__ == '__main__': 99 | test() 100 | test() 101 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------