├── data └── tmall │ ├── process_data_format1.py │ ├── process_data_format2.py │ └── readme.md ├── datamanager.py ├── maml.py ├── matcher.py ├── networks.py ├── readme.md ├── run.py └── utils.py /data/tmall/process_data_format1.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import random 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | '''1. count item''' 7 | # log = pd.read_csv("./data_format1/user_log_format1.csv", nrows=None, usecols=["user_id", "item_id", "time_stamp"]) 8 | # item_gp = log.drop_duplicates().groupby("item_id").size().reset_index().rename(columns={0:"cnt"}).sort_values("cnt", ascending=True) 9 | # item_gp.to_csv("./data_format1/item_cnt_duplicated.csv", index=False) 10 | # item_gp = log.groupby("item_id").size().reset_index().rename(columns={0:"cnt"}).sort_values("cnt", ascending=True) 11 | # item_gp.to_csv("./data_format1/item_cnt.csv", index=False) 12 | 13 | '''2. meta-sequence''' 14 | # log = pd.read_csv("./data_format1/user_log_format1.csv", nrows=None, usecols=["user_id", "item_id", "time_stamp", "action_type"]) 15 | # log["order"] = log.index.values # origin order 16 | 17 | # log = log.sort_values(["user_id", "time_stamp", "order"]).drop(["action_type", "order"], axis=1) 18 | # log = log.drop_duplicates() 19 | 20 | # item_gp = pd.read_csv("./data_format1/item_cnt_duplicated.csv") 21 | # item_gp = item_gp[item_gp.cnt >= 10] 22 | 23 | # cold_items = item_gp.sort_values("cnt").head(int(0.2 * len(item_gp))) 24 | # log = pd.merge(log, cold_items, "left", ["item_id"]).rename(columns={"cnt": "flag"}) 25 | # log["flag"] = log["flag"].fillna(0) 26 | 27 | # fout = open("./data_format1/meta_sequence.txt", "w") 28 | # for user_id, gp in tqdm( log.groupby("user_id") ): 29 | # user_log = gp.reset_index(drop=True) 30 | # index = user_log[user_log.flag != 0].index 31 | # for i in index: 32 | # seq = list(user_log.item_id[:i + 1].values) 33 | # if len(seq) <= 1: 34 | # continue 35 | # label = str(seq[len(seq) - 1]) 36 | # fout.write(str(user_id) + "\t" + label + "\t" + ",".join([str(v) for v in seq]) + "\n") 37 | # fout.close() 38 | 39 | # log = pd.read_csv("./data_format1/user_log_format1.csv", nrows=None, usecols=["user_id", "item_id", "time_stamp", "action_type"]) 40 | # log["order"] = log.index.values # origin order 41 | 42 | # log = log.sort_values(["user_id", "time_stamp", "order"]).drop(["action_type", "order"], axis=1) 43 | # log = log.drop_duplicates() 44 | 45 | # item_gp = pd.read_csv("./data_format1/item_cnt_duplicated.csv") 46 | # log = pd.merge(log, item_gp, "left", ["item_id"]).rename(columns={"cnt": "item_cnt"}) 47 | # print(log) 48 | 49 | # item_gp = item_gp[item_gp.cnt >= 10] 50 | # cold_items = item_gp.sort_values("cnt").head(int(0.2 * len(item_gp))) 51 | # log = pd.merge(log, cold_items, "left", ["item_id"]).rename(columns={"cnt": "flag"}) 52 | 53 | # log["flag"] = log["flag"].fillna(0) 54 | # log["flag1"] = log.apply(lambda x: 1 if x["flag"] > 0 or x["item_cnt"] <= 10 else 0, axis=1) 55 | 56 | # # print(len(log[log.flag > 0]), len(log[(log.flag > 0) | (log.item_cnt <= 10)]), len(log) ) 57 | 58 | # gp = log.groupby("user_id")["flag1"].sum().reset_index() 59 | 60 | 61 | '''3. train-test split''' 62 | 63 | meta_sequence = pd.read_csv("./data_format1/meta_sequence.txt", sep="\t", header=None, nrows=None) 64 | meta_sequence.columns = ["user_id", "item_id", "seq"] 65 | 66 | cold_items = meta_sequence.item_id.unique() 67 | 68 | random.seed(2021) 69 | random.shuffle(cold_items) 70 | 71 | cold_items_train = pd.DataFrame({"item_id": cold_items[:int(0.7 * len(cold_items))]}) 72 | cold_items_val = pd.DataFrame({"item_id": cold_items[int(0.7 * len(cold_items)):int(0.8 * len(cold_items))]}) 73 | cold_items_test = pd.DataFrame({"item_id": cold_items[int(0.8 * len(cold_items)):]}) 74 | print(len(cold_items_train), len(cold_items_val), len(cold_items_test)) 75 | 76 | pd.merge(meta_sequence, cold_items_train, "inner", "item_id").to_csv("./data_format1/meta_sequence_train.txt", sep="\t", header=None, index=False) 77 | pd.merge(meta_sequence, cold_items_val, "inner", "item_id").to_csv("./data_format1/meta_sequence_val.txt", sep="\t", header=None, index=False) 78 | pd.merge(meta_sequence, cold_items_test, "inner", "item_id").to_csv("./data_format1/meta_sequence_test.txt", sep="\t", header=None, index=False) -------------------------------------------------------------------------------- /data/tmall/process_data_format2.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SongDark/Mecos-tf/40a7799b0567c67c66e2be05e7c1f492bcb54f4f/data/tmall/process_data_format2.py -------------------------------------------------------------------------------- /data/tmall/readme.md: -------------------------------------------------------------------------------- 1 | # Tmall Dataset 2 | 3 | You can download the dataset from https://tianchi.aliyun.com/dataset/dataDetail?dataId=47 4 | 5 | Pre-processing stage follows this paper: [Personalized Top-N Sequential Recommendation via Convolutional Sequence Embedding](https://arxiv.org/abs/1809.07426). 6 | 7 | The authors provide the source code at https://github.com/graytowne/caser_pytorch . 8 | 9 | Shop Info and User Behavior data from IJCAI-15. 10 | 11 | | file | user | item | cat | merchant | brand | 12 | |--|--|--|--|--|--| 13 | | train_format1 | 212062 | | |1993 | | 14 | | test_format1 | 212108 | | |1993 | | 15 | | user_log_format1 | 424170 | 1090390 | 1658 | 4995 | 8444 | -------------------------------------------------------------------------------- /datamanager.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import random 4 | from tqdm import tqdm 5 | from tensorflow.python.keras.preprocessing.sequence import pad_sequences 6 | 7 | class RandomSample: 8 | '''fot test''' 9 | def __init__(self, batch_size=1, n_ways=128, k_shots=3, q_query=1) -> None: 10 | 11 | self.n_ways = n_ways 12 | self.k_shots = k_shots 13 | self.q_query = q_query 14 | self.batch_size = batch_size 15 | 16 | def get_one_meta_batch(self): 17 | 18 | meta_batchsize = self.n_ways * self.k_shots 19 | maxlen = 10 20 | 21 | support_seqs = np.random.randint(0, 10000, (self.batch_size,meta_batchsize, maxlen), dtype=np.int32) 22 | support_lens = np.ones((self.batch_size,meta_batchsize,1)) * maxlen 23 | support_labels = np.random.randint(0, self.n_ways, (self.batch_size,meta_batchsize,)) 24 | 25 | query_seqs = np.random.randint(0, 10000, (self.batch_size,self.n_ways, maxlen), dtype=np.int32) 26 | query_lens = np.ones((self.batch_size,self.n_ways,1)) * maxlen 27 | query_labels = np.random.randint(0, self.n_ways, (self.batch_size,self.n_ways,)) 28 | 29 | yield support_seqs, support_lens, support_labels, \ 30 | query_seqs, query_lens, query_labels 31 | 32 | class Tmall: 33 | 34 | def __init__(self, data_path, batch_size=1, n_ways=128, k_shots=3, q_query=1): 35 | 36 | self.n_ways = n_ways 37 | self.k_shots = k_shots 38 | self.q_query = q_query 39 | self.batch_size = batch_size 40 | 41 | self.dataset = {} 42 | 43 | df = pd.read_csv(data_path, sep="\t", header=None, usecols=[1,2], nrows=None) 44 | df.columns = ["label", "seq"] 45 | for label, seq in tqdm(df.values): 46 | if label not in self.dataset: 47 | self.dataset[label] = [] 48 | self.dataset[label].append([int(v) for v in seq.split(",")]) 49 | self.steps = int(len(df) // (batch_size * n_ways * k_shots) ) 50 | del df 51 | 52 | self.items = list(self.dataset.keys()) 53 | 54 | def get_one_meta_task(self): 55 | 56 | chosen_items = random.sample(self.items, self.n_ways) 57 | 58 | support_seqs, support_lens, support_labels = [], [], [] 59 | query_seqs, query_lens, query_labels = [], [], [] 60 | 61 | for label, chosen_item in enumerate(chosen_items): 62 | while len(self.dataset[chosen_item]) < self.k_shots + self.q_query: 63 | chosen_item = random.sample(self.items, 1)[0] 64 | 65 | seqs = random.sample(self.dataset[chosen_item], self.k_shots + self.q_query) 66 | for i in range(len(seqs)): 67 | if len(seqs[i]) > 64: 68 | seqs[i] = seqs[i][-64:] 69 | 70 | for i in range(self.k_shots): 71 | support_seqs.append(seqs[i]) 72 | support_lens.append(len(seqs[i])) 73 | support_labels.append(label) 74 | 75 | for i in range(self.k_shots, self.k_shots + self.q_query): 76 | query_seqs.append(seqs[i][:-1]) 77 | query_lens.append(len(seqs[i]) - 1) 78 | query_labels.append(label) 79 | 80 | support_index = list(range(len(support_seqs))) 81 | random.shuffle(support_index) 82 | support_seqs = [support_seqs[i] for i in support_index] 83 | support_lens = [support_lens[i] for i in support_index] 84 | support_labels = [support_labels[i] for i in support_index] 85 | 86 | query_index = list(range(len(query_seqs))) 87 | random.shuffle(query_index) 88 | query_seqs = [query_seqs[i] for i in query_index] 89 | query_lens = [query_lens[i] for i in query_index] 90 | query_labels = [query_labels[i] for i in query_index] 91 | 92 | support_seqs = pad_sequences(support_seqs, padding="post") 93 | support_lens = np.expand_dims(np.array(support_lens), -1) 94 | support_labels = np.array(support_labels) 95 | 96 | query_seqs = pad_sequences(query_seqs, padding="post") 97 | query_lens = np.expand_dims(np.array(query_lens), -1) 98 | query_labels = np.array(query_labels) 99 | 100 | return support_seqs, support_lens, support_labels,\ 101 | query_seqs, query_lens, query_labels 102 | 103 | def get_one_meta_batch(self): 104 | 105 | meta_support_seqs, meta_support_lens, meta_support_labels = [], [], [] 106 | meta_query_seqs, meta_query_lens, meta_query_labels = [], [], [] 107 | 108 | for _ in range(self.batch_size): 109 | support_seqs, support_lens, support_labels,\ 110 | query_seqs, query_lens, query_labels = self.get_one_meta_task() 111 | 112 | # print(support_seqs.shape) 113 | 114 | meta_support_seqs.append(support_seqs) 115 | meta_support_lens.append(support_lens) 116 | meta_support_labels.append(support_labels) 117 | 118 | meta_query_seqs.append(query_seqs) 119 | meta_query_lens.append(query_lens) 120 | meta_query_labels.append(query_labels) 121 | 122 | yield np.array(meta_support_seqs), np.array(meta_support_lens), np.array(meta_support_labels), \ 123 | np.array(meta_query_seqs), np.array(meta_query_lens), np.array(meta_query_labels) 124 | 125 | # data = RandomSample() 126 | # x1,x2,x3,x4,x5,x6 = next(data.get_one_meta_batch()) 127 | # print(x1.shape) 128 | # print(x2.shape) 129 | # print(x3.shape) 130 | # print(x4.shape) 131 | # print(x5.shape) 132 | # print(x6.shape) 133 | 134 | # data = Tmall(batch_size=2, n_ways=10, k_shots=1, q_query=1) 135 | 136 | # # x1,x2,x3,x4,x5,x6 = data.get_one_meta_task() 137 | # x1,x2,x3,x4,x5,x6 = next(data.get_one_meta_batch()) 138 | 139 | # print(x1.shape) 140 | # print(x2.shape) 141 | # print(x3.shape) 142 | # print(x4.shape) 143 | # print(x5.shape) 144 | # print(x6.shape) 145 | 146 | # print(list(x1[0])) 147 | # print(x2[0]) -------------------------------------------------------------------------------- /maml.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | import tensorflow as tf 4 | from tensorflow.python.keras.losses import sparse_categorical_crossentropy 5 | from tensorflow.python.keras.layers import Input 6 | from networks import Mecos 7 | import numpy as np 8 | 9 | 10 | class MAML: 11 | 12 | def __init__(self, 13 | n_ways, k_shots, 14 | vocabulary_size, 15 | embedding_size, 16 | matching_steps 17 | ): 18 | 19 | self.n_ways = n_ways 20 | self.k_shots = k_shots 21 | self.vocabulary_size = vocabulary_size 22 | self.embedding_size = embedding_size 23 | self.matching_steps = matching_steps 24 | 25 | self.meta_model = self.build_model() 26 | 27 | self.inner_writer_step = 0 28 | self.outer_writer_step = 0 29 | 30 | def build_model(self): 31 | 32 | maxlen = None 33 | meta_batchsize = self.n_ways * self.k_shots 34 | 35 | support_seqs = Input(shape=(meta_batchsize, maxlen,), dtype=tf.int32, name="support_seqs") 36 | support_lens = Input(shape=(meta_batchsize,1,), dtype=tf.int32, name="support_lens") 37 | support_labels = Input(shape=(meta_batchsize,), dtype=tf.int32, name="support_labels") 38 | query_seqs = Input(shape=(self.n_ways, maxlen,), dtype=tf.int32, name="query_seqs") 39 | query_lens = Input(shape=(self.n_ways,1,), dtype=tf.int32, name="query_lens") 40 | # query_labels = Input(shape=(), dtype=tf.int32, batch_size=self.n_ways, name="query_labels") 41 | 42 | mecos = Mecos(n_ways=self.n_ways, matching_steps=self.matching_steps, k_shots=self.k_shots, vocabulary_size=self.vocabulary_size, embedding_size=self.embedding_size) 43 | logits = mecos([support_seqs, support_lens, support_labels, query_seqs, query_lens]) 44 | 45 | model = tf.keras.Model(inputs=[support_seqs, support_lens, support_labels, query_seqs, query_lens], 46 | outputs=[logits]) 47 | return model 48 | 49 | def train_on_meta_batch(self, train_tasks_iterator, inner_optimizer=None, inner_step=1, outer_optimizer=None, writer=None): 50 | 51 | meta_support_seqs, meta_support_seqlens, meta_support_labels, \ 52 | meta_query_seqs, meta_query_seqlens, meta_query_labels = next(train_tasks_iterator) 53 | 54 | for support_seqs, support_seqlens, support_labels, \ 55 | query_seqs, query_seqlens, query_labels in zip(meta_support_seqs, meta_support_seqlens, meta_support_labels, meta_query_seqs, meta_query_seqlens, meta_query_labels): 56 | 57 | support_seqs = np.expand_dims(support_seqs, 0) 58 | support_seqlens = np.expand_dims(support_seqlens, 0) 59 | support_labels = np.expand_dims(support_labels, 0) 60 | query_seqs = np.expand_dims(query_seqs, 0) 61 | query_seqlens = np.expand_dims(query_seqlens, 0) 62 | query_labels = np.expand_dims(query_labels, 0) 63 | 64 | ''' 65 | Single Task: 66 | support_seqs: N x K x seqlen 67 | support_seqlens: N x K x 1 68 | support_labels: N x K x 1 69 | query_seqs: N x seqlen 70 | query_seqlens: N x 1 71 | query_labels: N x 1 72 | ''' 73 | task_tape = tf.GradientTape() 74 | 75 | losses = [] 76 | accs = [] 77 | for _ in range(inner_step): 78 | with task_tape as tape: 79 | logits = self.meta_model([support_seqs, support_seqlens, support_labels, query_seqs, query_seqlens]) 80 | loss = tf.reduce_mean(sparse_categorical_crossentropy(query_labels, logits)) 81 | acc = (np.argmax(logits, -1) == query_labels).astype(np.int32).mean() 82 | 83 | losses.append(loss) 84 | accs.append(acc) 85 | 86 | if writer is not None: 87 | with writer.as_default(): 88 | tf.summary.scalar("loss", loss, step=self.inner_writer_step) 89 | tf.summary.scalar("acc", acc, step=self.inner_writer_step) 90 | self.inner_writer_step += 1 91 | 92 | # Update 93 | with task_tape as tape: 94 | if outer_optimizer is not None: 95 | grads = tape.gradient(tf.reduce_sum(losses), self.meta_model.trainable_variables) 96 | outer_optimizer.apply_gradients(zip(grads, self.meta_model.trainable_variables)) 97 | 98 | return np.array(losses), np.array(accs) 99 | -------------------------------------------------------------------------------- /matcher.py: -------------------------------------------------------------------------------- 1 | # encoding:utf-8 2 | 3 | ''' 4 | LSTM Matching Network 5 | Match between Q and S 6 | h_t, C_t = LSTM(Q, [h_(t-1), S], C_(t-1)) 7 | with input Q, hidden state [h_(t-1), S], and cell state C_t 8 | 9 | This is a keras version 10 | ''' 11 | 12 | import numpy as np 13 | import tensorflow as tf 14 | from tensorflow.python.keras import backend as K 15 | from tensorflow.python.keras.layers import Layer, RNN 16 | from tensorflow.python.keras import initializers, activations 17 | 18 | class MinimalRNNCell(Layer): 19 | def __init__(self, units, **kwargs): 20 | self.units = units 21 | self.state_size = units 22 | super(MinimalRNNCell, self).__init__(**kwargs) 23 | 24 | def build(self, input_shape): 25 | self.kernel = self.add_weight( 26 | shape=(input_shape[-1], self.units), 27 | initializer="glorot_uniform", 28 | dtype=tf.float32, trainable=True, 29 | name="kernel") 30 | self.recurrent_kernel = self.add_weight( 31 | shape=(self.units, self.units), 32 | initializer="glorot_uniform", 33 | dtype=tf.float32, trainable=True, 34 | name="recurrent_kernel") 35 | super(MinimalRNNCell, self).build(input_shape) 36 | 37 | def call(self, inputs, states): 38 | prev_output = states[0] 39 | h = tf.tensordot(inputs, self.kernel, axes=(-1, 0)) 40 | output = h + tf.tensordot(prev_output, self.recurrent_kernel, axes=(-1, 0)) 41 | return output, [output] 42 | 43 | class MinimalLSTMCell(Layer): 44 | def __init__(self, units, **kwargs): 45 | self.units = units 46 | # Control the output size 47 | self.state_size = [units, units] 48 | self.output_size = units 49 | 50 | self.activation = activations.get("tanh") 51 | self.recurrent_activation = activations.get("hard_sigmoid") 52 | super(MinimalLSTMCell, self).__init__(**kwargs) 53 | 54 | def build(self, input_shape): 55 | input_dim = input_shape[-1] 56 | self.kernel = self.add_weight( 57 | shape=(input_dim, self.units * 4), 58 | initializer="glorot_uniform", 59 | dtype=tf.float32, trainable=True, 60 | name="kernel" 61 | ) 62 | self.bias = self.add_weight( 63 | shape=(self.units * 4, ), 64 | initializer="Zeros", 65 | dtype=tf.float32, trainable=True, 66 | name="bias" 67 | ) 68 | self.recurrent_kernel = self.add_weight( 69 | shape=(self.units, self.units * 4), 70 | initializer="glorot_uniform", 71 | dtype=tf.float32, trainable=True, 72 | name="recurrent_kernel" 73 | ) 74 | super(MinimalLSTMCell, self).build(input_shape) 75 | 76 | def call(self, inputs, states): 77 | h_tm1, c_tm1 = states 78 | 79 | inputs_i, inputs_f, inputs_c, inputs_o = inputs, inputs, inputs, inputs 80 | W_xi, W_xf, W_xc, W_xo = tf.split(self.kernel, num_or_size_splits=4, axis=1) 81 | b_i, b_f, b_c, b_o = tf.split(self.bias, num_or_size_splits=4, axis=0) 82 | h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1, h_tm1, h_tm1, h_tm1 83 | 84 | x_i = tf.nn.bias_add(tf.tensordot(inputs_i, W_xi, axes=(-1, 0)), b_i) 85 | x_f = tf.nn.bias_add(tf.tensordot(inputs_f, W_xf, axes=(-1, 0)), b_f) 86 | x_c = tf.nn.bias_add(tf.tensordot(inputs_c, W_xc, axes=(-1, 0)), b_c) 87 | x_o = tf.nn.bias_add(tf.tensordot(inputs_o, W_xo, axes=(-1, 0)), b_o) 88 | 89 | i = self.recurrent_activation(x_i + tf.tensordot(h_tm1_i, self.recurrent_kernel[:, :self.units], axes=(-1, 0)) ) 90 | f = self.recurrent_activation(x_f + tf.tensordot(h_tm1_f, self.recurrent_kernel[:, self.units: self.units * 2], axes=(-1, 0)) ) 91 | c = f * c_tm1 + i * self.activation(x_c + tf.tensordot(h_tm1_c, self.recurrent_kernel[:, self.units * 2: self.units * 3], axes=(-1, 0)) ) 92 | o = self.recurrent_activation(x_o + tf.tensordot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:], axes=(-1, 0)) ) 93 | 94 | h = o * self.activation(c) 95 | 96 | return h, [h, c] 97 | 98 | from tensorflow.python.training.tracking import data_structures 99 | class CustomLSTMCell(Layer): 100 | def __init__(self, units, **kwargs): 101 | self.units = units 102 | # Control the output size 103 | self.state_size = [units, units] 104 | self.output_size = units 105 | 106 | self.activation = activations.get("tanh") 107 | self.recurrent_activation = activations.get("hard_sigmoid") 108 | 109 | super(CustomLSTMCell, self).__init__(**kwargs) 110 | 111 | def build(self, input_shape): 112 | # (input_dim + 1 + units * 2) * (units * 4) 113 | input_dim = input_shape[-1] 114 | self.kernel = self.add_weight( 115 | shape=(input_dim, self.units * 4), 116 | initializer="glorot_uniform", 117 | dtype=tf.float32, trainable=True, 118 | name="kernel" 119 | ) 120 | self.bias = self.add_weight( 121 | shape=(self.units * 4, ), 122 | initializer="Zeros", 123 | dtype=tf.float32, trainable=True, 124 | name="bias" 125 | ) 126 | self.recurrent_kernel = self.add_weight( 127 | shape=(self.units, self.units * 4), 128 | initializer="glorot_uniform", 129 | dtype=tf.float32, trainable=True, 130 | name="recurrent_kernel" 131 | ) 132 | self.additional_kernel = self.add_weight( 133 | shape=(self.units, self.units * 4), 134 | initializer="glorot_uniform", 135 | dtype=tf.float32, trainable=True, 136 | name="additional_kernel" 137 | ) 138 | super(CustomLSTMCell, self).build(input_shape) 139 | 140 | def call(self, inputs, states, additional_states): 141 | h_tm1, c_tm1 = states 142 | s = additional_states 143 | 144 | inputs_i, inputs_f, inputs_c, inputs_o = inputs, inputs, inputs, inputs 145 | W_xi, W_xf, W_xc, W_xo = tf.split(self.kernel, num_or_size_splits=4, axis=1) 146 | b_i, b_f, b_c, b_o = tf.split(self.bias, num_or_size_splits=4, axis=0) 147 | h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1, h_tm1, h_tm1, h_tm1 148 | 149 | x_i = tf.nn.bias_add(tf.tensordot(inputs_i, W_xi, axes=(-1, 0)), b_i) 150 | x_f = tf.nn.bias_add(tf.tensordot(inputs_f, W_xf, axes=(-1, 0)), b_f) 151 | x_c = tf.nn.bias_add(tf.tensordot(inputs_c, W_xc, axes=(-1, 0)), b_c) 152 | x_o = tf.nn.bias_add(tf.tensordot(inputs_o, W_xo, axes=(-1, 0)), b_o) 153 | 154 | i = self.recurrent_activation(x_i \ 155 | + tf.tensordot(h_tm1_i, self.recurrent_kernel[:, :self.units], axes=(-1, 0)) \ 156 | + tf.tensordot(s, self.additional_kernel[:, :self.units], axes=(-1, 0)) ) 157 | f = self.recurrent_activation(x_f \ 158 | + tf.tensordot(h_tm1_f, self.recurrent_kernel[:, self.units: self.units * 2], axes=(-1, 0)) \ 159 | + tf.tensordot(s, self.additional_kernel[:, self.units: self.units * 2], axes=(-1, 0)) ) 160 | c = f * c_tm1 + i * self.activation(x_c \ 161 | + tf.tensordot(h_tm1_c, self.recurrent_kernel[:, self.units * 2: self.units * 3], axes=(-1, 0)) \ 162 | + tf.tensordot(s, self.additional_kernel[:, self.units * 2: self.units * 3], axes=(-1, 0)) ) 163 | o = self.recurrent_activation(x_o \ 164 | + tf.tensordot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:], axes=(-1, 0)) \ 165 | + tf.tensordot(s, self.additional_kernel[:, self.units * 3:], axes=(-1, 0)) ) 166 | 167 | h = o * self.activation(c) 168 | 169 | return h, [h, c] 170 | 171 | 172 | from tensorflow.python.ops import state_ops 173 | from tensorflow.python.util import nest 174 | from tensorflow.python.keras.utils import generic_utils 175 | from tensorflow.python.keras.layers.recurrent import _standardize_args 176 | from tensorflow.python.keras.layers.recurrent import StackedRNNCells 177 | from tensorflow.python.keras.engine.input_spec import InputSpec 178 | 179 | class CustomRNN(tf.keras.layers.RNN): 180 | def __init__(self, **kwargs): 181 | super(CustomRNN, self).__init__(**kwargs) 182 | 183 | def __call__(self, inputs, additional_state, initial_state=None, constants=None, **kwargs): 184 | 185 | inputs, initial_state, constants = _standardize_args(inputs, 186 | initial_state, 187 | constants, 188 | self._num_constants) 189 | 190 | if initial_state is None and constants is None: 191 | # return super(CustomRNN, self).__call__([inputs, additional_state], **kwargs) 192 | return Layer.__call__(self, inputs, additional_state, **kwargs) 193 | 194 | # If any of `initial_state` or `constants` are specified and are Keras 195 | # tensors, then add them to the inputs and temporarily modify the 196 | # input_spec to include them. 197 | 198 | additional_inputs = [] 199 | additional_specs = [] 200 | if initial_state is not None: 201 | additional_inputs += initial_state 202 | self.state_spec = nest.map_structure( 203 | lambda s: InputSpec(shape=K.int_shape(s)), initial_state) 204 | additional_specs += self.state_spec 205 | if constants is not None: 206 | additional_inputs += constants 207 | self.constants_spec = [ 208 | InputSpec(shape=K.int_shape(constant)) for constant in constants 209 | ] 210 | self._num_constants = len(constants) 211 | additional_specs += self.constants_spec 212 | # additional_inputs can be empty if initial_state or constants are provided 213 | # but empty (e.g. the cell is stateless). 214 | flat_additional_inputs = nest.flatten(additional_inputs) 215 | is_keras_tensor = K.is_keras_tensor( 216 | flat_additional_inputs[0]) if flat_additional_inputs else True 217 | for tensor in flat_additional_inputs: 218 | if K.is_keras_tensor(tensor) != is_keras_tensor: 219 | raise ValueError('The initial state or constants of an RNN' 220 | ' layer cannot be specified with a mix of' 221 | ' Keras tensors and non-Keras tensors' 222 | ' (a "Keras tensor" is a tensor that was' 223 | ' returned by a Keras layer, or by `Input`)') 224 | 225 | if is_keras_tensor: 226 | # Compute the full input spec, including state and constants 227 | full_input = [inputs] + additional_inputs 228 | if self.built: 229 | # Keep the input_spec since it has been populated in build() method. 230 | full_input_spec = self.input_spec + additional_specs 231 | else: 232 | # The original input_spec is None since there could be a nested tensor 233 | # input. Update the input_spec to match the inputs. 234 | full_input_spec = generic_utils.to_list( 235 | nest.map_structure(lambda _: None, inputs)) + additional_specs 236 | # Perform the call with temporarily replaced input_spec 237 | self.input_spec = full_input_spec 238 | output = super(CustomRNN, self).__call__(full_input, **kwargs) 239 | # Remove the additional_specs from input spec and keep the rest. It is 240 | # important to keep since the input spec was populated by build(), and 241 | # will be reused in the stateful=True. 242 | self.input_spec = self.input_spec[:-len(additional_specs)] 243 | return output 244 | else: 245 | if initial_state is not None: 246 | kwargs['initial_state'] = initial_state 247 | if constants is not None: 248 | kwargs['constants'] = constants 249 | return super(CustomRNN, self).__call__(inputs, **kwargs) 250 | 251 | def call(self, 252 | inputs, 253 | additional_states, 254 | mask=None, 255 | training=None, 256 | initial_state=None, 257 | constants=None): 258 | # The input should be dense, padded with zeros. If a ragged input is fed 259 | # into the layer, it is padded and the row lengths are used for masking. 260 | 261 | inputs, row_lengths = K.convert_inputs_if_ragged(inputs) 262 | is_ragged_input = (row_lengths is not None) 263 | self._validate_args_if_ragged(is_ragged_input, mask) 264 | 265 | inputs, initial_state, constants = self._process_inputs( 266 | inputs, initial_state, constants) 267 | 268 | self._maybe_reset_cell_dropout_mask(self.cell) 269 | if isinstance(self.cell, StackedRNNCells): 270 | for cell in self.cell.cells: 271 | self._maybe_reset_cell_dropout_mask(cell) 272 | 273 | if mask is not None: 274 | # Time step masks must be the same for each input. 275 | # TODO(scottzhu): Should we accept multiple different masks? 276 | mask = nest.flatten(mask)[0] 277 | 278 | if nest.is_nested(inputs): 279 | # In the case of nested input, use the first element for shape check. 280 | input_shape = K.int_shape(nest.flatten(inputs)[0]) 281 | else: 282 | input_shape = K.int_shape(inputs) 283 | timesteps = input_shape[0] if self.time_major else input_shape[1] 284 | if self.unroll and timesteps is None: 285 | raise ValueError('Cannot unroll a RNN if the ' 286 | 'time dimension is undefined. \n' 287 | '- If using a Sequential model, ' 288 | 'specify the time dimension by passing ' 289 | 'an `input_shape` or `batch_input_shape` ' 290 | 'argument to your first layer. If your ' 291 | 'first layer is an Embedding, you can ' 292 | 'also use the `input_length` argument.\n' 293 | '- If using the functional API, specify ' 294 | 'the time dimension by passing a `shape` ' 295 | 'or `batch_shape` argument to your Input layer.') 296 | 297 | kwargs = {} 298 | if generic_utils.has_arg(self.cell.call, 'training'): 299 | kwargs['training'] = training 300 | 301 | # TF RNN cells expect single tensor as state instead of list wrapped tensor. 302 | is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None 303 | # Use the __call__ function for callable objects, eg layers, so that it 304 | # will have the proper name scopes for the ops, etc. 305 | cell_call_fn = self.cell.__call__ if callable(self.cell) else self.cell.call 306 | if constants: 307 | if not generic_utils.has_arg(self.cell.call, 'constants'): 308 | raise ValueError('RNN cell does not support constants') 309 | 310 | def step(inputs, states): 311 | constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type 312 | states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type 313 | 314 | states = states[0] if len(states) == 1 and is_tf_rnn_cell else states 315 | # output, new_states = cell_call_fn( 316 | # inputs, states, constants=constants, **kwargs) 317 | output, new_states = cell_call_fn( 318 | inputs, states, additional_states, constants=constants, **kwargs) 319 | if not nest.is_nested(new_states): 320 | new_states = [new_states] 321 | return output, new_states 322 | else: 323 | 324 | def step(inputs, states): 325 | states = states[0] if len(states) == 1 and is_tf_rnn_cell else states 326 | # output, new_states = cell_call_fn(inputs, states, **kwargs) 327 | output, new_states = cell_call_fn(inputs, states, additional_states, **kwargs) 328 | if not nest.is_nested(new_states): 329 | new_states = [new_states] 330 | return output, new_states 331 | 332 | # inputs = [inputs, additional_states] 333 | last_output, outputs, states = K.rnn( 334 | step, 335 | inputs, 336 | initial_state, 337 | constants=constants, 338 | go_backwards=self.go_backwards, 339 | mask=mask, 340 | unroll=self.unroll, 341 | input_length=row_lengths if row_lengths is not None else timesteps, 342 | time_major=self.time_major, 343 | zero_output_for_mask=self.zero_output_for_mask) 344 | 345 | if self.stateful: 346 | updates = [ 347 | state_ops.assign(self_state, state) for self_state, state in zip( 348 | nest.flatten(self.states), nest.flatten(states)) 349 | ] 350 | self.add_update(updates) 351 | 352 | if self.return_sequences: 353 | output = K.maybe_convert_to_ragged(is_ragged_input, outputs, row_lengths) 354 | else: 355 | output = last_output 356 | 357 | if self.return_state: 358 | if not isinstance(states, (list, tuple)): 359 | states = [states] 360 | else: 361 | states = list(states) 362 | return generic_utils.to_list(output) + states 363 | else: 364 | return output 365 | 366 | class CustomLSTM(tf.keras.layers.RNN): 367 | 368 | def __init__(self, units, **kwargs): 369 | cell = MinimalLSTMCell(units) 370 | super(CustomLSTM, self).__init__(cell) 371 | self.input_spect = [tf.keras.layers.InputSpec(ndim=3)] 372 | 373 | def call(self, inputs): 374 | return super(CustomLSTM, self).call(inputs) 375 | 376 | class LSTMMatcher(Layer): 377 | 378 | def __init__(self, **kwargs): 379 | super(LSTMMatcher, self).__init__(**kwargs) 380 | 381 | def build(self, input_shape): 382 | return super(LSTMMatcher, self).build(input_shape) 383 | 384 | def call(self, inputs, steps, **kwargs): 385 | ''' 386 | inputs: 387 | S: bs x dim 388 | Q: bs x dim 389 | ''' 390 | assert len(inputs) == 2 391 | assert steps > 0 392 | S, Q = inputs 393 | 394 | Q = tf.reshape(tf.tile(Q, multiples=(1, steps)), shape=(-1, steps, tf.shape(Q)[-1])) # bs x steps x dim 395 | 396 | return None 397 | 398 | def get_config(self): 399 | return super(LSTMMatcher, self).get_config() 400 | 401 | 402 | 403 | # # encoding:utf-8 404 | 405 | # import tensorflow as tf 406 | # from tensorflow.python.keras.layers import Layer 407 | # from tensorflow.python.keras.layers.recurrent import LSTMCell, GRUCell, SimpleRNNCell, StackedRNNCells 408 | 409 | # class Matcher(Layer): 410 | 411 | # def __init__(self, 412 | # units = [1, ], 413 | # cell_type="lstm", 414 | # steps = 2, 415 | # **kwargs): 416 | 417 | # super(Matcher, self).__init__(**kwargs) 418 | 419 | # self.cell_type = cell_type 420 | # self.steps = steps 421 | # self.units = units # array 422 | 423 | 424 | # def build(self, input_shape): 425 | 426 | # super(Matcher, self).build(input_shape) 427 | 428 | # if self.cell_type.lower() == "lstm": 429 | # self.core_cell = [tf.compat.v1.nn.rnn_cell.LSTMCell(units) for units in self.units] 430 | # self.cells = tf.compat.v1.nn.rnn_cell.MultiRNNCell(self.core_cell) 431 | # else: 432 | # raise ValueError("bad cell type=%s" % self.cell_type) 433 | 434 | # def call(self, inputs): 435 | 436 | # ''' 437 | # inputs: 438 | # s: N x dim 439 | # q: N x dim 440 | # ''' 441 | # s, q = inputs 442 | # batch_size = tf.shape(inputs[0])[0] 443 | 444 | # eos_time_slice = tf.ones_like(inputs[0], dtype=tf.float32, name="eos") 445 | # pad_time_slice = tf.zeros_like(inputs[0], dtype=tf.float32, name="pad") 446 | 447 | # iteration_steps = tf.multiply(tf.ones((batch_size,)), self.steps) 448 | # iteration_steps = tf.cast(iteration_steps, dtype=tf.int32) 449 | 450 | # def loop_fn_initial(): 451 | 452 | # initial_elements_finished = (iteration_steps <= 0) # All Flase 453 | # initial_input = q 454 | 455 | # # initial_cell_state = [tf.concat([q, s], axis=1)] 456 | # initial_cell_state = [tf.concat([q, s], axis=1)] 457 | # for i in range(1, len(self.units)): 458 | # initial_cell_state.append(self.core_cell[i].zero_state(batch_size, dtype=tf.float32) ) 459 | 460 | # return (initial_elements_finished, 461 | # initial_input, 462 | # tuple(initial_cell_state), 463 | # None, None) 464 | 465 | # def loop_fn_transition(time, cell_output, cell_state, loop_state): 466 | 467 | # _elements_finished = (iteration_steps <= time) 468 | 469 | # _finished = tf.reduce_all(_elements_finished) 470 | # _inputs = tf.cond(_finished, lambda:pad_time_slice, q ) 471 | 472 | # _states = tf.concat() 473 | # _outputs = cell_output 474 | # _loop_state = None 475 | # return (_elements_finished, 476 | # _inputs, 477 | # _states, 478 | # _outputs, 479 | # _loop_state) 480 | 481 | # def loop_fn(time, cell_output, cell_state, loop_state): 482 | # if cell_state is None: 483 | # return loop_fn_initial() 484 | # else: 485 | # return loop_fn_transition(time, cell_output, cell_state, loop_state) 486 | 487 | # # with tf.variable_scope("matcher"): 488 | # outputs_ta, final_state, _ = tf.compat.v1.nn.raw_rnn(self.cells, loop_fn) 489 | # outputs = outputs_ta.stack() 490 | 491 | # return outputs, final_state 492 | 493 | # def get_config(self): 494 | # config = { 495 | # } 496 | # base_config = super(Matcher, self).get_config() 497 | # return dict(list(base_config.items()) + list(config.items())) 498 | 499 | 500 | 501 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | 3 | import tensorflow as tf 4 | from tensorflow.python.keras import initializers, regularizers 5 | from tensorflow.python.keras.layers import Layer, LSTM, RNN 6 | from utils import SequenceEncoder, Aggregator 7 | # from lstmcell import TestLSTMCell 8 | # from recurrent import TestRNN 9 | from matcher import CustomLSTMCell, CustomRNN 10 | 11 | class Mecos(Layer): 12 | 13 | def __init__(self, 14 | n_ways, k_shots, 15 | vocabulary_size, 16 | embedding_size, 17 | matching_steps, 18 | embeddings_initializer="glorot_normal", 19 | embeddings_regularizer=None, 20 | **kwargs): 21 | 22 | self.n_ways = n_ways 23 | self.k_shots = k_shots 24 | self.matching_steps = matching_steps 25 | self.vocabulary_size = vocabulary_size 26 | self.embedding_size = embedding_size 27 | self.embeddings_initializer = initializers.get(embeddings_initializer) 28 | self.embeddings_regularizer = regularizers.get(embeddings_regularizer) 29 | 30 | super(Mecos, self).__init__(**kwargs) 31 | 32 | def build(self, input_shape): 33 | 34 | self.item_embeddings = self.add_weight( 35 | name="", 36 | shape=(self.vocabulary_size, self.embedding_size), 37 | initializer=self.embeddings_initializer, 38 | regularizer=self.embeddings_regularizer, 39 | dtype=tf.float32, trainable=True 40 | ) 41 | 42 | self.sequence_encoder = SequenceEncoder(feedforword_layers=2, name="seq_enc") 43 | 44 | # cell = TestLSTMCell(input_shape[0][-1]) 45 | # self.lstm = TestRNN(cell=cell, return_state=True) 46 | 47 | self.lstm_cell = CustomLSTMCell(self.embedding_size * 2) 48 | self.lstm_matcher = CustomRNN(cell=self.lstm_cell) 49 | 50 | super(Mecos, self).build(input_shape) 51 | 52 | def call(self, inputs): 53 | ''' 54 | inputs: 55 | support seq: bs x (N x K) x maxlen 56 | support len: bs x (N x K) x 1 57 | support labels: bs x (N x K) 58 | query seq: bs x N x maxlen 59 | query len: bs x N x 1 60 | query labels: bs x N 61 | outputs 62 | ''' 63 | 64 | support_seqs, support_lens, support_labels, query_seqs, query_lens = inputs 65 | 66 | support_seqs = tf.nn.embedding_lookup(self.item_embeddings, support_seqs) # bs x (N x K) x maxlen x dim 67 | support_labels = tf.nn.embedding_lookup(self.item_embeddings, support_labels) # bs x (N x K) x dim 68 | query_seqs = tf.nn.embedding_lookup(self.item_embeddings, query_seqs) # bs x N x maxlen x dim 69 | 70 | support_embs = self.sequence_encoder([support_seqs, support_lens, support_labels]) # bs x (N x K) x (2xdim) 71 | support_embs = tf.reshape(support_embs, (-1, self.n_ways, self.k_shots, self.embedding_size * 2)) # bs x N x K x (2*dim) 72 | 73 | # aggregation for S 74 | support_embs = tf.reduce_mean(support_embs, axis=-2) # bs x N x (2xdim) 75 | 76 | # Q 77 | query_embs = self.sequence_encoder([query_seqs, query_lens]) # bs x N x (2xdim) 78 | 79 | # matching 80 | # support_embs = lstm_encoder([support_embs, query_embs]) 81 | query_embs = tf.tile(query_embs, [1, 1, self.matching_steps]) 82 | query_embs = tf.reshape(query_embs, (-1, self.matching_steps, self.embedding_size * 2)) 83 | support_embs = tf.reshape(support_embs, (-1, self.embedding_size * 2)) 84 | query_embs = self.lstm_matcher(query_embs, additional_state=support_embs) 85 | query_embs = tf.reshape(query_embs, (-1, self.n_ways, self.embedding_size * 2)) 86 | support_embs = tf.reshape(support_embs, (-1, self.n_ways, self.embedding_size * 2)) 87 | 88 | # cos 89 | support_embs = tf.tile(support_embs, [1, self.n_ways, 1]) 90 | query_embs = tf.reshape(tf.tile(query_embs, [1, 1, self.n_ways]), (-1, self.n_ways**2, self.embedding_size*2)) 91 | 92 | support_embs = tf.nn.l2_normalize(support_embs, axis=-1) 93 | query_embs = tf.nn.l2_normalize(query_embs, axis=-1) 94 | 95 | cos_similarity = tf.reduce_sum(tf.multiply(query_embs, support_embs), axis=-1) 96 | cos_similarity = tf.reshape(cos_similarity, (-1, self.n_ways, self.n_ways)) 97 | 98 | outputs = tf.nn.softmax(cos_similarity, axis=-1) 99 | 100 | return outputs 101 | 102 | def get_config(self): 103 | config = {} 104 | base_config = super(Mecos, self).get_config() 105 | return dict(list(base_config.items()) + list(config.items())) 106 | 107 | # bs = 3 108 | # n_ways = 10 109 | # k_shots = 3 110 | # vocabulary_size = 10000 111 | # embedding_size = 32 112 | # meta_batchsize = n_ways * k_shots 113 | 114 | # # seqs = tf.ones([meta_batchsize, 32, 1]) 115 | # # maxlen = tf.multiply(tf.ones((meta_batchsize, 1)), 32) 116 | # # labels = tf.ones([meta_batchsize, 1]) 117 | 118 | # maxlen = None 119 | # support_seqs = tf.keras.layers.Input(shape=(meta_batchsize, maxlen,), batch_size=bs, dtype=tf.int32, name="support_seqs") 120 | # support_lens = tf.keras.layers.Input(shape=(meta_batchsize, 1,), batch_size=bs, dtype=tf.int32, name="support_lens") 121 | # support_labels = tf.keras.layers.Input(shape=(meta_batchsize,), batch_size=bs, dtype=tf.int32, name="support_labels") 122 | # query_seqs = tf.keras.layers.Input(shape=(n_ways, maxlen,), dtype=tf.int32, batch_size=bs) 123 | # query_lens = tf.keras.layers.Input(shape=(n_ways, 1,), dtype=tf.int32, batch_size=bs) 124 | # query_labels = tf.keras.layers.Input(shape=(n_ways,), dtype=tf.int32, batch_size=bs) 125 | 126 | # # encoder = SequenceEncoder(feedforword_layers=2, name="my") 127 | # mecos = Mecos(n_ways=n_ways, k_shots=k_shots, matching_steps=2, vocabulary_size=vocabulary_size, embedding_size=embedding_size) 128 | 129 | # cos = mecos([support_seqs, support_lens, support_labels, query_seqs, query_lens]) 130 | # print(cos.shape) 131 | # model = tf.keras.Model(inputs=[support_seqs, support_lens, support_labels, query_seqs, query_lens], 132 | # outputs=[cos]) 133 | # model.summary() 134 | 135 | 136 | # a = tf.constant([[[1,1,1],[2,2,2]]], tf.int32) 137 | # b = tf.constant([[[3,3,3],[4,4,4]]], tf.int32) 138 | 139 | # a = tf.tile(a, [1,2,1]) 140 | # b = tf.reshape(tf.tile(b, [1,1,2]), a.shape) 141 | # print(a) 142 | # print(b) 143 | 144 | # support_embs = tf.constant([[-1,1],[2,2]], dtype=tf.float32) 145 | # query_embs = tf.constant([[3,3],[4,4]], dtype=tf.float32) 146 | 147 | # support_embs = tf.tile(support_embs, [2, 1]) 148 | # query_embs = tf.reshape(tf.tile(query_embs, [1, 2]), support_embs.shape) 149 | 150 | # support_embs = tf.nn.l2_normalize(support_embs, axis=1) 151 | # query_embs = tf.nn.l2_normalize(query_embs, axis=1) 152 | # cos_similarity = tf.reduce_sum(tf.multiply(query_embs, support_embs), axis=1) 153 | 154 | # cos_similarity = tf.reshape(cos_similarity, (2,2)) 155 | # print(cos_similarity) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Meta-learning-based Cold-Start Sequential Recommendation 2 | 3 | Tensorflow implementation of paper: [Yujia Zheng, Siyi Liu, Zekun Li, Shu Wu, Cold-start Sequential Recommendation via Meta Learner, AAAI, 2021](https://arxiv.org/abs/2012.05462) 4 | 5 | The official implementation is not available yet. This version may still contains mistakes. 6 | 7 | ## Settings 8 | 9 | 1. Pre-Training 10 | 11 | Random initialized here. No pre-train. 12 | 13 | 14 | ## Training 15 | 16 | The loss met convergence in a few steps, but the accuracy is quite low. (nearly Zero) 17 | 18 | ![企业微信截图_16219091472275.png](https://i.loli.net/2021/05/25/lmPXadkxWvbD9wh.png) -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | from datamanager import RandomSample, Tmall 6 | from maml import MAML 7 | 8 | maml = MAML(n_ways=128, matching_steps=2, k_shots=3, vocabulary_size=1090391, embedding_size=100) 9 | 10 | train_data = Tmall("./data/tmall/data_format1/meta_sequence_train.txt", batch_size=1, n_ways=128, k_shots=3, q_query=1) 11 | val_data = Tmall("./data/tmall/data_format1/meta_sequence_val.txt", batch_size=1, n_ways=128, k_shots=3, q_query=1) 12 | # train_data.steps = 100 13 | val_data.steps = 100 14 | 15 | writer = tf.summary.create_file_writer("./logs/") 16 | 17 | optimizer = tf.keras.optimizers.Adam(0.0001) 18 | 19 | for epoch in range(10): 20 | 21 | train_progbar = tf.keras.utils.Progbar(train_data.steps) 22 | val_progbar = tf.keras.utils.Progbar(val_data.steps) 23 | 24 | train_meta_loss, train_meta_acc = [], [] 25 | val_meta_loss, val_meta_acc = [], [] 26 | 27 | for i in range(train_data.steps): 28 | loss, acc = maml.train_on_meta_batch(train_data.get_one_meta_batch(), 29 | outer_optimizer=optimizer, 30 | writer=writer) 31 | train_meta_loss.append(loss) 32 | train_meta_acc.append(acc) 33 | 34 | train_progbar.update(i + 1, [("loss", np.mean(train_meta_loss)), ("accuracy", np.mean(train_meta_acc))] ) 35 | 36 | for i in range(val_data.steps): 37 | loss, acc = maml.train_on_meta_batch(val_data.get_one_meta_batch(), 38 | outer_optimizer=None, 39 | writer=None) 40 | val_meta_loss.append(loss) 41 | val_meta_acc.append(acc) 42 | val_progbar.update(i + 1, [("loss", np.mean(val_meta_loss)), ("accuracy", np.mean(val_meta_acc))] ) 43 | 44 | maml.meta_model.save_weights("./models/maml_epoch%d.h5" % epoch) 45 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #encoding:utf-8 2 | 3 | import tensorflow as tf 4 | 5 | from tensorflow.python.keras.layers import Layer 6 | from tensorflow.python.keras import initializers, regularizers 7 | 8 | class SequenceEncoder(Layer): 9 | 10 | def __init__(self, 11 | feedforword_layers=2, 12 | embeddings_initializer=None, 13 | embeddings_regularizer=None, 14 | **kwargs): 15 | 16 | super(SequenceEncoder, self).__init__(**kwargs) 17 | 18 | self.feedforword_layers = feedforword_layers 19 | 20 | self.embeddings_initializer = initializers.get(embeddings_initializer) 21 | self.embeddings_regularizer = regularizers.get(embeddings_regularizer) 22 | 23 | def build(self, input_shape): 24 | emb_size = input_shape[0][-1] 25 | 26 | self.kernel_weights = {} 27 | for key in ["W1", "W2", "W3"]: 28 | self.kernel_weights[key] = \ 29 | self.add_weight( 30 | name=key, shape=(emb_size, emb_size), 31 | initializer=self.embeddings_initializer, 32 | regularizer=self.embeddings_regularizer, 33 | trainable=True, dtype=tf.float32 34 | ) 35 | self.kernel_weights["b"] = \ 36 | self.add_weight( 37 | name="b", shape=(emb_size, ), 38 | initializer="Zeros", 39 | trainable=True, dtype=tf.float32 40 | ) 41 | self.kernel_weights["p"] = \ 42 | self.add_weight( 43 | name="p", shape=(emb_size, ), 44 | initializer=self.embeddings_initializer, 45 | trainable=True, dtype=tf.float32 46 | ) 47 | for l in range(self.feedforword_layers): 48 | self.kernel_weights["Wl_%d" % l] = \ 49 | self.add_weight( 50 | name="Wl_%d" % l, shape=(emb_size * 2, emb_size * 2), 51 | initializer=self.embeddings_initializer, 52 | regularizer=self.embeddings_regularizer, 53 | trainable=True, dtype=tf.float32 54 | ) 55 | self.kernel_weights["bl_%d" % l] = \ 56 | self.add_weight( 57 | name="bl_%d" % l, shape=(emb_size * 2, ), 58 | initializer="Zeros", 59 | trainable=True, dtype=tf.float32 60 | ) 61 | self.kernel_weights["Wq"] = \ 62 | self.add_weight( 63 | name="Wq", shape=(emb_size, emb_size * 2), 64 | initializer=self.embeddings_initializer, 65 | regularizer=self.embeddings_regularizer, 66 | trainable=True, dtype=tf.float32 67 | ) 68 | super(SequenceEncoder, self).build(input_shape) 69 | 70 | def call(self, inputs): 71 | ''' 72 | input: 73 | seqs: bs x (N x K) x maxlen x dim 74 | lens: bs x (N x K) x 1 75 | labels: bs x (N x K) x dim 76 | output: 77 | representation: bs x (N x K) x (dimx2) 78 | ''' 79 | if len(inputs) == 2: 80 | seqs, lens = inputs 81 | labels = None 82 | elif len(inputs) == 3: 83 | seqs, lens, labels = inputs 84 | else: 85 | raise ValueError("wrong size inputs=%d" % len(inputs)) 86 | 87 | # bs x (N x K) x 1 x dim 88 | V_last = tf.expand_dims(tf.tensordot( 89 | seqs[:, :, -1, :], 90 | self.kernel_weights["W1"], axes=(-1, 0)), -2) 91 | # bs x (N x K) x maxlen x dim 92 | V_seq = tf.tensordot( 93 | seqs, 94 | self.kernel_weights["W2"], axes=(-1, 0)) 95 | # bs x (N x K) x 1 x dim 96 | V_avg = tf.expand_dims(tf.tensordot( 97 | tf.divide(tf.reduce_sum(seqs, axis=-2), tf.cast(lens, tf.float32) ), 98 | self.kernel_weights["W3"], axes=(-1, 0)), -2) 99 | 100 | # bs x (N x K) x maxlen 101 | emb = tf.tensordot( 102 | tf.nn.bias_add(V_last + V_seq + V_avg, self.kernel_weights["b"]), 103 | self.kernel_weights["p"], axes=(-1, 0) 104 | ) 105 | 106 | # attention bs x (N x K) x maxlen 107 | attn = tf.nn.softmax(emb, axis=-1) 108 | 109 | # weighted bs x (N x K) x dim 110 | weighted_seqs = tf.reduce_sum(tf.multiply(tf.expand_dims(attn, -1), seqs), axis=-2) 111 | 112 | # feedforward bs x (N x K) x (2 * dim) 113 | if labels is None: 114 | hidden_proj_init = tf.tensordot(weighted_seqs, self.kernel_weights["Wq"], axes=(-1,0)) 115 | else: 116 | hidden_proj_init = tf.concat([weighted_seqs, labels], axis=-1) 117 | 118 | hidden_proj = hidden_proj_init 119 | for l in range(self.feedforword_layers): 120 | hidden_proj = tf.nn.bias_add(tf.tensordot( 121 | hidden_proj, self.kernel_weights["Wl_%d" % l], axes=(-1,0)), self.kernel_weights["bl_%d" % l]) 122 | hidden_proj = tf.nn.relu(hidden_proj) 123 | 124 | # N x 2d 125 | seq_representation = hidden_proj_init + hidden_proj 126 | 127 | return seq_representation 128 | 129 | def get_config(self): 130 | config = { 131 | "embeddings_initializer": self.embeddings_initializer, 132 | "embeddings_regularizer": self.embeddings_regularizer, 133 | "feedforword_layers": self.feedforword_layers 134 | } 135 | base_config = super(SequenceEncoder, self).get_config() 136 | return dict(list(base_config.items()) + list(config.items())) 137 | 138 | class Aggregator(Layer): 139 | 140 | def __init__(self, mode="mean", axis=1, **kwargs): 141 | super(Aggregator, self).__init__(**kwargs) 142 | self.mode = mode 143 | self.axis = axis 144 | self.support_modes = ["max", "mean", "last"] 145 | 146 | def call(self, inputs): 147 | 148 | if self.mode == "mean": 149 | return tf.reduce_mean(inputs, axis=self.axis) 150 | elif self.mode == "max": 151 | return tf.reduce_max(inputs, axis=self.axis) 152 | elif self.mode == "last": 153 | return inputs[:, -1, :] 154 | else: 155 | raise ValueError("fatal aggregator mode=%s" % self.mode) 156 | 157 | def get_config(self): 158 | config = { 159 | "mode": self.mode, 160 | "support_modes": self.support_modes 161 | } 162 | base_config = super(Aggregator, self).get_config() 163 | return dict(list(base_config.items()) + list(config.items())) 164 | 165 | # def compute_output_shape(self, input_shape): 166 | # if self.mode == "mean": 167 | # return (input_shape[0][0], input_shape[0][-1]) 168 | 169 | 170 | 171 | # seqs = tf.ones([2, 10, 32, 7]) 172 | # maxlen = tf.multiply(tf.ones((2, 10, 1)), 32) 173 | # labels = tf.ones([2, 10, 7]) 174 | 175 | # mecos = SequenceEncoder() 176 | # out = mecos([seqs, maxlen, labels]) 177 | # print(out.shape) --------------------------------------------------------------------------------