├── sum_tree.py ├── README.md ├── graph_init.py ├── Graph_generate ├── yelp_graph.py ├── lastfm_graph.py ├── lastfm_data_process.py ├── lastfm_star_data_process.py └── yelp_data_process.py ├── gcn.py ├── utils.py ├── RL ├── RL_evaluate.py ├── env_binary_question.py └── env_enumerated_question.py ├── evaluate.py └── RL_model.py /sum_tree.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | # SumTree 4 | # a binary tree data structure where the parent’s value is the sum of its children 5 | class SumTree(object): 6 | write = 0 7 | 8 | def __init__(self, capacity): 9 | self.capacity = capacity 10 | self.tree = np.zeros(2 * capacity - 1) 11 | self.data = np.zeros(capacity, dtype=object) 12 | self.n_entries = 0 13 | 14 | # update to the root node 15 | def _propagate(self, idx, change): 16 | parent = (idx - 1) // 2 17 | 18 | self.tree[parent] += change 19 | 20 | if parent != 0: 21 | self._propagate(parent, change) 22 | 23 | # find sample on leaf node 24 | def _retrieve(self, idx, s): 25 | left = 2 * idx + 1 26 | right = left + 1 27 | 28 | if left >= len(self.tree): 29 | return idx 30 | 31 | if s <= self.tree[left]: 32 | return self._retrieve(left, s) 33 | else: 34 | return self._retrieve(right, s - self.tree[left]) 35 | 36 | def total(self): 37 | return self.tree[0] 38 | 39 | # store priority and sample 40 | def add(self, p, data): 41 | idx = self.write + self.capacity - 1 42 | 43 | self.data[self.write] = data 44 | self.update(idx, p) 45 | 46 | self.write += 1 47 | if self.write >= self.capacity: 48 | self.write = 0 49 | 50 | if self.n_entries < self.capacity: 51 | self.n_entries += 1 52 | 53 | # update priority 54 | def update(self, idx, p): 55 | change = p - self.tree[idx] 56 | 57 | self.tree[idx] = p 58 | self._propagate(idx, change) 59 | 60 | # get priority and sample 61 | def get(self, s): 62 | idx = self._retrieve(0, s) 63 | dataIdx = idx - self.capacity + 1 64 | 65 | return (idx, self.tree[idx], self.data[dataIdx]) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UNICORN 2 | 3 | The implementation of _Unified Conversational Recommendation Policy Learning via Graph-based Reinforcement Learning_ (SIGIR 2021). 4 | 5 | The code is partially referred to https://cpr-conv-rec.github.io/. 6 | 7 | ## Data Preparation 8 | 1. Please download the datasets "SCPR_Data.zip" from https://cpr-conv-rec.github.io/, including lastfm, lastfm_start, and yelp. (If you would like to use your own dataset, please follow the same data format.) 9 | 2. Upzip "SCPR_Data.zip" and put "data" folder in the path "unicorn/". 10 | 3. Processing data: `python graph_init.py --data_name ` 11 | 4. Use TransE from [[OpenKE](https://github.com/thunlp/OpenKE)] to pretrain the graph embeddings. And put the pretrained embeddings under "unicorn/tmp//embeds/". Or you can directly download the pretrained TransE embeddings from https://drive.google.com/file/d/1qoZMbYCBi2Y4IsJBdJ8Eg6y30Ap0gsQY/view?usp=sharing. 12 | 13 | ## Training 14 | `python RL_model.py --data_name ` 15 | 16 | ## Evaluation 17 | `python evaluate.py --data_name --load_rl_epoch ` 18 | 19 | ## Citation 20 | If the code is used in your research, please star this repo and cite our paper as follows: 21 | ``` 22 | @inproceedings{DBLP:conf/sigir/DengL0DL21, 23 | author = {Yang Deng and 24 | Yaliang Li and 25 | Fei Sun and 26 | Bolin Ding and 27 | Wai Lam}, 28 | title = {Unified Conversational Recommendation Policy Learning via Graph-based 29 | Reinforcement Learning}, 30 | booktitle = {{SIGIR} '21: The 44th International {ACM} {SIGIR} Conference on Research 31 | and Development in Information Retrieval, Virtual Event, Canada, July 32 | 11-15, 2021}, 33 | pages = {1431--1441}, 34 | publisher = {{ACM}}, 35 | year = {2021}, 36 | } 37 | ``` 38 | -------------------------------------------------------------------------------- /graph_init.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | from utils import * 4 | from Graph_generate.lastfm_data_process import LastFmDataset 5 | from Graph_generate.lastfm_star_data_process import LastFmStarDataset 6 | from Graph_generate.lastfm_graph import LastFmGraph 7 | from Graph_generate.yelp_data_process import YelpDataset 8 | from Graph_generate.yelp_graph import YelpGraph 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--data_name', type=str, default=LAST_FM, choices=[LAST_FM, LAST_FM_STAR, YELP, YELP_STAR], 13 | help='One of {LAST_FM, LAST_FM_STAR, YELP, YELP_STAR}.') 14 | args = parser.parse_args() 15 | DatasetDict = { 16 | LAST_FM: LastFmDataset, 17 | LAST_FM_STAR: LastFmStarDataset, 18 | YELP: YelpDataset, 19 | YELP_STAR: YelpDataset, 20 | } 21 | GraphDict = { 22 | LAST_FM: LastFmGraph, 23 | LAST_FM_STAR: LastFmGraph, 24 | YELP: YelpGraph, 25 | YELP_STAR: YelpGraph, 26 | } 27 | 28 | # Create 'data_name' instance for data_name. 29 | print('Load', args.data_name, 'from file...') 30 | print(TMP_DIR[args.data_name]) 31 | if not os.path.isdir(TMP_DIR[args.data_name]): 32 | os.makedirs(TMP_DIR[args.data_name]) 33 | dataset = DatasetDict[args.data_name](DATA_DIR[args.data_name]) 34 | save_dataset(args.data_name, dataset) 35 | print('Save', args.data_name, 'dataset successfully!') 36 | 37 | # Generate graph instance for 'data_name' 38 | print('Create', args.data_name, 'graph from data_name...') 39 | dataset = load_dataset(args.data_name) 40 | kg = GraphDict[args.data_name](dataset) 41 | save_kg(args.data_name, kg) 42 | print('Save', args.data_name, 'graph successfully!') 43 | 44 | 45 | def construct(kg): 46 | users = kg.G['user'].keys() 47 | items = kg.G['item'].keys() 48 | features = kg.G['feature'].keys() 49 | num_node = len(users) + len(items) + len(features) 50 | graph = np.zeros((num_node, num_node)) 51 | for i in range(num_node): 52 | for j in range(num_node): 53 | if i < len(users) and j < len(users)+len(items): 54 | graph[i][j] = 1 55 | graph[j][i] = 1 56 | elif i >= len(users) and i < len(users)+len(items): 57 | if j-len(users)-len(items) in kg.G['item'][i-len(users)]['belong_to']: 58 | graph[i][j] = 1 59 | graph[j][i] = 1 60 | else: 61 | pass 62 | print(graph) 63 | return graph 64 | 65 | 66 | if __name__ == '__main__': 67 | main() 68 | 69 | -------------------------------------------------------------------------------- /Graph_generate/yelp_graph.py: -------------------------------------------------------------------------------- 1 | 2 | class YelpGraph(object): 3 | 4 | def __init__(self, dataset): 5 | self.G = dict() 6 | self._load_entities(dataset) 7 | self._load_knowledge(dataset) 8 | self._clean() 9 | 10 | def _load_entities(self, dataset): 11 | print('load entities...') 12 | num_nodes = 0 13 | data_relations, _, _ = dataset.get_relation() # entity_relations, relation_name, link_entity_type 14 | entity_list = list(data_relations.keys()) 15 | for entity in entity_list: 16 | self.G[entity] = {} 17 | entity_size = getattr(dataset, entity).value_len 18 | for eid in range(entity_size): 19 | entity_rela_list = data_relations[entity].keys() 20 | self.G[entity][eid] = {r: [] for r in entity_rela_list} 21 | num_nodes += entity_size 22 | print('load entity:{:s} : Total {:d} nodes.'.format(entity, entity_size)) 23 | print('ALL total {:d} nodes.'.format(num_nodes)) 24 | print('===============END==============') 25 | 26 | def _load_knowledge(self, dataset): 27 | _, data_relations_name, link_entity_type = dataset.get_relation() # entity_relations, relation_name, link_entity_type 28 | for relation in data_relations_name: 29 | print('Load knowledge {}...'.format(relation)) 30 | data = getattr(dataset, relation).data 31 | num_edges = 0 32 | for he_id, te_ids in enumerate(data): # head_entity_id , tail_entity_ids 33 | if len(te_ids) <= 0: 34 | continue 35 | e_head_type = link_entity_type[relation][0] 36 | e_tail_type = link_entity_type[relation][1] 37 | for te_id in set(te_ids): 38 | self._add_edge(e_head_type, he_id, relation, e_tail_type, te_id) 39 | num_edges += 2 40 | print('Total {:d} {:s} edges.'.format(num_edges, relation)) 41 | print('===============END==============') 42 | 43 | def _add_edge(self, etype1, eid1, relation, etype2, eid2): 44 | self.G[etype1][eid1][relation].append(eid2) 45 | self.G[etype2][eid2][relation].append(eid1) 46 | 47 | def _clean(self): 48 | print('Remove duplicates...') 49 | for etype in self.G: 50 | for eid in self.G[etype]: 51 | for r in self.G[etype][eid]: 52 | data = self.G[etype][eid][r] 53 | data = tuple(sorted(set(data))) 54 | self.G[etype][eid][r] = data 55 | -------------------------------------------------------------------------------- /Graph_generate/lastfm_graph.py: -------------------------------------------------------------------------------- 1 | 2 | class LastFmGraph(object): 3 | 4 | def __init__(self, dataset): 5 | self.G = dict() 6 | self._load_entities(dataset) 7 | self._load_knowledge(dataset) 8 | self._clean() 9 | 10 | def _load_entities(self, dataset): 11 | print('load entities...') 12 | num_nodes = 0 13 | data_relations, _, _ = dataset.get_relation() # entity_relations, relation_name, link_entity_type 14 | entity_list = list(data_relations.keys()) 15 | for entity in entity_list: 16 | self.G[entity] = {} 17 | entity_size = getattr(dataset, entity).value_len 18 | for eid in range(entity_size): 19 | entity_rela_list = data_relations[entity].keys() 20 | self.G[entity][eid] = {r: [] for r in entity_rela_list} 21 | num_nodes += entity_size 22 | print('load entity:{:s} : Total {:d} nodes.'.format(entity, entity_size)) 23 | print('ALL total {:d} nodes.'.format(num_nodes)) 24 | print('===============END==============') 25 | 26 | def _load_knowledge(self, dataset): 27 | _, data_relations_name, link_entity_type = dataset.get_relation() # entity_relations, relation_name, link_entity_type 28 | for relation in data_relations_name: 29 | print('Load knowledge {}...'.format(relation)) 30 | data = getattr(dataset, relation).data 31 | num_edges = 0 32 | for he_id, te_ids in enumerate(data): # head_entity_id , tail_entity_ids 33 | if len(te_ids) <= 0: 34 | continue 35 | e_head_type = link_entity_type[relation][0] 36 | e_tail_type = link_entity_type[relation][1] 37 | for te_id in set(te_ids): 38 | self._add_edge(e_head_type, he_id, relation, e_tail_type, te_id) 39 | num_edges += 2 40 | print('Total {:d} {:s} edges.'.format(num_edges, relation)) 41 | print('===============END==============') 42 | 43 | def _add_edge(self, etype1, eid1, relation, etype2, eid2): 44 | self.G[etype1][eid1][relation].append(eid2) 45 | self.G[etype2][eid2][relation].append(eid1) 46 | 47 | def _clean(self): 48 | print('Remove duplicates...') 49 | for etype in self.G: 50 | for eid in self.G[etype]: 51 | for r in self.G[etype][eid]: 52 | data = self.G[etype][eid][r] 53 | data = tuple(sorted(set(data))) 54 | self.G[etype][eid][r] = data 55 | -------------------------------------------------------------------------------- /Graph_generate/lastfm_data_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from easydict import EasyDict as edict 4 | 5 | 6 | class LastFmDataset(object): 7 | def __init__(self, data_dir): 8 | self.data_dir = data_dir + '/Graph_generate_data' 9 | self.load_entities() 10 | self.load_relations() 11 | def get_relation(self): 12 | #Entities 13 | USER = 'user' 14 | ITEM = 'item' 15 | FEATURE = 'feature' 16 | 17 | #Relations 18 | INTERACT = 'interact' 19 | FRIEND = 'friends' 20 | LIKE = 'like' 21 | BELONG_TO = 'belong_to' 22 | relation_name = [INTERACT, FRIEND, LIKE, BELONG_TO] 23 | 24 | fm_relation = { 25 | USER: { 26 | INTERACT: ITEM, 27 | FRIEND: USER, 28 | LIKE: FEATURE, 29 | }, 30 | ITEM: { 31 | BELONG_TO: FEATURE, 32 | INTERACT: USER 33 | }, 34 | FEATURE: { 35 | LIKE: USER, 36 | BELONG_TO: ITEM 37 | } 38 | } 39 | fm_relation_link_entity_type = { 40 | INTERACT: [USER, ITEM], 41 | FRIEND: [USER, USER], 42 | LIKE: [USER, FEATURE], 43 | BELONG_TO: [ITEM, FEATURE] 44 | } 45 | return fm_relation, relation_name, fm_relation_link_entity_type 46 | def load_entities(self): 47 | entity_files = edict( 48 | user='user_dict.json', 49 | item='item_dict.json', 50 | feature='merged_tag_map.json', 51 | ) 52 | for entity_name in entity_files: 53 | with open(os.path.join(self.data_dir,entity_files[entity_name]), encoding='utf-8') as f: 54 | mydict = json.load(f) 55 | if entity_name == 'feature': 56 | entity_id = list(mydict.values()) 57 | else: 58 | entity_id = list(map(int, list(mydict.keys()))) 59 | setattr(self, entity_name, edict(id=entity_id, value_len=max(entity_id)+1)) 60 | print('Load', entity_name, 'of size', len(entity_id)) 61 | print(entity_name, 'of max id is', max(entity_id)) 62 | 63 | def load_relations(self): 64 | """ 65 | relation: head entity---> tail entity 66 | -- 67 | """ 68 | LastFm_relations = edict( 69 | interact=('user_item.json', self.user, self.item), #(filename, head_entity, tail_entity) 70 | friends=('user_dict.json', self.user, self.user), 71 | like=('user_dict.json', self.user, self.feature), 72 | belong_to=('item_dict.json', self.item, self.feature), 73 | ) 74 | for name in LastFm_relations: 75 | # Save tail_entity 76 | relation = edict( 77 | data=[], 78 | ) 79 | knowledge = [list([]) for i in range(LastFm_relations[name][1].value_len)] 80 | # load relation files 81 | with open(os.path.join(self.data_dir, LastFm_relations[name][0]), encoding='utf-8') as f: 82 | mydict = json.load(f) 83 | if name in ['interact']: 84 | for key, value in mydict.items(): 85 | head_id = int(key) 86 | tail_ids = value 87 | knowledge[head_id] = tail_ids 88 | elif name in ['friends', 'like']: 89 | for key in mydict.keys(): 90 | head_str = key 91 | head_id = int(key) 92 | tail_ids = mydict[head_str][name] 93 | knowledge[head_id] = tail_ids 94 | elif name in ['belong_to']: 95 | for key in mydict.keys(): 96 | head_str = key 97 | head_id = int(key) 98 | tail_ids = mydict[head_str]['feature_index'] 99 | knowledge[head_id] = tail_ids 100 | relation.data = knowledge 101 | setattr(self, name, relation) 102 | tuple_num = 0 103 | for i in knowledge: 104 | tuple_num += len(i) 105 | print('Load', name, 'of size', tuple_num) 106 | 107 | 108 | 109 | 110 | 111 | 112 | -------------------------------------------------------------------------------- /Graph_generate/lastfm_star_data_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from easydict import EasyDict as edict 4 | 5 | 6 | class LastFmStarDataset(object): 7 | def __init__(self, data_dir): 8 | self.data_dir = data_dir + '/Graph_generate_data' 9 | self.load_entities() 10 | self.load_relations() 11 | 12 | def get_relation(self): 13 | # Entities 14 | USER = 'user' 15 | ITEM = 'item' 16 | FEATURE = 'feature' 17 | 18 | # Relations 19 | INTERACT = 'interact' 20 | FRIEND = 'friends' 21 | LIKE = 'like' 22 | BELONG_TO = 'belong_to' 23 | relation_name = [INTERACT, FRIEND, LIKE, BELONG_TO] 24 | 25 | fm_relation = { 26 | USER: { 27 | INTERACT: ITEM, 28 | FRIEND: USER, 29 | LIKE: FEATURE, 30 | }, 31 | ITEM: { 32 | BELONG_TO: FEATURE, 33 | INTERACT: USER 34 | }, 35 | FEATURE: { 36 | LIKE: USER, 37 | BELONG_TO: ITEM 38 | } 39 | } 40 | fm_relation_link_entity_type = { 41 | INTERACT: [USER, ITEM], 42 | FRIEND: [USER, USER], 43 | LIKE: [USER, FEATURE], 44 | BELONG_TO: [ITEM, FEATURE] 45 | } 46 | return fm_relation, relation_name, fm_relation_link_entity_type 47 | 48 | def load_entities(self): 49 | entity_files = edict( 50 | user='user_dict.json', 51 | item='item_dict.json', 52 | feature='original_tag_map.json', 53 | ) 54 | for entity_name in entity_files: 55 | with open(os.path.join(self.data_dir, entity_files[entity_name]), encoding='utf-8') as f: 56 | mydict = json.load(f) 57 | if entity_name == 'feature': 58 | entity_id = list(mydict.values()) 59 | else: 60 | entity_id = list(map(int, list(mydict.keys()))) 61 | setattr(self, entity_name, edict(id=entity_id, value_len=max(entity_id) + 1)) 62 | print('Load', entity_name, 'of size', len(entity_id)) 63 | print(entity_name, 'of max id is', max(entity_id)) 64 | 65 | def load_relations(self): 66 | """ 67 | relation: head entity---> tail entity 68 | -- 69 | """ 70 | LastFm_relations = edict( 71 | interact=('user_item.json', self.user, self.item), # (filename, head_entity, tail_entity) 72 | friends=('user_dict.json', self.user, self.user), 73 | like=('user_dict.json', self.user, self.feature), 74 | belong_to=('item_dict.json', self.item, self.feature), 75 | ) 76 | for name in LastFm_relations: 77 | # Save tail_entity 78 | relation = edict( 79 | data=[], 80 | ) 81 | knowledge = [list([]) for i in 82 | range(LastFm_relations[name][1].value_len)] 83 | # load relation files 84 | with open(os.path.join(self.data_dir, LastFm_relations[name][0]), encoding='utf-8') as f: 85 | mydict = json.load(f) 86 | if name in ['interact']: 87 | for key, value in mydict.items(): 88 | head_id = int(key) 89 | tail_ids = value 90 | knowledge[head_id] = tail_ids 91 | elif name in ['friends', 'like']: 92 | for key in mydict.keys(): 93 | head_str = key 94 | head_id = int(key) 95 | tail_ids = mydict[head_str][name] 96 | knowledge[head_id] = tail_ids 97 | elif name in ['belong_to']: 98 | for key in mydict.keys(): 99 | head_str = key 100 | head_id = int(key) 101 | tail_ids = mydict[head_str]['feature_index'] 102 | knowledge[head_id] = tail_ids 103 | relation.data = knowledge 104 | setattr(self, name, relation) 105 | tuple_num = 0 106 | for i in knowledge: 107 | tuple_num += len(i) 108 | print('Load', name, 'of size', tuple_num) -------------------------------------------------------------------------------- /gcn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn.parameter import Parameter 4 | from torch.nn.modules.module import Module 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from tqdm import tqdm 8 | import pickle 9 | import gzip 10 | import numpy as np 11 | import time 12 | 13 | 14 | class GraphConvolution(Module): 15 | 16 | def __init__(self, in_features, out_features, bias=True): 17 | super(GraphConvolution, self).__init__() 18 | self.in_features = in_features 19 | self.out_features = out_features 20 | self.weight = Parameter(torch.FloatTensor(in_features, out_features)) 21 | if bias: 22 | self.bias = Parameter(torch.FloatTensor(out_features)) 23 | else: 24 | self.register_parameter('bias', None) 25 | self.reset_parameters() 26 | 27 | def reset_parameters(self): 28 | stdv = 1. / math.sqrt(self.weight.size(1)) 29 | self.weight.data.uniform_(-stdv, stdv) 30 | if self.bias is not None: 31 | self.bias.data.uniform_(-stdv, stdv) 32 | 33 | def forward(self, input, adj): 34 | support = torch.mm(input, self.weight) 35 | output = torch.sparse.mm(adj, support) 36 | if self.bias is not None: 37 | return output + self.bias 38 | else: 39 | return output 40 | 41 | 42 | class GraphEncoder(Module): 43 | def __init__(self, device, entity, emb_size, kg, embeddings=None, fix_emb=True, seq='rnn', gcn=True, hidden_size=100, layers=1, rnn_layer=1): 44 | super(GraphEncoder, self).__init__() 45 | self.embedding = nn.Embedding(entity, emb_size, padding_idx=entity-1) 46 | if embeddings is not None: 47 | print("pre-trained embeddings") 48 | self.embedding.from_pretrained(embeddings,freeze=fix_emb) 49 | self.layers = layers 50 | self.user_num = len(kg.G['user']) 51 | self.item_num = len(kg.G['item']) 52 | self.PADDING_ID = entity-1 53 | self.device = device 54 | self.seq = seq 55 | self.gcn = gcn 56 | 57 | self.fc1 = nn.Linear(hidden_size, hidden_size) 58 | if self.seq == 'rnn': 59 | self.rnn = nn.GRU(hidden_size, hidden_size, rnn_layer, batch_first=True) 60 | elif self.seq == 'transformer': 61 | self.transformer = nn.TransformerEncoder(encoder_layer=nn.TransformerEncoderLayer(d_model=hidden_size, nhead=4, dim_feedforward=400), num_layers=rnn_layer) 62 | 63 | if self.gcn: 64 | indim, outdim = emb_size, hidden_size 65 | self.gnns = nn.ModuleList() 66 | for l in range(layers): 67 | self.gnns.append(GraphConvolution(indim, outdim)) 68 | indim = outdim 69 | else: 70 | self.fc2 = nn.Linear(emb_size, hidden_size) 71 | 72 | def forward(self, b_state): 73 | """ 74 | :param b_state [N] 75 | :return: [N x L x d] 76 | """ 77 | batch_output = [] 78 | for s in b_state: 79 | #neighbors, adj = self.get_state_graph(s) 80 | neighbors, adj = s['neighbors'].to(self.device), s['adj'].to(self.device) 81 | input_state = self.embedding(neighbors) 82 | if self.gcn: 83 | for gnn in self.gnns: 84 | output_state = gnn(input_state, adj) 85 | input_state = output_state 86 | batch_output.append(output_state) 87 | else: 88 | output_state = F.relu(self.fc2(input_state)) 89 | batch_output.append(output_state) 90 | 91 | seq_embeddings = [] 92 | for s, o in zip(b_state, batch_output): 93 | seq_embeddings.append(o[:len(s['cur_node']),:][None,:]) 94 | if len(batch_output) > 1: 95 | seq_embeddings = self.padding_seq(seq_embeddings) 96 | seq_embeddings = torch.cat(seq_embeddings, dim=0) # [N x L x d] 97 | 98 | if self.seq == 'rnn': 99 | _, h = self.rnn(seq_embeddings) 100 | seq_embeddings = h.permute(1,0,2) #[N*1*D] 101 | elif self.seq == 'transformer': 102 | seq_embeddings = torch.mean(self.transformer(seq_embeddings), dim=1, keepdim=True) 103 | elif self.seq == 'mean': 104 | seq_embeddings = torch.mean(seq_embeddings, dim=1, keepdim=True) 105 | 106 | seq_embeddings = F.relu(self.fc1(seq_embeddings)) 107 | 108 | return seq_embeddings 109 | 110 | 111 | def padding_seq(self, seq): 112 | padding_size = max([len(x[0]) for x in seq]) 113 | padded_seq = [] 114 | for s in seq: 115 | cur_size = len(s[0]) 116 | emb_size = len(s[0][0]) 117 | new_s = torch.zeros((padding_size, emb_size)).to(self.device) 118 | new_s[:cur_size,:] = s[0] 119 | padded_seq.append(new_s[None,:]) 120 | return padded_seq 121 | -------------------------------------------------------------------------------- /Graph_generate/yelp_data_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from easydict import EasyDict as edict 4 | 5 | class YelpDataset(object): 6 | def __init__(self, data_dir): 7 | self.data_dir = data_dir + '/Graph_generate_data' 8 | self.load_entities() 9 | self.load_relations() 10 | def get_relation(self): 11 | #Entities 12 | USER = 'user' 13 | ITEM = 'item' 14 | FEATURE = 'feature' 15 | LARGE_FEATURE = 'large_feature' 16 | 17 | #Relations 18 | INTERACT = 'interact' 19 | FRIEND = 'friends' 20 | LIKE = 'like' 21 | BELONG_TO = 'belong_to' 22 | BELONG_TO_LARGE = 'belong_to_large' # feature(second-layer tag) --> large_feature(first-layer tag) 23 | LINK_TO_FEATURE = 'link_to_feature' # large_feature(first-layer tag) --> feature(second-layer tag) 24 | relation_name = [INTERACT, FRIEND, LIKE, BELONG_TO, BELONG_TO_LARGE, LINK_TO_FEATURE] 25 | 26 | fm_relation = { 27 | USER: { 28 | INTERACT: ITEM, 29 | FRIEND: USER, 30 | LIKE: FEATURE, #There is no such relationship in yelp 31 | }, 32 | ITEM: { 33 | BELONG_TO: FEATURE, 34 | BELONG_TO_LARGE: LARGE_FEATURE, 35 | INTERACT: USER 36 | }, 37 | FEATURE: { 38 | LIKE: USER, 39 | BELONG_TO: ITEM, 40 | LINK_TO_FEATURE: LARGE_FEATURE 41 | }, 42 | LARGE_FEATURE: { 43 | LIKE: USER, 44 | BELONG_TO_LARGE: ITEM, 45 | LINK_TO_FEATURE: FEATURE 46 | 47 | } 48 | 49 | } 50 | relation_link_entity_type = { 51 | INTERACT: [USER, ITEM], 52 | FRIEND: [USER, USER], 53 | LIKE: [USER, FEATURE], 54 | BELONG_TO: [ITEM, FEATURE], 55 | BELONG_TO_LARGE: [ITEM, LARGE_FEATURE], 56 | LINK_TO_FEATURE: [LARGE_FEATURE, FEATURE] 57 | } 58 | return fm_relation, relation_name, relation_link_entity_type 59 | def load_entities(self): 60 | entity_files = edict( 61 | user='user_dict.json', 62 | item='item_dict-original_tag.json', 63 | feature='second-layer_oringinal_tag_map.json', 64 | large_feature='first-layer_merged_tag_map.json' 65 | ) 66 | for entity_name in entity_files: 67 | with open(os.path.join(self.data_dir,entity_files[entity_name]), encoding='utf-8') as f: 68 | mydict = json.load(f) 69 | if entity_name in ['feature']: 70 | entity_id = list(mydict.values()) 71 | elif entity_name in ['large_feature']: 72 | entity_id = list(map(int, list(mydict.values()))) 73 | else: 74 | entity_id = list(map(int, list(mydict.keys()))) 75 | setattr(self, entity_name, edict(id=entity_id, value_len=max(entity_id)+1)) 76 | print('Load', entity_name, 'of size', len(entity_id)) 77 | print(entity_name, 'of max id is', max(entity_id)) 78 | 79 | def load_relations(self): 80 | """ 81 | relation: head entity---> tail entity 82 | -- 83 | """ 84 | Yelp_relations = edict( 85 | interact=('user_item.json', self.user, self.item), #(filename, head_entity, tail_entity) 86 | friends=('user_dict.json', self.user, self.user), 87 | like=('user_dict.json', self.user, self.feature), 88 | belong_to=('item_dict-original_tag.json', self.item, self.feature), 89 | belong_to_large=('item_dict-merged_tag.json', self.item, self.large_feature), 90 | link_to_feature=('2-layer taxonomy.json', self.large_feature, self.feature) 91 | ) 92 | for name in Yelp_relations: 93 | # Save tail_entity 94 | relation = edict( 95 | data=[], 96 | ) 97 | knowledge = [list([]) for i in range(Yelp_relations[name][1].value_len)] 98 | # load relation files 99 | with open(os.path.join(self.data_dir, Yelp_relations[name][0]), encoding='utf-8') as f: 100 | mydict = json.load(f) 101 | if name in ['interact', 'belong_to_large']: 102 | for key, value in mydict.items(): 103 | head_id = int(key) 104 | tail_ids = value 105 | knowledge[head_id] = tail_ids 106 | elif name in ['friends', 'like']: 107 | for key in mydict.keys(): 108 | head_str = key 109 | head_id = int(key) 110 | tail_ids = mydict[head_str][name] 111 | knowledge[head_id] = tail_ids 112 | elif name in ['belong_to']: 113 | for key in mydict.keys(): 114 | head_str = key 115 | head_id = int(key) 116 | tail_ids = mydict[head_str]['feature_index'] 117 | knowledge[head_id] = tail_ids 118 | elif name in ['link_to_feature']: 119 | with open(os.path.join(self.data_dir, 'first-layer_merged_tag_map.json'), encoding='utf-8') as f: 120 | tag_map = json.load(f) 121 | for key, value in mydict.items(): 122 | head_id = tag_map[key] 123 | tail_ids = value 124 | knowledge[head_id] = tail_ids 125 | relation.data = knowledge 126 | setattr(self, name, relation) 127 | tuple_num = 0 128 | for i in knowledge: 129 | tuple_num += len(i) 130 | print('Load', name, 'of size', tuple_num) 131 | 132 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import random 4 | import torch 5 | import os 6 | import sys 7 | # from knowledge_graph import KnowledgeGraph 8 | # from data_process import LastFmDataset 9 | # from KG_data_generate.lastfm_small_data_process import LastFmSmallDataset 10 | # from KG_data_generate.lastfm_knowledge_graph import KnowledgeGraph 11 | #Dataset names 12 | LAST_FM = 'LAST_FM' 13 | LAST_FM_STAR = 'LAST_FM_STAR' 14 | YELP = 'YELP' 15 | YELP_STAR = 'YELP_STAR' 16 | 17 | DATA_DIR = { 18 | LAST_FM: './data/lastfm', 19 | YELP: './data/yelp', 20 | LAST_FM_STAR: './data/lastfm_star', 21 | YELP_STAR: './data/yelp', 22 | } 23 | TMP_DIR = { 24 | LAST_FM: './tmp/last_fm', 25 | YELP: './tmp/yelp', 26 | LAST_FM_STAR: './tmp/last_fm_star', 27 | YELP_STAR: './tmp/yelp_star', 28 | } 29 | def cuda_(var): 30 | return var.cuda() if torch.cuda.is_available() else var 31 | def save_dataset(dataset, dataset_obj): 32 | dataset_file = TMP_DIR[dataset] + '/dataset.pkl' 33 | with open(dataset_file, 'wb') as f: 34 | pickle.dump(dataset_obj, f) 35 | 36 | def load_dataset(dataset): 37 | dataset_file = TMP_DIR[dataset] + '/dataset.pkl' 38 | dataset_obj = pickle.load(open(dataset_file, 'rb')) 39 | return dataset_obj 40 | 41 | def save_kg(dataset, kg): 42 | kg_file = TMP_DIR[dataset] + '/kg.pkl' 43 | pickle.dump(kg, open(kg_file, 'wb')) 44 | 45 | def load_kg(dataset): 46 | kg_file = TMP_DIR[dataset] + '/kg.pkl' 47 | kg = pickle.load(open(kg_file, 'rb')) 48 | return kg 49 | 50 | def save_graph(dataset, graph): 51 | graph_file = TMP_DIR[dataset] + '/graph.pkl' 52 | pickle.dump(graph, open(graph_file, 'wb')) 53 | 54 | def load_graph(dataset): 55 | graph_file = TMP_DIR[dataset] + '/graph.pkl' 56 | graph = pickle.load(open(graph_file, 'rb')) 57 | return graph 58 | 59 | 60 | def load_embed(dataset, embed, epoch): 61 | if embed: 62 | path = TMP_DIR[dataset] + '/embeds/' + '{}.pkl'.format(embed) 63 | else: 64 | return None 65 | with open(path, 'rb') as f: 66 | embeds = pickle.load(f) 67 | print('{} Embedding load successfully!'.format(embed)) 68 | return embeds 69 | 70 | 71 | def load_rl_agent(dataset, filename, epoch_user): 72 | model_file = TMP_DIR[dataset] + '/RL-agent/' + filename + '-epoch-{}.pkl'.format(epoch_user) 73 | model_dict = torch.load(model_file) 74 | print('RL policy model load at {}'.format(model_file)) 75 | return model_dict 76 | 77 | def save_rl_agent(dataset, model, filename, epoch_user): 78 | model_file = TMP_DIR[dataset] + '/RL-agent/' + filename + '-epoch-{}.pkl'.format(epoch_user) 79 | if not os.path.isdir(TMP_DIR[dataset] + '/RL-agent/'): 80 | os.makedirs(TMP_DIR[dataset] + '/RL-agent/') 81 | torch.save(model, model_file) 82 | print('RL policy model saved at {}'.format(model_file)) 83 | 84 | 85 | def save_rl_mtric(dataset, filename, epoch, SR, spend_time, mode='train'): 86 | PATH = TMP_DIR[dataset] + '/RL-log-merge/' + filename + '.txt' 87 | if not os.path.isdir(TMP_DIR[dataset] + '/RL-log-merge/'): 88 | os.makedirs(TMP_DIR[dataset] + '/RL-log-merge/') 89 | if mode == 'train': 90 | with open(PATH, 'a') as f: 91 | f.write('===========Train===============\n') 92 | f.write('Starting {} user epochs\n'.format(epoch)) 93 | f.write('training SR@5: {}\n'.format(SR[0])) 94 | f.write('training SR@10: {}\n'.format(SR[1])) 95 | f.write('training SR@15: {}\n'.format(SR[2])) 96 | f.write('training Avg@T: {}\n'.format(SR[3])) 97 | f.write('training hDCG: {}\n'.format(SR[4])) 98 | f.write('Spending time: {}\n'.format(spend_time)) 99 | f.write('================================\n') 100 | # f.write('1000 loss: {}\n'.format(loss_1000)) 101 | elif mode == 'test': 102 | with open(PATH, 'a') as f: 103 | f.write('===========Test===============\n') 104 | f.write('Testing {} user tuples\n'.format(epoch)) 105 | f.write('Testing SR@5: {}\n'.format(SR[0])) 106 | f.write('Testing SR@10: {}\n'.format(SR[1])) 107 | f.write('Testing SR@15: {}\n'.format(SR[2])) 108 | f.write('Testing Avg@T: {}\n'.format(SR[3])) 109 | f.write('Testing hDCG: {}\n'.format(SR[4])) 110 | f.write('Testing time: {}\n'.format(spend_time)) 111 | f.write('================================\n') 112 | # f.write('1000 loss: {}\n'.format(loss_1000)) 113 | 114 | def save_rl_model_log(dataset, filename, epoch, epoch_loss, train_len): 115 | PATH = TMP_DIR[dataset] + '/RL-log-merge/' + filename + '.txt' 116 | if not os.path.isdir(TMP_DIR[dataset] + '/RL-log-merge/'): 117 | os.makedirs(TMP_DIR[dataset] + '/RL-log-merge/') 118 | with open(PATH, 'a') as f: 119 | f.write('Starting {} epoch\n'.format(epoch)) 120 | f.write('training loss : {}\n'.format(epoch_loss / train_len)) 121 | # f.write('1000 loss: {}\n'.format(loss_1000)) 122 | 123 | def set_random_seed(seed): 124 | random.seed(seed) 125 | np.random.seed(seed) 126 | torch.manual_seed(seed) 127 | if torch.cuda.is_available(): 128 | torch.cuda.manual_seed_all(seed) 129 | 130 | 131 | # Disable 132 | def blockPrint(): 133 | sys.stdout = open(os.devnull, 'w') 134 | 135 | # Restore 136 | def enablePrint(): 137 | sys.stdout = sys.__stdout__ 138 | 139 | 140 | def set_cuda(args): 141 | use_cuda = torch.cuda.is_available() 142 | if use_cuda: 143 | torch.cuda.manual_seed(args.seed) 144 | torch.backends.cudnn.deterministic = True 145 | devices_id = [int(device_id) for device_id in args.gpu.split()] 146 | device = ( 147 | torch.device("cuda:{}".format(str(devices_id[0]))) 148 | if use_cuda 149 | else torch.device("cpu") 150 | ) 151 | return device, devices_id -------------------------------------------------------------------------------- /RL/RL_evaluate.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | from itertools import count 4 | import torch.nn as nn 5 | import torch 6 | import math 7 | from collections import namedtuple 8 | from utils import * 9 | from RL.env_binary_question import BinaryRecommendEnv 10 | from RL.env_enumerated_question import EnumeratedRecommendEnv 11 | from tqdm import tqdm 12 | EnvDict = { 13 | LAST_FM: BinaryRecommendEnv, 14 | LAST_FM_STAR: BinaryRecommendEnv, 15 | YELP: EnumeratedRecommendEnv, 16 | YELP_STAR: BinaryRecommendEnv 17 | } 18 | 19 | def dqn_evaluate(args, kg, dataset, agent, filename, i_episode): 20 | test_env = EnvDict[args.data_name](kg, dataset, args.data_name, args.embed, seed=args.seed, max_turn=args.max_turn, 21 | cand_num=args.cand_num, cand_item_num=args.cand_item_num, attr_num=args.attr_num, mode='test', ask_num=args.ask_num, entropy_way=args.entropy_method, 22 | fm_epoch=args.fm_epoch) 23 | set_random_seed(args.seed) 24 | tt = time.time() 25 | start = tt 26 | 27 | SR5, SR10, SR15, AvgT, Rank, total_reward = 0, 0, 0, 0, 0, 0 28 | SR_turn_15 = [0]* args.max_turn 29 | turn_result = [] 30 | result = [] 31 | user_size = test_env.ui_array.shape[0] 32 | print('User size in UI_test: ', user_size) 33 | test_filename = 'Evaluate-epoch-{}-'.format(i_episode) + filename 34 | plot_filename = 'Evaluate-'.format(i_episode) + filename 35 | if args.data_name in [LAST_FM_STAR, LAST_FM]: 36 | if args.eval_num == 1: 37 | test_size = 500 38 | else: 39 | test_size = 4000 # Only do 4000 iteration for the sake of time 40 | user_size = test_size 41 | if args.data_name in [YELP_STAR, YELP]: 42 | if args.eval_num == 1: 43 | test_size = 500 44 | else: 45 | test_size = 2500 # Only do 2500 iteration for the sake of time 46 | user_size = test_size 47 | print('The select Test size : ', test_size) 48 | for user_num in tqdm(range(user_size)): #user_size 49 | # TODO uncommend this line to print the dialog process 50 | blockPrint() 51 | print('\n================test tuple:{}===================='.format(user_num)) 52 | if not args.fix_emb: 53 | state, cand, action_space = test_env.reset(agent.gcn_net.embedding.weight.data.cpu().detach().numpy()) # Reset environment and record the starting state 54 | else: 55 | state, cand, action_space = test_env.reset() 56 | epi_reward = 0 57 | is_last_turn = False 58 | for t in count(): # user dialog 59 | if t == 14: 60 | is_last_turn = True 61 | action, sorted_actions = agent.select_action(state, cand, action_space, is_test=True, is_last_turn=is_last_turn) 62 | next_state, next_cand, action_space, reward, done = test_env.step(action.item(), sorted_actions) 63 | epi_reward += reward 64 | reward = torch.tensor([reward], device=args.device, dtype=torch.float) 65 | if done: 66 | next_state = None 67 | state = next_state 68 | cand = next_cand 69 | if done: 70 | enablePrint() 71 | if reward.item() == 1: # recommend successfully 72 | SR_turn_15 = [v+1 if i>t else v for i, v in enumerate(SR_turn_15) ] 73 | if t < 5: 74 | SR5 += 1 75 | SR10 += 1 76 | SR15 += 1 77 | elif t < 10: 78 | SR10 += 1 79 | SR15 += 1 80 | else: 81 | SR15 += 1 82 | Rank += (1/math.log(t+3,2) + (1/math.log(t+2,2)-1/math.log(t+3,2))/math.log(done+1,2)) 83 | else: 84 | Rank += 0 85 | total_reward += epi_reward 86 | AvgT += t+1 87 | break 88 | 89 | if (user_num+1) % args.observe_num == 0 and user_num > 0: 90 | SR = [SR5/args.observe_num, SR10/args.observe_num, SR15/args.observe_num, AvgT / args.observe_num, Rank / args.observe_num, total_reward / args.observe_num] 91 | SR_TURN = [i/args.observe_num for i in SR_turn_15] 92 | print('Total evalueation epoch_uesr:{}'.format(user_num + 1)) 93 | print('Takes {} seconds to finish {}% of this task'.format(str(time.time() - start), 94 | float(user_num) * 100 / user_size)) 95 | print('SR5:{}, SR10:{}, SR15:{}, AvgT:{}, Rank:{}, reward:{} ' 96 | 'Total epoch_uesr:{}'.format(SR5 / args.observe_num, SR10 / args.observe_num, SR15 / args.observe_num, 97 | AvgT / args.observe_num, Rank / args.observe_num, total_reward / args.observe_num, user_num + 1)) 98 | result.append(SR) 99 | turn_result.append(SR_TURN) 100 | SR5, SR10, SR15, AvgT, Rank, total_reward = 0, 0, 0, 0, 0, 0 101 | SR_turn_15 = [0] * args.max_turn 102 | tt = time.time() 103 | enablePrint() 104 | 105 | SR5_mean = np.mean(np.array([item[0] for item in result])) 106 | SR10_mean = np.mean(np.array([item[1] for item in result])) 107 | SR15_mean = np.mean(np.array([item[2] for item in result])) 108 | AvgT_mean = np.mean(np.array([item[3] for item in result])) 109 | Rank_mean = np.mean(np.array([item[4] for item in result])) 110 | reward_mean = np.mean(np.array([item[5] for item in result])) 111 | SR_all = [SR5_mean, SR10_mean, SR15_mean, AvgT_mean, Rank_mean, reward_mean] 112 | save_rl_mtric(dataset=args.data_name, filename=filename, epoch=user_num, SR=SR_all, spend_time=time.time() - start, 113 | mode='test') 114 | save_rl_mtric(dataset=args.data_name, filename=test_filename, epoch=user_num, SR=SR_all, spend_time=time.time() - start, 115 | mode='test') # save RL SR 116 | print('save test evaluate successfully!') 117 | 118 | SRturn_all = [0] * args.max_turn 119 | for i in range(len(SRturn_all)): 120 | SRturn_all[i] = np.mean(np.array([item[i] for item in turn_result])) 121 | print('success turn:{}'.format(SRturn_all)) 122 | print('SR5:{}, SR10:{}, SR15:{}, AvgT:{}, Rank:{}, reward:{}'.format(SR5_mean, SR10_mean, SR15_mean, AvgT_mean, Rank_mean, reward_mean)) 123 | PATH = TMP_DIR[args.data_name] + '/RL-log-merge/' + test_filename + '.txt' 124 | with open(PATH, 'a') as f: 125 | f.write('Training epocch:{}\n'.format(i_episode)) 126 | f.write('===========Test Turn===============\n') 127 | f.write('Testing {} user tuples\n'.format(user_num)) 128 | for i in range(len(SRturn_all)): 129 | f.write('Testing SR-turn@{}: {}\n'.format(i, SRturn_all[i])) 130 | f.write('================================\n') 131 | PATH = TMP_DIR[args.data_name] + '/RL-log-merge/' + plot_filename + '.txt' 132 | with open(PATH, 'a') as f: 133 | f.write('{}\t{}\t{}\t{}\t{}\n'.format(i_episode, SR15_mean, AvgT_mean, Rank_mean, reward_mean)) 134 | 135 | return SR15_mean 136 | 137 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import random 4 | import numpy as np 5 | import os 6 | import sys 7 | from tqdm import tqdm 8 | # sys.path.append('..') 9 | 10 | from collections import namedtuple 11 | import argparse 12 | from itertools import count, chain 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torch.nn.functional as F 17 | from utils import * 18 | from sum_tree import SumTree 19 | 20 | #TODO select env 21 | from RL.env_binary_question import BinaryRecommendEnv 22 | from RL.env_enumerated_question import EnumeratedRecommendEnv 23 | from RL.RL_evaluate import dqn_evaluate 24 | from RL_model import Agent, ReplayMemoryPER 25 | from gcn import GraphEncoder 26 | import time 27 | import warnings 28 | 29 | warnings.filterwarnings("ignore") 30 | EnvDict = { 31 | LAST_FM: BinaryRecommendEnv, 32 | LAST_FM_STAR: BinaryRecommendEnv, 33 | YELP: EnumeratedRecommendEnv, 34 | YELP_STAR: BinaryRecommendEnv 35 | } 36 | 37 | FeatureDict = { 38 | LAST_FM: 'feature', 39 | LAST_FM_STAR: 'feature', 40 | YELP: 'large_feature', 41 | YELP_STAR: 'feature' 42 | } 43 | 44 | 45 | def evaluate(args, kg, dataset, filename): 46 | test_env = EnvDict[args.data_name](kg, dataset, args.data_name, args.embed, seed=args.seed, max_turn=args.max_turn, 47 | cand_num=args.cand_num, cand_item_num=args.cand_item_num, attr_num=args.attr_num, mode='test', ask_num=args.ask_num, entropy_way=args.entropy_method, 48 | fm_epoch=args.fm_epoch) 49 | set_random_seed(args.seed) 50 | memory = ReplayMemoryPER(args.memory_size) #10000 51 | embed = torch.FloatTensor(np.concatenate((test_env.ui_embeds, test_env.feature_emb, np.zeros((1,test_env.ui_embeds.shape[1]))), axis=0)) 52 | gcn_net = GraphEncoder(device=args.device, entity=embed.size(0), emb_size=embed.size(1), kg=kg, embeddings=embed, \ 53 | fix_emb=args.fix_emb, seq=args.seq, gcn=args.gcn, hidden_size=args.hidden).to(args.device) 54 | agent = Agent(device=args.device, memory=memory, state_size=args.hidden, action_size=embed.size(1), \ 55 | hidden_size=args.hidden, gcn_net=gcn_net, learning_rate=args.learning_rate, l2_norm=args.l2_norm, PADDING_ID=embed.size(0)-1) 56 | print('Staring loading rl model in epoch {}'.format(args.load_rl_epoch)) 57 | agent.load_model(data_name=args.data_name, filename=filename, epoch_user=args.load_rl_epoch) 58 | 59 | tt = time.time() 60 | start = tt 61 | 62 | SR5, SR10, SR15, AvgT, Rank = 0, 0, 0, 0, 0 63 | SR_turn_15 = [0]* args.max_turn 64 | turn_result = [] 65 | result = [] 66 | user_size = test_env.ui_array.shape[0] 67 | 68 | print('User size in UI_test: ', user_size) 69 | test_filename = 'Evaluate-epoch-{}-'.format(args.load_rl_epoch) + filename 70 | 71 | for user_num in tqdm(range(user_size)): #user_size 72 | # TODO uncommend this line to print the dialog process 73 | blockPrint() 74 | print('\n================test tuple:{}===================='.format(user_num)) 75 | state, cand, action_space = test_env.reset() # Reset environment and record the starting state 76 | is_last_turn = False 77 | for t in count(): # user dialog 78 | if t == 14: 79 | is_last_turn = True 80 | action, sorted_actions = agent.select_action(state, cand, action_space, is_test=True, is_last_turn=is_last_turn) 81 | next_state, next_cand, action_space, reward, done = test_env.step(action.item(), sorted_actions) 82 | reward = torch.tensor([reward], device=args.device, dtype=torch.float) 83 | if done: 84 | next_state = None 85 | state = next_state 86 | cand = next_cand 87 | if done: 88 | enablePrint() 89 | if reward.item() == 1: # recommend successfully 90 | SR_turn_15 = [v+1 if i>t else v for i, v in enumerate(SR_turn_15) ] 91 | if t < 5: 92 | SR5 += 1 93 | SR10 += 1 94 | SR15 += 1 95 | elif t < 10: 96 | SR10 += 1 97 | SR15 += 1 98 | else: 99 | SR15 += 1 100 | Rank += (1/math.log(t+3,2) + (1/math.log(t+2,2)-1/math.log(t+3,2))/math.log(done+1,2)) 101 | else: 102 | Rank += 0 103 | AvgT += t+1 104 | break 105 | 106 | if (user_num+1) % args.observe_num == 0 and user_num > 0: 107 | SR = [SR5/args.observe_num, SR10/args.observe_num, SR15/args.observe_num, AvgT / args.observe_num, Rank / args.observe_num] 108 | SR_TURN = [i/args.observe_num for i in SR_turn_15] 109 | print('Total evalueation epoch_uesr:{}'.format(user_num + 1)) 110 | print('Takes {} seconds to finish {}% of this task'.format(str(time.time() - start), 111 | float(user_num) * 100 / user_size)) 112 | print('SR5:{}, SR10:{}, SR15:{}, AvgT:{}, Rank:{} ' 113 | 'Total epoch_uesr:{}'.format(SR5 / args.observe_num, SR10 / args.observe_num, SR15 / args.observe_num, 114 | AvgT / args.observe_num, Rank / args.observe_num, user_num + 1)) 115 | result.append(SR) 116 | turn_result.append(SR_TURN) 117 | SR5, SR10, SR15, AvgT, Rank = 0, 0, 0, 0, 0 118 | SR_turn_15 = [0] * args.max_turn 119 | tt = time.time() 120 | enablePrint() 121 | 122 | SR5_mean = np.mean(np.array([item[0] for item in result])) 123 | SR10_mean = np.mean(np.array([item[1] for item in result])) 124 | SR15_mean = np.mean(np.array([item[2] for item in result])) 125 | AvgT_mean = np.mean(np.array([item[3] for item in result])) 126 | Rank_mean = np.mean(np.array([item[4] for item in result])) 127 | SR_all = [SR5_mean, SR10_mean, SR15_mean, AvgT_mean, Rank_mean] 128 | save_rl_mtric(dataset=args.data_name, filename=filename, epoch=user_num, SR=SR_all, spend_time=time.time() - start, 129 | mode='test') 130 | save_rl_mtric(dataset=args.data_name, filename=test_filename, epoch=user_num, SR=SR_all, spend_time=time.time() - start, 131 | mode='test') # save RL SR 132 | print('save test evaluate successfully!') 133 | 134 | SRturn_all = [0] * args.max_turn 135 | for i in range(len(SRturn_all)): 136 | SRturn_all[i] = np.mean(np.array([item[i] for item in turn_result])) 137 | print('success turn:{}'.format(SRturn_all)) 138 | print('SR5:{}, SR10:{}, SR15:{}, AvgT:{}, Rank:{}'.format(SR5_mean, SR10_mean, SR15_mean, AvgT_mean, Rank_mean)) 139 | PATH = TMP_DIR[args.data_name] + '/RL-log-merge/' + test_filename + '.txt' 140 | with open(PATH, 'a') as f: 141 | #f.write('Training epocch:{}\n'.format(i_episode)) 142 | f.write('===========Test Turn===============\n') 143 | f.write('Testing {} user tuples\n'.format(user_num)) 144 | for i in range(len(SRturn_all)): 145 | f.write('Testing SR-turn@{}: {}\n'.format(i, SRturn_all[i])) 146 | f.write('================================\n') 147 | 148 | 149 | def main(): 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument('--seed', '-seed', type=int, default=1, help='random seed.') 152 | parser.add_argument('--gpu', type=str, default='0', help='gpu device.') 153 | parser.add_argument('--epochs', '-me', type=int, default=50000, help='the number of RL train epoch') 154 | parser.add_argument('--fm_epoch', type=int, default=0, help='the epoch of FM embedding') 155 | parser.add_argument('--batch_size', type=int, default=128, help='batch size.') 156 | parser.add_argument('--gamma', type=float, default=0.999, help='reward discount factor.') 157 | parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate.') 158 | parser.add_argument('--l2_norm', type=float, default=1e-6, help='l2 regularization.') 159 | parser.add_argument('--hidden', type=int, default=100, help='number of samples') 160 | parser.add_argument('--memory_size', type=int, default=50000, help='size of memory ') 161 | 162 | parser.add_argument('--data_name', type=str, default=LAST_FM, choices=[LAST_FM, LAST_FM_STAR, YELP, YELP_STAR], 163 | help='One of {LAST_FM, LAST_FM_STAR, YELP, YELP_STAR}.') 164 | parser.add_argument('--entropy_method', type=str, default='weight_entropy', help='entropy_method is one of {entropy, weight entropy}') 165 | # Although the performance of 'weighted entropy' is better, 'entropy' is an alternative method considering the time cost. 166 | parser.add_argument('--max_turn', type=int, default=15, help='max conversation turn') 167 | parser.add_argument('--cand_len_size', type=int, default=20, help='binary state size for the length of candidate items') 168 | parser.add_argument('--attr_num', type=int, help='the number of attributes') 169 | parser.add_argument('--mode', type=str, default='train', help='the mode in [train, test]') 170 | parser.add_argument('--ask_num', type=int, default=1, help='the number of features asked in a turn') 171 | parser.add_argument('--observe_num', type=int, default=500, help='the number of epochs to save RL model and metric') 172 | parser.add_argument('--load_rl_epoch', type=int, default=0, help='the epoch of loading RL model') 173 | 174 | parser.add_argument('--sample_times', type=int, default=100, help='the epoch of sampling') 175 | parser.add_argument('--max_steps', type=int, default=100, help='max training steps') 176 | parser.add_argument('--eval_num', type=int, default=10, help='the number of epochs to save RL model and metric') 177 | parser.add_argument('--cand_num', type=int, default=10, help='candidate sampling number') 178 | parser.add_argument('--cand_item_num', type=int, default=10, help='candidate item sampling number') 179 | parser.add_argument('--fix_emb', type=bool, default=True, help='fix embedding or not') 180 | parser.add_argument('--embed', type=str, default='transe', help='pretrained embeddings') 181 | parser.add_argument('--seq', type=str, default='transformer', choices=['rnn', 'transformer', 'mean'], help='sequential learning method') 182 | parser.add_argument('--gcn', action='store_false', help='use GCN or not') 183 | 184 | 185 | args = parser.parse_args() 186 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 187 | args.device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 188 | print(args.device) 189 | print('data_set:{}'.format(args.data_name)) 190 | kg = load_kg(args.data_name) 191 | #reset attr_num 192 | feature_name = FeatureDict[args.data_name] 193 | feature_length = len(kg.G[feature_name].keys()) 194 | print('dataset:{}, feature_length:{}'.format(args.data_name, feature_length)) 195 | args.attr_num = feature_length # set attr_num = feature_length 196 | print('args.attr_num:', args.attr_num) 197 | print('args.entropy_method:', args.entropy_method) 198 | 199 | dataset = load_dataset(args.data_name) 200 | filename = 'train-data-{}-RL-cand_num-{}-cand_item_num-{}-embed-{}-seq-{}-gcn-{}'.format( 201 | args.data_name, args.cand_num, args.cand_item_num, args.embed, args.seq, args.gcn) 202 | evaluate(args, kg, dataset, filename) 203 | 204 | if __name__ == '__main__': 205 | main() -------------------------------------------------------------------------------- /RL_model.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import random 4 | import numpy as np 5 | import os 6 | import sys 7 | from tqdm import tqdm 8 | # sys.path.append('..') 9 | 10 | from collections import namedtuple 11 | import argparse 12 | from itertools import count, chain 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torch.nn.functional as F 17 | from utils import * 18 | from sum_tree import SumTree 19 | 20 | #TODO select env 21 | from RL.env_binary_question import BinaryRecommendEnv 22 | from RL.env_enumerated_question import EnumeratedRecommendEnv 23 | from RL.RL_evaluate import dqn_evaluate 24 | from gcn import GraphEncoder 25 | import time 26 | import warnings 27 | 28 | warnings.filterwarnings("ignore") 29 | EnvDict = { 30 | LAST_FM: BinaryRecommendEnv, 31 | LAST_FM_STAR: BinaryRecommendEnv, 32 | YELP: EnumeratedRecommendEnv, 33 | YELP_STAR: BinaryRecommendEnv 34 | } 35 | FeatureDict = { 36 | LAST_FM: 'feature', 37 | LAST_FM_STAR: 'feature', 38 | YELP: 'large_feature', 39 | YELP_STAR: 'feature' 40 | } 41 | 42 | Transition = namedtuple('Transition', 43 | ('state', 'action', 'next_state', 'reward', 'next_cand')) 44 | 45 | class ReplayMemory(object): 46 | 47 | def __init__(self, capacity): 48 | self.capacity = capacity 49 | self.memory = [] 50 | self.position = 0 51 | 52 | def push(self, *args): 53 | if len(self.memory) < self.capacity: 54 | self.memory.append(None) 55 | self.memory[self.position] = Transition(*args) 56 | self.position = (self.position + 1) % self.capacity 57 | 58 | def sample(self, batch_size): 59 | return random.sample(self.memory, batch_size) 60 | 61 | def __len__(self): 62 | return len(self.memory) 63 | 64 | 65 | class ReplayMemoryPER(object): 66 | # stored as ( s, a, r, s_ ) in SumTree 67 | def __init__(self, capacity, a = 0.6, e = 0.01): 68 | self.tree = SumTree(capacity) 69 | self.capacity = capacity 70 | self.prio_max = 0.1 71 | self.a = a 72 | self.e = e 73 | self.beta = 0.4 74 | self.beta_increment_per_sampling = 0.001 75 | 76 | def push(self, *args): 77 | data = Transition(*args) 78 | p = (np.abs(self.prio_max) + self.e) ** self.a # proportional priority 79 | self.tree.add(p, data) 80 | 81 | def sample(self, batch_size): 82 | batch_data = [] 83 | idxs = [] 84 | segment = self.tree.total() / batch_size 85 | priorities = [] 86 | 87 | for i in range(batch_size): 88 | a = segment * i 89 | b = segment * (i + 1) 90 | s = random.uniform(a, b) 91 | idx, p, data = self.tree.get(s) 92 | 93 | batch_data.append(data) 94 | priorities.append(p) 95 | idxs.append(idx) 96 | 97 | self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) 98 | 99 | sampling_probabilities = priorities / self.tree.total() 100 | is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.beta) 101 | is_weight /= is_weight.max() 102 | 103 | return idxs, batch_data, is_weight 104 | 105 | def update(self, idxs, errors): 106 | self.prio_max = max(self.prio_max, max(np.abs(errors))) 107 | for i, idx in enumerate(idxs): 108 | p = (np.abs(errors[i]) + self.e) ** self.a 109 | self.tree.update(idx, p) 110 | 111 | def __len__(self): 112 | return self.tree.n_entries 113 | 114 | 115 | class DQN(nn.Module): 116 | def __init__(self, state_size, action_size, hidden_size=100): 117 | super(DQN, self).__init__() 118 | # V(s) 119 | self.fc2_value = nn.Linear(hidden_size, hidden_size) 120 | self.out_value = nn.Linear(hidden_size, 1) 121 | # Q(s,a) 122 | self.fc2_advantage = nn.Linear(hidden_size + action_size, hidden_size) 123 | self.out_advantage = nn.Linear(hidden_size, 1) 124 | 125 | def forward(self, x, y, choose_action=True): 126 | """ 127 | :param x: encode history [N*L*D]; y: action embedding [N*K*D] 128 | :return: v: action score [N*K] 129 | """ 130 | # V(s) 131 | value = self.out_value(F.relu(self.fc2_value(x))).squeeze(dim=2) #[N*1*1] 132 | # Q(s,a) 133 | if choose_action: 134 | x = x.repeat(1, y.size(1), 1) 135 | state_cat_action = torch.cat((x,y),dim=2) 136 | advantage = self.out_advantage(F.relu(self.fc2_advantage(state_cat_action))).squeeze(dim=2) #[N*K] 137 | 138 | if choose_action: 139 | qsa = advantage + value - advantage.mean(dim=1, keepdim=True) 140 | else: 141 | qsa = advantage + value 142 | 143 | return qsa 144 | 145 | 146 | class Agent(object): 147 | def __init__(self, device, memory, state_size, action_size, hidden_size, gcn_net, learning_rate, l2_norm, PADDING_ID, EPS_START = 0.9, EPS_END = 0.1, EPS_DECAY = 0.0001, tau=0.01): 148 | self.EPS_START = EPS_START 149 | self.EPS_END = EPS_END 150 | self.EPS_DECAY = EPS_DECAY 151 | self.steps_done = 0 152 | self.device = device 153 | self.gcn_net = gcn_net 154 | self.policy_net = DQN(state_size, action_size, hidden_size).to(device) 155 | self.target_net = DQN(state_size, action_size, hidden_size).to(device) 156 | self.target_net.load_state_dict(self.policy_net.state_dict()) 157 | self.target_net.eval() 158 | self.optimizer = optim.Adam(chain(self.policy_net.parameters(),self.gcn_net.parameters()), lr=learning_rate, weight_decay = l2_norm) 159 | self.memory = memory 160 | self.loss_func = nn.MSELoss() 161 | self.PADDING_ID = PADDING_ID 162 | self.tau = tau 163 | 164 | 165 | def select_action(self, state, cand, action_space, is_test=False, is_last_turn=False): 166 | state_emb = self.gcn_net([state]) 167 | cand = torch.LongTensor([cand]).to(self.device) 168 | cand_emb = self.gcn_net.embedding(cand) 169 | sample = random.random() 170 | eps_threshold = self.EPS_END + (self.EPS_START - self.EPS_END) * \ 171 | math.exp(-1. * self.steps_done / self.EPS_DECAY) 172 | self.steps_done += 1 173 | if is_test or sample > eps_threshold: 174 | if is_test and (len(action_space[1]) <= 10 or is_last_turn): 175 | return torch.tensor(action_space[1][0], device=self.device, dtype=torch.long), action_space[1] 176 | with torch.no_grad(): 177 | actions_value = self.policy_net(state_emb, cand_emb) 178 | print(sorted(list(zip(cand[0].tolist(), actions_value[0].tolist())), key=lambda x: x[1], reverse=True)) 179 | action = cand[0][actions_value.argmax().item()] 180 | sorted_actions = cand[0][actions_value.sort(1, True)[1].tolist()] 181 | return action, sorted_actions.tolist() 182 | else: 183 | shuffled_cand = action_space[0]+action_space[1] 184 | random.shuffle(shuffled_cand) 185 | return torch.tensor(shuffled_cand[0], device=self.device, dtype=torch.long), shuffled_cand 186 | 187 | def update_target_model(self): 188 | #soft assign 189 | for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()): 190 | target_param.data.copy_(self.tau * param.data + target_param.data * (1.0 - self.tau)) 191 | 192 | def optimize_model(self, BATCH_SIZE, GAMMA): 193 | if len(self.memory) < BATCH_SIZE: 194 | return 195 | 196 | self.update_target_model() 197 | 198 | idxs, transitions, is_weights = self.memory.sample(BATCH_SIZE) 199 | batch = Transition(*zip(*transitions)) 200 | 201 | state_emb_batch = self.gcn_net(list(batch.state)) 202 | action_batch = torch.LongTensor(np.array(batch.action).astype(int).reshape(-1, 1)).to(self.device) #[N*1] 203 | action_emb_batch = self.gcn_net.embedding(action_batch) 204 | reward_batch = torch.FloatTensor(np.array(batch.reward).astype(float).reshape(-1, 1)).to(self.device) 205 | non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, 206 | batch.next_state)), device=self.device, dtype=torch.uint8) 207 | n_states = [] 208 | n_cands = [] 209 | for s, c in zip(batch.next_state, batch.next_cand): 210 | if s is not None: 211 | n_states.append(s) 212 | n_cands.append(c) 213 | next_state_emb_batch = self.gcn_net(n_states) 214 | next_cand_batch = self.padding(n_cands) 215 | next_cand_emb_batch = self.gcn_net.embedding(next_cand_batch) 216 | 217 | q_eval = self.policy_net(state_emb_batch, action_emb_batch, choose_action=False) 218 | 219 | # Double DQN 220 | best_actions = torch.gather(input=next_cand_batch, dim=1, index=self.policy_net(next_state_emb_batch, next_cand_emb_batch).argmax(dim=1).view(len(n_states),1).to(self.device)) 221 | best_actions_emb = self.gcn_net.embedding(best_actions) 222 | q_target = torch.zeros((BATCH_SIZE,1), device=self.device) 223 | q_target[non_final_mask] = self.target_net(next_state_emb_batch,best_actions_emb,choose_action=False).detach() 224 | q_target = reward_batch + GAMMA * q_target 225 | 226 | # prioritized experience replay 227 | errors = (q_eval - q_target).detach().cpu().squeeze().tolist() 228 | self.memory.update(idxs, errors) 229 | 230 | # mean squared error loss to minimize 231 | loss = (torch.FloatTensor(is_weights).to(self.device) * self.loss_func(q_eval, q_target)).mean() 232 | self.optimizer.zero_grad() 233 | loss.backward() 234 | for param in self.policy_net.parameters(): 235 | param.grad.data.clamp_(-1, 1) 236 | self.optimizer.step() 237 | 238 | return loss.data 239 | 240 | def save_model(self, data_name, filename, epoch_user): 241 | save_rl_agent(dataset=data_name, model={'policy': self.policy_net.state_dict(), 'gcn': self.gcn_net.state_dict()}, filename=filename, epoch_user=epoch_user) 242 | def load_model(self, data_name, filename, epoch_user): 243 | model_dict = load_rl_agent(dataset=data_name, filename=filename, epoch_user=epoch_user) 244 | self.policy_net.load_state_dict(model_dict['policy']) 245 | self.gcn_net.load_state_dict(model_dict['gcn']) 246 | 247 | def padding(self, cand): 248 | pad_size = max([len(c) for c in cand]) 249 | padded_cand = [] 250 | for c in cand: 251 | cur_size = len(c) 252 | new_c = np.ones((pad_size)) * self.PADDING_ID 253 | new_c[:cur_size] = c 254 | padded_cand.append(new_c) 255 | return torch.LongTensor(padded_cand).to(self.device) 256 | 257 | 258 | def train(args, kg, dataset, filename): 259 | env = EnvDict[args.data_name](kg, dataset, args.data_name, args.embed, seed=args.seed, max_turn=args.max_turn, cand_num=args.cand_num, cand_item_num=args.cand_item_num, 260 | attr_num=args.attr_num, mode='train', ask_num=args.ask_num, entropy_way=args.entropy_method, fm_epoch=args.fm_epoch) 261 | set_random_seed(args.seed) 262 | memory = ReplayMemoryPER(args.memory_size) #50000 263 | embed = torch.FloatTensor(np.concatenate((env.ui_embeds, env.feature_emb, np.zeros((1,env.ui_embeds.shape[1]))), axis=0)) 264 | gcn_net = GraphEncoder(device=args.device, entity=embed.size(0), emb_size=embed.size(1), kg=kg, embeddings=embed, \ 265 | fix_emb=args.fix_emb, seq=args.seq, gcn=args.gcn, hidden_size=args.hidden).to(args.device) 266 | agent = Agent(device=args.device, memory=memory, state_size=args.hidden, action_size=embed.size(1), \ 267 | hidden_size=args.hidden, gcn_net=gcn_net, learning_rate=args.learning_rate, l2_norm=args.l2_norm, PADDING_ID=embed.size(0)-1) 268 | # self.reward_dict = { 269 | # 'ask_suc': 0.01, 270 | # 'ask_fail': -0.1, 271 | # 'rec_suc': 1, 272 | # 'rec_fail': -0.1, 273 | # 'until_T': -0.3, # until MAX_Turn 274 | # } 275 | #ealuation metric ST@T 276 | #agent load policy parameters 277 | if args.load_rl_epoch != 0 : 278 | print('Staring loading rl model in epoch {}'.format(args.load_rl_epoch)) 279 | agent.load_model(data_name=args.data_name, filename=filename, epoch_user=args.load_rl_epoch) 280 | 281 | test_performance = [] 282 | if args.eval_num == 1: 283 | SR15_mean = dqn_evaluate(args, kg, dataset, agent, filename, 0) 284 | test_performance.append(SR15_mean) 285 | for train_step in range(1, args.max_steps+1): 286 | SR5, SR10, SR15, AvgT, Rank, total_reward = 0., 0., 0., 0., 0., 0. 287 | loss = torch.tensor(0, dtype=torch.float, device=args.device) 288 | for i_episode in tqdm(range(args.sample_times),desc='sampling'): 289 | #blockPrint() 290 | print('\n================new tuple:{}===================='.format(i_episode)) 291 | if not args.fix_emb: 292 | state, cand, action_space = env.reset(agent.gcn_net.embedding.weight.data.cpu().detach().numpy()) # Reset environment and record the starting state 293 | else: 294 | state, cand, action_space = env.reset() 295 | #state = torch.unsqueeze(torch.FloatTensor(state), 0).to(args.device) 296 | epi_reward = 0 297 | is_last_turn = False 298 | for t in count(): # user dialog 299 | if t == 14: 300 | is_last_turn = True 301 | action, sorted_actions = agent.select_action(state, cand, action_space, is_last_turn=is_last_turn) 302 | if not args.fix_emb: 303 | next_state, next_cand, action_space, reward, done = env.step(action.item(), sorted_actions, agent.gcn_net.embedding.weight.data.cpu().detach().numpy()) 304 | else: 305 | next_state, next_cand, action_space, reward, done = env.step(action.item(), sorted_actions) 306 | epi_reward += reward 307 | reward = torch.tensor([reward], device=args.device, dtype=torch.float) 308 | if done: 309 | next_state = None 310 | 311 | agent.memory.push(state, action, next_state, reward, next_cand) 312 | state = next_state 313 | cand = next_cand 314 | 315 | newloss = agent.optimize_model(args.batch_size, args.gamma) 316 | if newloss is not None: 317 | loss += newloss 318 | 319 | if done: 320 | # every episode update the target model to be same with model 321 | if reward.item() == 1: # recommend successfully 322 | if t < 5: 323 | SR5 += 1 324 | SR10 += 1 325 | SR15 += 1 326 | elif t < 10: 327 | SR10 += 1 328 | SR15 += 1 329 | else: 330 | SR15 += 1 331 | Rank += (1/math.log(t+3,2) + (1/math.log(t+2,2)-1/math.log(t+3,2))/math.log(done+1,2)) 332 | else: 333 | Rank += 0 334 | AvgT += t+1 335 | total_reward += epi_reward 336 | break 337 | enablePrint() # Enable print function 338 | print('loss : {} in epoch_uesr {}'.format(loss.item()/args.sample_times, args.sample_times)) 339 | print('SR5:{}, SR10:{}, SR15:{}, AvgT:{}, Rank:{}, rewards:{} ' 340 | 'Total epoch_uesr:{}'.format(SR5 / args.sample_times, SR10 / args.sample_times, SR15 / args.sample_times, 341 | AvgT / args.sample_times, Rank / args.sample_times, total_reward / args.sample_times, args.sample_times)) 342 | 343 | if train_step % args.eval_num == 0: 344 | SR15_mean = dqn_evaluate(args, kg, dataset, agent, filename, train_step) 345 | test_performance.append(SR15_mean) 346 | if train_step % args.save_num == 0: 347 | agent.save_model(data_name=args.data_name, filename=filename, epoch_user=train_step) 348 | print(test_performance) 349 | 350 | 351 | def main(): 352 | parser = argparse.ArgumentParser() 353 | parser.add_argument('--seed', '-seed', type=int, default=1, help='random seed.') 354 | parser.add_argument('--gpu', type=str, default='0', help='gpu device.') 355 | parser.add_argument('--epochs', '-me', type=int, default=50000, help='the number of RL train epoch') 356 | parser.add_argument('--fm_epoch', type=int, default=0, help='the epoch of FM embedding') 357 | parser.add_argument('--batch_size', type=int, default=128, help='batch size.') 358 | parser.add_argument('--gamma', type=float, default=0.999, help='reward discount factor.') 359 | parser.add_argument('--learning_rate', type=float, default=1e-4, help='learning rate.') 360 | parser.add_argument('--l2_norm', type=float, default=1e-6, help='l2 regularization.') 361 | parser.add_argument('--hidden', type=int, default=100, help='number of samples') 362 | parser.add_argument('--memory_size', type=int, default=50000, help='size of memory ') 363 | 364 | parser.add_argument('--data_name', type=str, default=LAST_FM, choices=[LAST_FM, LAST_FM_STAR, YELP, YELP_STAR], 365 | help='One of {LAST_FM, LAST_FM_STAR, YELP, YELP_STAR}.') 366 | parser.add_argument('--entropy_method', type=str, default='weight_entropy', help='entropy_method is one of {entropy, weight entropy}') 367 | # Although the performance of 'weighted entropy' is better, 'entropy' is an alternative method considering the time cost. 368 | parser.add_argument('--max_turn', type=int, default=15, help='max conversation turn') 369 | parser.add_argument('--attr_num', type=int, help='the number of attributes') 370 | parser.add_argument('--mode', type=str, default='train', help='the mode in [train, test]') 371 | parser.add_argument('--ask_num', type=int, default=1, help='the number of features asked in a turn') 372 | parser.add_argument('--load_rl_epoch', type=int, default=0, help='the epoch of loading RL model') 373 | 374 | parser.add_argument('--sample_times', type=int, default=100, help='the epoch of sampling') 375 | parser.add_argument('--max_steps', type=int, default=100, help='max training steps') 376 | parser.add_argument('--eval_num', type=int, default=10, help='the number of steps to evaluate RL model and metric') 377 | parser.add_argument('--save_num', type=int, default=10, help='the number of steps to save RL model and metric') 378 | parser.add_argument('--observe_num', type=int, default=500, help='the number of steps to print metric') 379 | parser.add_argument('--cand_num', type=int, default=10, help='candidate sampling number') 380 | parser.add_argument('--cand_item_num', type=int, default=10, help='candidate item sampling number') 381 | parser.add_argument('--fix_emb', action='store_false', help='fix embedding or not') 382 | parser.add_argument('--embed', type=str, default='transe', help='pretrained embeddings') 383 | parser.add_argument('--seq', type=str, default='transformer', choices=['rnn', 'transformer', 'mean'], help='sequential learning method') 384 | parser.add_argument('--gcn', action='store_false', help='use GCN or not') 385 | 386 | 387 | args = parser.parse_args() 388 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 389 | args.device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' 390 | print(args.device) 391 | print('data_set:{}'.format(args.data_name)) 392 | kg = load_kg(args.data_name) 393 | #reset attr_num 394 | feature_name = FeatureDict[args.data_name] 395 | feature_length = len(kg.G[feature_name].keys()) 396 | print('dataset:{}, feature_length:{}'.format(args.data_name, feature_length)) 397 | args.attr_num = feature_length # set attr_num = feature_length 398 | print('args.attr_num:', args.attr_num) 399 | print('args.entropy_method:', args.entropy_method) 400 | 401 | dataset = load_dataset(args.data_name) 402 | filename = 'train-data-{}-RL-cand_num-{}-cand_item_num-{}-embed-{}-seq-{}-gcn-{}'.format( 403 | args.data_name, args.cand_num, args.cand_item_num, args.embed, args.seq, args.gcn) 404 | train(args, kg, dataset, filename) 405 | 406 | if __name__ == '__main__': 407 | main() -------------------------------------------------------------------------------- /RL/env_binary_question.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import numpy as np 4 | import os 5 | import random 6 | from utils import * 7 | from torch import nn 8 | 9 | from tkinter import _flatten 10 | from collections import Counter 11 | class BinaryRecommendEnv(object): 12 | def __init__(self, kg, dataset, data_name, embed, seed=1, max_turn=15, cand_num=10, cand_item_num=10, attr_num=20, mode='train', ask_num=1, entropy_way='weight entropy', fm_epoch=0): 13 | self.data_name = data_name 14 | self.mode = mode 15 | self.seed = seed 16 | self.max_turn = max_turn #MAX_TURN 17 | self.attr_state_num = attr_num 18 | self.kg = kg 19 | self.dataset = dataset 20 | self.feature_length = getattr(self.dataset, 'feature').value_len 21 | self.user_length = getattr(self.dataset, 'user').value_len 22 | self.item_length = getattr(self.dataset, 'item').value_len 23 | 24 | # action parameters 25 | self.ask_num = ask_num 26 | self.rec_num = 10 27 | self.random_sample_feature = False 28 | self.random_sample_item = False 29 | if cand_num == 0: 30 | self.cand_num = 10 31 | self.random_sample_feature = True 32 | else: 33 | self.cand_num = cand_num 34 | if cand_item_num == 0: 35 | self.cand_item_num = 10 36 | self.random_sample_item = True 37 | else: 38 | self.cand_item_num = cand_item_num 39 | # entropy or weight entropy 40 | self.ent_way = entropy_way 41 | 42 | # user's profile 43 | self.reachable_feature = [] # user reachable feature 44 | self.user_acc_feature = [] # user accepted feature which asked by agent 45 | self.user_rej_feature = [] # user rejected feature which asked by agent 46 | self.cand_items = [] # candidate items 47 | self.item_feature_pair = {} 48 | self.cand_item_score = [] 49 | 50 | #user_id item_id cur_step cur_node_set 51 | self.user_id = None 52 | self.target_item = None 53 | self.cur_conver_step = 0 # the number of conversation in current step 54 | self.cur_node_set = [] # maybe a node or a node set / normally save feature node 55 | # state veactor 56 | self.user_embed = None 57 | self.conver_his = [] #conversation_history 58 | self.attr_ent = [] # attribute entropy 59 | 60 | self.ui_dict = self.__load_rl_data__(data_name, mode=mode) # np.array [ u i weight] 61 | self.user_weight_dict = dict() 62 | self.user_items_dict = dict() 63 | 64 | #init seed & init user_dict 65 | set_random_seed(self.seed) # set random seed 66 | if mode == 'train': 67 | self.__user_dict_init__() # init self.user_weight_dict and self.user_items_dict 68 | elif mode == 'test': 69 | self.ui_array = None # u-i array [ [userID1, itemID1], ...,[userID2, itemID2]] 70 | self.__test_tuple_generate__() 71 | self.test_num = 0 72 | # embeds = { 73 | # 'ui_emb': ui_emb, 74 | # 'feature_emb': feature_emb 75 | # } 76 | # load fm epoch 77 | embeds = load_embed(data_name, embed, epoch=fm_epoch) 78 | if embeds: 79 | self.ui_embeds =embeds['ui_emb'] 80 | self.feature_emb = embeds['feature_emb'] 81 | else: 82 | self.ui_embeds = nn.Embedding(self.user_length+self.item_length, 64).weight.data.numpy() 83 | self.feature_emb = nn.Embedding(self.feature_length, 64).weight.data.numpy() 84 | # self.feature_length = self.feature_emb.shape[0]-1 85 | 86 | self.action_space = 2 87 | 88 | self.reward_dict = { 89 | 'ask_suc': 0.01, 90 | 'ask_fail': -0.1, 91 | 'rec_suc': 1, 92 | 'rec_fail': -0.1, 93 | 'until_T': -0.3, # MAX_Turn 94 | 'cand_none': -0.1 95 | } 96 | self.history_dict = { 97 | 'ask_suc': 1, 98 | 'ask_fail': -1, 99 | 'rec_scu': 2, 100 | 'rec_fail': -2, 101 | 'until_T': 0 102 | } 103 | self.attr_count_dict = dict() # This dict is used to calculate entropy 104 | 105 | def __load_rl_data__(self, data_name, mode): 106 | if mode == 'train': 107 | with open(os.path.join(DATA_DIR[data_name], 'UI_Interaction_data/review_dict_valid.json'), encoding='utf-8') as f: 108 | print('train_data: load RL valid data') 109 | mydict = json.load(f) 110 | elif mode == 'test': 111 | with open(os.path.join(DATA_DIR[data_name], 'UI_Interaction_data/review_dict_test.json'), encoding='utf-8') as f: 112 | print('test_data: load RL test data') 113 | mydict = json.load(f) 114 | return mydict 115 | 116 | 117 | def __user_dict_init__(self): #Calculate the weight of the number of interactions per user 118 | ui_nums = 0 119 | for items in self.ui_dict.values(): 120 | ui_nums += len(items) 121 | for user_str in self.ui_dict.keys(): 122 | user_id = int(user_str) 123 | self.user_weight_dict[user_id] = len(self.ui_dict[user_str])/ui_nums 124 | print('user_dict init successfully!') 125 | 126 | def __test_tuple_generate__(self): 127 | ui_list = [] 128 | for user_str, items in self.ui_dict.items(): 129 | user_id = int(user_str) 130 | for item_id in items: 131 | ui_list.append([user_id, item_id]) 132 | self.ui_array = np.array(ui_list) 133 | np.random.shuffle(self.ui_array) 134 | 135 | def reset(self, embed=None): 136 | if embed is not None: 137 | self.ui_embeds = embed[:self.user_length+self.item_length] 138 | self.feature_emb = embed[self.user_length+self.item_length:] 139 | #init user_id item_id cur_step cur_node_set 140 | self.cur_conver_step = 0 #reset cur_conversation step 141 | self.cur_node_set = [] 142 | if self.mode == 'train': 143 | users = list(self.user_weight_dict.keys()) 144 | # self.user_id = np.random.choice(users, p=list(self.user_weight_dict.values())) # select user according to user weights 145 | self.user_id = np.random.choice(users) 146 | self.target_item = np.random.choice(self.ui_dict[str(self.user_id)]) 147 | elif self.mode == 'test': 148 | self.user_id = self.ui_array[self.test_num, 0] 149 | self.target_item = self.ui_array[self.test_num, 1] 150 | self.test_num += 1 151 | 152 | # init user's profile 153 | print('-----------reset state vector------------') 154 | print('user_id:{}, target_item:{}'.format(self.user_id, self.target_item)) 155 | self.reachable_feature = [] # user reachable feature in cur_step 156 | self.user_acc_feature = [] # user accepted feature which asked by agent 157 | self.user_rej_feature = [] # user rejected feature which asked by agent 158 | self.cand_items = list(range(self.item_length)) 159 | 160 | # init state vector 161 | self.user_embed = self.ui_embeds[self.user_id].tolist() # init user_embed np.array---list 162 | self.conver_his = [0] * self.max_turn # conversation_history 163 | self.attr_ent = [0] * self.attr_state_num # attribute entropy 164 | 165 | # initialize dialog by randomly asked a question from ui interaction 166 | user_like_random_fea = random.choice(self.kg.G['item'][self.target_item]['belong_to']) 167 | self.user_acc_feature.append(user_like_random_fea) #update user acc_fea 168 | self.cur_node_set.append(user_like_random_fea) 169 | self._update_cand_items(user_like_random_fea, acc_rej=True) 170 | self._updata_reachable_feature() # self.reachable_feature = [] 171 | self.conver_his[self.cur_conver_step] = self.history_dict['ask_suc'] 172 | self.cur_conver_step += 1 173 | 174 | print('=== init user prefer feature: {}'.format(self.cur_node_set)) 175 | self._update_feature_entropy() #update entropy 176 | print('reset_reachable_feature num: {}'.format(len(self.reachable_feature))) 177 | 178 | # Sort reachable features according to the entropy of features 179 | reach_fea_score = self._feature_score() 180 | max_ind_list = [] 181 | for k in range(self.cand_num): 182 | max_score = max(reach_fea_score) 183 | max_ind = reach_fea_score.index(max_score) 184 | reach_fea_score[max_ind] = 0 185 | if max_ind in max_ind_list: 186 | break 187 | max_ind_list.append(max_ind) 188 | 189 | max_fea_id = [self.reachable_feature[i] for i in max_ind_list] 190 | [self.reachable_feature.remove(v) for v in max_fea_id] 191 | [self.reachable_feature.insert(0, v) for v in max_fea_id[::-1]] 192 | 193 | return self._get_state(), self._get_cand(), self._get_action_space() 194 | 195 | def _get_cand(self): 196 | if self.random_sample_feature: 197 | cand_feature = self._map_to_all_id(random.sample(self.reachable_feature, min(len(self.reachable_feature),self.cand_num)),'feature') 198 | else: 199 | cand_feature = self._map_to_all_id(self.reachable_feature[:self.cand_num],'feature') 200 | if self.random_sample_item: 201 | cand_item = self._map_to_all_id(random.sample(self.cand_items, min(len(self.cand_items),self.cand_item_num)),'item') 202 | else: 203 | cand_item = self._map_to_all_id(self.cand_items[:self.cand_item_num],'item') 204 | cand = cand_feature + cand_item 205 | return cand 206 | 207 | def _get_action_space(self): 208 | action_space = [self._map_to_all_id(self.reachable_feature,'feature'), self._map_to_all_id(self.cand_items,'item')] 209 | return action_space 210 | 211 | def _get_state(self): 212 | if self.data_name in ['YELP_STAR']: 213 | self_cand_items = self.cand_items[:5000] 214 | set_cand_items = set(self_cand_items) 215 | else: 216 | self_cand_items = self.cand_items 217 | user = [self.user_id] 218 | cur_node = [x + self.user_length + self.item_length for x in self.cur_node_set] 219 | cand_items = [x + self.user_length for x in self_cand_items] 220 | reachable_feature = [x + self.user_length + self.item_length for x in self.reachable_feature] 221 | neighbors = cur_node + user + cand_items + reachable_feature 222 | 223 | idx = dict(enumerate(neighbors)) 224 | idx = {v: k for k, v in idx.items()} 225 | 226 | i = [] 227 | v = [] 228 | for item in self_cand_items: 229 | item_idx = item + self.user_length 230 | for fea in self.item_feature_pair[item]: 231 | fea_idx = fea + self.user_length + self.item_length 232 | i.append([idx[item_idx], idx[fea_idx]]) 233 | i.append([idx[fea_idx], idx[item_idx]]) 234 | v.append(1) 235 | v.append(1) 236 | 237 | user_idx = len(cur_node) 238 | cand_item_score = self.sigmoid(self.cand_item_score) 239 | for item, score in zip(self.cand_items, cand_item_score): 240 | if self.data_name in ['YELP_STAR']: 241 | if item not in set_cand_items: 242 | continue 243 | item_idx = item + self.user_length 244 | i.append([user_idx, idx[item_idx]]) 245 | i.append([idx[item_idx], user_idx]) 246 | v.append(score) 247 | v.append(score) 248 | 249 | i = torch.LongTensor(i) 250 | v = torch.FloatTensor(v) 251 | neighbors = torch.LongTensor(neighbors) 252 | adj = torch.sparse.FloatTensor(i.t(), v, torch.Size([len(neighbors),len(neighbors)])) 253 | 254 | state = {'cur_node': cur_node, 255 | 'neighbors': neighbors, 256 | 'adj': adj} 257 | return state 258 | 259 | def step(self, action, sorted_actions, embed=None): 260 | if embed is not None: 261 | self.ui_embeds = embed[:self.user_length+self.item_length] 262 | self.feature_emb = embed[self.user_length+self.item_length:] 263 | 264 | done = 0 265 | print('---------------step:{}-------------'.format(self.cur_conver_step)) 266 | 267 | if self.cur_conver_step == self.max_turn: 268 | reward = self.reward_dict['until_T'] 269 | self.conver_his[self.cur_conver_step-1] = self.history_dict['until_T'] 270 | print('--> Maximum number of turns reached !') 271 | done = 1 272 | elif action >= self.user_length + self.item_length: #ask feature 273 | asked_feature = self._map_to_old_id(action) 274 | print('-->action: ask features {}, max entropy feature {}'.format(asked_feature, self.reachable_feature[:self.cand_num])) 275 | reward, done, acc_rej = self._ask_update(asked_feature) #update user's profile: user_acc_feature & user_rej_feature 276 | self._update_cand_items(asked_feature, acc_rej) #update cand_items 277 | else: #recommend items 278 | 279 | #===================== rec update========= 280 | recom_items = [] 281 | for act in sorted_actions: 282 | if act < self.user_length + self.item_length: 283 | recom_items.append(self._map_to_old_id(act)) 284 | if len(recom_items) == self.rec_num: 285 | break 286 | reward, done = self._recommend_update(recom_items) 287 | #======================================== 288 | if reward > 0: 289 | print('-->Recommend successfully!') 290 | else: 291 | print('-->Recommend fail !') 292 | 293 | self._updata_reachable_feature() # update user's profile: reachable_feature 294 | 295 | print('reachable_feature num: {}'.format(len(self.reachable_feature))) 296 | print('cand_item num: {}'.format(len(self.cand_items))) 297 | 298 | self._update_feature_entropy() 299 | if len(self.reachable_feature) != 0: # if reachable_feature == 0 :cand_item= 1 300 | reach_fea_score = self._feature_score() # compute feature score 301 | 302 | max_ind_list = [] 303 | for k in range(self.cand_num): 304 | max_score = max(reach_fea_score) 305 | max_ind = reach_fea_score.index(max_score) 306 | reach_fea_score[max_ind] = 0 307 | if max_ind in max_ind_list: 308 | break 309 | max_ind_list.append(max_ind) 310 | max_fea_id = [self.reachable_feature[i] for i in max_ind_list] 311 | [self.reachable_feature.remove(v) for v in max_fea_id] 312 | [self.reachable_feature.insert(0, v) for v in max_fea_id[::-1]] 313 | 314 | self.cur_conver_step += 1 315 | return self._get_state(), self._get_cand(), self._get_action_space(), reward, done 316 | 317 | 318 | def _updata_reachable_feature(self): 319 | next_reachable_feature = [] 320 | reachable_item_feature_pair = {} 321 | for cand in self.cand_items: 322 | fea_belong_items = list(self.kg.G['item'][cand]['belong_to']) # A-I 323 | next_reachable_feature.extend(fea_belong_items) 324 | reachable_item_feature_pair[cand] = list(set(fea_belong_items) - set(self.user_rej_feature)) 325 | next_reachable_feature = list(set(next_reachable_feature)) 326 | self.reachable_feature = list(set(next_reachable_feature) - set(self.user_acc_feature) - set(self.user_rej_feature)) 327 | self.item_feature_pair = reachable_item_feature_pair 328 | 329 | def _feature_score(self): 330 | reach_fea_score = [] 331 | for feature_id in self.reachable_feature: 332 | ''' 333 | score = self.attr_ent[feature_id] 334 | reach_fea_score.append(score) 335 | ''' 336 | feature_embed = self.feature_emb[feature_id] 337 | score = 0 338 | score += np.inner(np.array(self.user_embed), feature_embed) 339 | prefer_embed = self.feature_emb[self.user_acc_feature, :] #np.array (x*64) 340 | for i in range(len(self.user_acc_feature)): 341 | score += np.inner(prefer_embed[i], feature_embed) 342 | if feature_id in self.user_rej_feature: 343 | score -= self.sigmoid([feature_embed, feature_embed])[0] 344 | reach_fea_score.append(score) 345 | 346 | return reach_fea_score 347 | 348 | 349 | def _item_score(self): 350 | cand_item_score = [] 351 | for item_id in self.cand_items: 352 | item_embed = self.ui_embeds[self.user_length + item_id] 353 | score = 0 354 | score += np.inner(np.array(self.user_embed), item_embed) 355 | prefer_embed = self.feature_emb[self.user_acc_feature, :] #np.array (x*64) 356 | unprefer_feature = list(set(self.user_rej_feature) & set(self.kg.G['item'][item_id]['belong_to'])) 357 | unprefer_embed = self.feature_emb[unprefer_feature, :] #np.array (x*64) 358 | for i in range(len(self.user_acc_feature)): 359 | score += np.inner(prefer_embed[i], item_embed) 360 | for i in range(len(unprefer_feature)): 361 | score -= self.sigmoid([np.inner(unprefer_embed[i], item_embed)])[0] 362 | #score -= np.inner(unprefer_embed[i], item_embed) 363 | cand_item_score.append(score) 364 | return cand_item_score 365 | 366 | 367 | def _ask_update(self, asked_feature): 368 | ''' 369 | :return: reward, acc_feature, rej_feature 370 | ''' 371 | done = 0 372 | # TODO datafram! groundTruth == target_item features 373 | feature_groundtrue = self.kg.G['item'][self.target_item]['belong_to'] 374 | 375 | if asked_feature in feature_groundtrue: 376 | acc_rej = True 377 | self.user_acc_feature.append(asked_feature) 378 | self.cur_node_set.append(asked_feature) 379 | reward = self.reward_dict['ask_suc'] 380 | self.conver_his[self.cur_conver_step] = self.history_dict['ask_suc'] #update conver_his 381 | else: 382 | acc_rej = False 383 | self.user_rej_feature.append(asked_feature) 384 | reward = self.reward_dict['ask_fail'] 385 | self.conver_his[self.cur_conver_step] = self.history_dict['ask_fail'] #update conver_his 386 | 387 | if self.cand_items == []: #candidate items is empty 388 | done = 1 389 | reward = self.reward_dict['cand_none'] 390 | 391 | return reward, done, acc_rej 392 | 393 | def _update_cand_items(self, asked_feature, acc_rej): 394 | if acc_rej: # accept feature 395 | print('=== ask acc: update cand_items') 396 | feature_items = self.kg.G['feature'][asked_feature]['belong_to'] 397 | self.cand_items = set(self.cand_items) & set(feature_items) # itersection 398 | self.cand_items = list(self.cand_items) 399 | 400 | else: # reject feature 401 | print('=== ask rej: update cand_items') 402 | 403 | #select topk candidate items to recommend 404 | cand_item_score = self._item_score() 405 | item_score_tuple = list(zip(self.cand_items, cand_item_score)) 406 | sort_tuple = sorted(item_score_tuple, key=lambda x: x[1], reverse=True) 407 | self.cand_items, self.cand_item_score = zip(*sort_tuple) 408 | 409 | def _recommend_update(self, recom_items): 410 | print('-->action: recommend items') 411 | print(set(recom_items) - set(self.cand_items[: self.rec_num])) 412 | self.cand_items = list(self.cand_items) 413 | self.cand_item_score = list(self.cand_item_score) 414 | #recom_items = self.cand_items[: self.rec_num] # TOP k item to recommend 415 | if self.target_item in recom_items: 416 | reward = self.reward_dict['rec_suc'] 417 | self.conver_his[self.cur_conver_step] = self.history_dict['rec_scu'] #update state vector: conver_his 418 | tmp_score = [] 419 | for item in recom_items: 420 | idx = self.cand_items.index(item) 421 | tmp_score.append(self.cand_item_score[idx]) 422 | self.cand_items = recom_items 423 | self.cand_item_score = tmp_score 424 | done = recom_items.index(self.target_item) + 1 425 | else: 426 | reward = self.reward_dict['rec_fail'] 427 | self.conver_his[self.cur_conver_step] = self.history_dict['rec_fail'] #update state vector: conver_his 428 | if len(self.cand_items) > self.rec_num: 429 | for item in recom_items: 430 | del self.item_feature_pair[item] 431 | idx = self.cand_items.index(item) 432 | self.cand_items.pop(idx) 433 | self.cand_item_score.pop(idx) 434 | #self.cand_items = self.cand_items[self.rec_num:] #update candidate items 435 | done = 0 436 | return reward, done 437 | 438 | def _update_feature_entropy(self): 439 | if self.ent_way == 'entropy': 440 | cand_items_fea_list = [] 441 | for item_id in self.cand_items: 442 | cand_items_fea_list.append(list(self.kg.G['item'][item_id]['belong_to'])) 443 | cand_items_fea_list = list(_flatten(cand_items_fea_list)) 444 | self.attr_count_dict = dict(Counter(cand_items_fea_list)) 445 | self.attr_ent = [0] * self.attr_state_num # reset attr_ent 446 | real_ask_able = list(set(self.reachable_feature) & set(self.attr_count_dict.keys())) 447 | for fea_id in real_ask_able: 448 | p1 = float(self.attr_count_dict[fea_id]) / len(self.cand_items) 449 | p2 = 1.0 - p1 450 | if p1 == 1: 451 | self.attr_ent[fea_id] = 0 452 | else: 453 | ent = (- p1 * np.log2(p1) - p2 * np.log2(p2)) 454 | self.attr_ent[fea_id] = ent 455 | elif self.ent_way == 'weight_entropy': 456 | cand_items_fea_list = [] 457 | self.attr_count_dict = {} 458 | #cand_item_score = self._item_score() 459 | cand_item_score_sig = self.sigmoid(self.cand_item_score) # sigmoid(score) 460 | for score_ind, item_id in enumerate(self.cand_items): 461 | cand_items_fea_list = list(self.kg.G['item'][item_id]['belong_to']) 462 | for fea_id in cand_items_fea_list: 463 | if self.attr_count_dict.get(fea_id) == None: 464 | self.attr_count_dict[fea_id] = 0 465 | self.attr_count_dict[fea_id] += cand_item_score_sig[score_ind] 466 | 467 | self.attr_ent = [0] * self.attr_state_num # reset attr_ent 468 | real_ask_able = list(set(self.reachable_feature) & set(self.attr_count_dict.keys())) 469 | sum_score_sig = sum(cand_item_score_sig) 470 | 471 | for fea_id in real_ask_able: 472 | p1 = float(self.attr_count_dict[fea_id]) / sum_score_sig 473 | p2 = 1.0 - p1 474 | if p1 == 1 or p1 <= 0: 475 | self.attr_ent[fea_id] = 0 476 | else: 477 | ent = (- p1 * np.log2(p1) - p2 * np.log2(p2)) 478 | self.attr_ent[fea_id] = ent 479 | 480 | def sigmoid(self, x_list): 481 | x_np = np.array(x_list) 482 | s = 1 / (1 + np.exp(-x_np)) 483 | return s.tolist() 484 | 485 | def _map_to_all_id(self, x_list, old_type): 486 | if old_type == 'item': 487 | return [x + self.user_length for x in x_list] 488 | elif old_type == 'feature': 489 | return [x + self.user_length + self.item_length for x in x_list] 490 | else: 491 | return x_list 492 | 493 | def _map_to_old_id(self, x): 494 | if x >= self.user_length + self.item_length: 495 | x -= (self.user_length + self.item_length) 496 | elif x >= self.user_length: 497 | x -= self.user_length 498 | return x 499 | 500 | -------------------------------------------------------------------------------- /RL/env_enumerated_question.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import numpy as np 4 | import itertools 5 | import os 6 | import random 7 | from utils import * 8 | from torch import nn 9 | 10 | from tkinter import _flatten 11 | from collections import Counter 12 | class EnumeratedRecommendEnv(object): 13 | def __init__(self, kg, dataset, data_name, embed, seed=1, max_turn=15, cand_num=10, cand_item_num=10, attr_num=20, mode='train', ask_num=1, entropy_way='weight entropy', fm_epoch=0): 14 | self.data_name = data_name 15 | self.mode = mode 16 | self.seed = seed 17 | self.max_turn = max_turn #MAX_TURN 18 | self.attr_state_num = attr_num 19 | self.kg = kg 20 | self.dataset = dataset 21 | self.feature_length = getattr(self.dataset, 'large_feature').value_len 22 | self.user_length = getattr(self.dataset, 'user').value_len 23 | self.item_length = getattr(self.dataset, 'item').value_len 24 | self.small_feature_length = getattr(self.dataset, 'feature').value_len 25 | 26 | # action parameters 27 | self.ask_num = ask_num 28 | self.rec_num = 10 29 | self.random_sample_feature = False 30 | self.random_sample_item = False 31 | if cand_num == 0: 32 | self.cand_num = 10 33 | self.random_sample_feature = True 34 | else: 35 | self.cand_num = cand_num 36 | if cand_item_num == 0: 37 | self.cand_item_num = 10 38 | self.random_sample_item = True 39 | else: 40 | self.cand_item_num = cand_item_num 41 | # entropy or weight entropy 42 | self.ent_way = entropy_way 43 | 44 | # user's profile 45 | self.reachable_feature = [] # user reachable large_feature 46 | self.reachable_small_feature = [] 47 | self.user_acc_feature = [] # user accepted large_feature which asked by agent 48 | self.user_rej_feature = [] # user rejected large_feature which asked by agent 49 | self.acc_small_fea = [] 50 | self.rej_small_fea = [] 51 | self.cand_items = [] # candidate items 52 | self.item_feature_pair = {} 53 | self.cand_item_score = [] 54 | 55 | self.invalid_small_feature = [] 56 | self.valid_small_feature = [] 57 | for fea in self.kg.G['large_feature']: 58 | self.valid_small_feature.extend(self.kg.G['large_feature'][fea]['link_to_feature']) 59 | self.invalid_small_feature = set(self.kg.G['feature'].keys()) - set(self.valid_small_feature) 60 | print(self.invalid_small_feature) 61 | 62 | 63 | #user_id item_id cur_step cur_node_set 64 | self.user_id = None 65 | self.target_item = None 66 | self.cur_conver_step = 0 # the number of conversation in current step 67 | self.cur_node_set = [] #maybe a node or a node set / normally save large_feature node 68 | # state veactor 69 | self.user_embed = None 70 | self.conver_his = [] #conversation_history 71 | self.attr_ent = [] # attribute entropy 72 | 73 | self.ui_dict = self.__load_rl_data__(data_name, mode=mode) # np.array [ u i weight] 74 | self.user_weight_dict = dict() 75 | self.user_items_dict = dict() 76 | 77 | #init seed & init user_dict 78 | set_random_seed(self.seed) # set random seed 79 | if mode == 'train': 80 | self.__user_dict_init__() # init self.user_weight_dict and self.user_items_dict 81 | elif mode == 'test': 82 | self.ui_array = None # u-i array [ [userID1, itemID1], ..., [userID2, itemID2]] 83 | self.__test_tuple_generate__() 84 | self.test_num = 0 85 | # embeds = { 86 | # 'ui_emb': ui_emb, 87 | # 'feature_emb': feature_emb 88 | # } 89 | #load fm epoch 90 | embeds = load_embed(data_name, embed, epoch=fm_epoch) 91 | if embeds: 92 | self.ui_embeds =embeds['ui_emb'] 93 | self.feature_emb = embeds['feature_emb'] 94 | else: 95 | self.ui_embeds = nn.Embedding(self.user_length+self.item_length, 64).weight.data.numpy() 96 | self.feature_emb = nn.Embedding(self.small_feature_length, 64).weight.data.numpy() 97 | # self.feature_length = self.feature_emb.shape[0]-1 98 | 99 | self.action_space = 2 100 | 101 | self.reward_dict = { 102 | 'ask_suc': 0.01, 103 | 'ask_fail': -0.1, 104 | 'rec_suc': 1, 105 | 'rec_fail': -0.1, 106 | 'until_T': -0.3, # MAX_Turn 107 | 'cand_none': -0.1 108 | } 109 | self.history_dict = { 110 | 'ask_suc': 1, 111 | 'ask_fail': -1, 112 | 'rec_scu': 2, 113 | 'rec_fail': -2, 114 | 'until_T': 0 115 | } 116 | self.attr_count_dict = dict() # This dict is used to calculate entropy 117 | 118 | def __load_rl_data__(self, data_name, mode): 119 | if mode == 'train': 120 | with open(os.path.join(DATA_DIR[data_name], 'UI_Interaction_data/review_dict_valid.json'), encoding='utf-8') as f: 121 | print('train_data: load RL valid data') 122 | mydict = json.load(f) 123 | elif mode == 'test': 124 | with open(os.path.join(DATA_DIR[data_name], 'UI_Interaction_data/review_dict_test.json'), encoding='utf-8') as f: 125 | mydict = json.load(f) 126 | return mydict 127 | 128 | 129 | def __user_dict_init__(self): #Calculate the weight of the number of interactions per user 130 | ui_nums = 0 131 | for items in self.ui_dict.values(): 132 | ui_nums += len(items) 133 | for user_str in self.ui_dict.keys(): 134 | user_id = int(user_str) 135 | self.user_weight_dict[user_id] = len(self.ui_dict[user_str])/ui_nums 136 | print('user_dict init successfully!') 137 | 138 | def __test_tuple_generate__(self): 139 | ui_list = [] 140 | for user_str, items in self.ui_dict.items(): 141 | user_id = int(user_str) 142 | for item_id in items: 143 | ui_list.append([user_id, item_id]) 144 | self.ui_array = np.array(ui_list) 145 | np.random.shuffle(self.ui_array) 146 | 147 | 148 | 149 | def reset(self, embed=None): 150 | if embed is not None: 151 | self.ui_embeds = embed[:self.user_length+self.item_length] 152 | self.feature_emb = embed[self.user_length+self.item_length:] 153 | #init user_id item_id 154 | self.cur_conver_step = 0 #reset cur_conversation step 155 | self.cur_node_set = [] 156 | if self.mode == 'train': 157 | users = list(self.user_weight_dict.keys()) 158 | #TODO select user by weight? 159 | #self.user_id = np.random.choice(users, p=list(self.user_weight_dict.values())) # select user according to user weights 160 | self.user_id = np.random.choice(users) 161 | self.target_item = np.random.choice(self.ui_dict[str(self.user_id)]) 162 | elif self.mode == 'test': 163 | self.user_id = self.ui_array[self.test_num, 0] 164 | self.target_item = self.ui_array[self.test_num, 1] 165 | self.test_num += 1 166 | # init user's profile 167 | print('-----------reset state vector------------') 168 | self.user_acc_feature = [] # user accepted large_feature which asked by agent 169 | self.user_rej_feature = [] # user rejected large_feature which asked by agent 170 | self.acc_small_fea = [] 171 | self.rej_small_fea = [] 172 | self.cand_items = list(range(self.item_length)) 173 | print('user_id:{}, target_item:{}'.format(self.user_id, self.target_item)) 174 | 175 | # init state vector 176 | self.user_embed = self.ui_embeds[self.user_id].tolist() # init user_embed np.array---list 177 | self.conver_his = [0] * self.max_turn # conversation_history 178 | self.attr_ent = [0] * self.attr_state_num # attribute entropy 179 | 180 | # initialize dialog by randomly asked a question from ui interaction 181 | user_like_random_fea = random.choice(self.kg.G['item'][self.target_item]['belong_to_large']) 182 | self.user_acc_feature.append(user_like_random_fea) #update user acc_fea 183 | self.cur_node_set.append(user_like_random_fea) 184 | self._update_cand_items(user_like_random_fea, acc_rej=True) 185 | self._updata_reachable_feature() # self.reachable_feature = [] 186 | self.conver_his[self.cur_conver_step] = self.history_dict['ask_suc'] 187 | self.cur_conver_step += 1 188 | 189 | print('=== init user prefer large_feature: {}'.format(self.cur_node_set)) 190 | #self._update_cand_items(acc_feature=self.cur_node_set, rej_feature=[]) 191 | self._update_feature_entropy() # update entropy 192 | print('reset_reachable_feature num: {}'.format(len(self.reachable_feature))) 193 | 194 | 195 | 196 | #Sort reachable features according to the entropy of features 197 | reach_fea_score = self._feature_score() 198 | max_ind_list = [] 199 | for k in range(self.cand_num): 200 | max_score = max(reach_fea_score) 201 | max_ind = reach_fea_score.index(max_score) 202 | if max_ind in max_ind_list: 203 | break 204 | reach_fea_score[max_ind] = 0 205 | max_ind_list.append(max_ind) 206 | max_fea_id = [self.reachable_feature[i] for i in max_ind_list] 207 | [self.reachable_feature.remove(v) for v in max_fea_id] 208 | [self.reachable_feature.insert(0, v) for v in max_fea_id[::-1]] 209 | 210 | return self._get_state(), self._get_cand(), self._get_action_space() 211 | 212 | def _get_cand(self): 213 | if self.random_sample_feature: 214 | cand_feature = self._map_to_all_id(random.sample(self.reachable_feature, min(len(self.reachable_feature),self.cand_num)),'feature') 215 | else: 216 | cand_feature = self._map_to_all_id(self._cand_small_feature(self.reachable_feature[:self.cand_num]),'feature') 217 | if self.random_sample_item: 218 | cand_item = self._map_to_all_id(random.sample(self.cand_items, min(len(self.cand_items),self.cand_item_num)),'item') 219 | else: 220 | cand_item = self._map_to_all_id(self.cand_items[:self.cand_item_num],'item') 221 | cand = cand_feature + cand_item 222 | return cand 223 | 224 | def _get_action_space(self): 225 | action_space = [self._map_to_all_id(self.reachable_small_feature,'feature'), self._map_to_all_id(self.cand_items,'item')] 226 | return action_space 227 | 228 | def _get_state(self): 229 | user = [self.user_id] 230 | cur_node = [x + self.user_length + self.item_length for x in self.acc_small_fea] 231 | cand_items = [x + self.user_length for x in self.cand_items] 232 | reachable_feature = [x + self.user_length + self.item_length for x in self.reachable_small_feature] 233 | neighbors = cur_node + user + cand_items + reachable_feature 234 | 235 | idx = dict(enumerate(neighbors)) 236 | idx = {v: k for k, v in idx.items()} 237 | 238 | i = [] 239 | v = [] 240 | for item in self.item_feature_pair: 241 | item_idx = item + self.user_length 242 | for fea in self.item_feature_pair[item]: 243 | fea_idx = fea + self.user_length + self.item_length 244 | i.append([idx[item_idx], idx[fea_idx]]) 245 | i.append([idx[fea_idx], idx[item_idx]]) 246 | v.append(1) 247 | v.append(1) 248 | 249 | user_idx = len(cur_node) 250 | cand_item_score = self.sigmoid(self.cand_item_score) 251 | for item, score in zip(self.cand_items, cand_item_score): 252 | item_idx = item + self.user_length 253 | i.append([user_idx, idx[item_idx]]) 254 | i.append([idx[item_idx], user_idx]) 255 | v.append(score) 256 | v.append(score) 257 | 258 | i = torch.LongTensor(i) 259 | v = torch.FloatTensor(v) 260 | neighbors = torch.LongTensor(neighbors) 261 | adj = torch.sparse.FloatTensor(i.t(), v, torch.Size([len(neighbors),len(neighbors)])) 262 | 263 | state = {'cur_node': cur_node, 264 | 'neighbors': neighbors, 265 | 'adj': adj} 266 | return state 267 | 268 | def step(self, action, sorted_actions, embed=None): 269 | if embed is not None: 270 | self.ui_embeds = embed[:self.user_length+self.item_length] 271 | self.feature_emb = embed[self.user_length+self.item_length:] 272 | 273 | done = 0 274 | print('---------------step:{}-------------'.format(self.cur_conver_step)) 275 | 276 | if self.cur_conver_step == self.max_turn: 277 | reward = self.reward_dict['until_T'] 278 | self.conver_his[self.cur_conver_step-1] = self.history_dict['until_T'] 279 | print('--> Maximum number of turns reached !') 280 | done = 1 281 | elif action >= self.user_length + self.item_length: #ask large_feature 282 | print(self._map_to_old_id(action)) 283 | asked_feature = self.kg.G['feature'][self._map_to_old_id(action)]['link_to_feature'][0] 284 | print('-->action: ask features {}, max entropy feature {}'.format(asked_feature, self.reachable_feature[:self.cand_num])) 285 | reward, done, acc_rej = self._ask_update(asked_feature) #update user's profile: user_acc_feature & user_rej_feature 286 | self._update_cand_items(asked_feature, acc_rej) #update cand_items 287 | else: #recommend items 288 | 289 | #===================== rec update========= 290 | recom_items = [] 291 | for act in sorted_actions: 292 | if act < self.user_length + self.item_length: 293 | recom_items.append(self._map_to_old_id(act)) 294 | if len(recom_items) == self.rec_num: 295 | break 296 | reward, done = self._recommend_update(recom_items) 297 | #======================================== 298 | if reward > 0: 299 | print('-->Recommend successfully!') 300 | else: 301 | print('-->Recommend fail !') 302 | 303 | self._updata_reachable_feature() # update user's profile: reachable_feature 304 | 305 | print('reachable_feature num: {}'.format(len(self.reachable_feature))) 306 | print('cand_item num: {}'.format(len(self.cand_items))) 307 | 308 | self._update_feature_entropy() 309 | if len(self.reachable_feature) != 0: # if reachable_feature == 0 :cand_item= 1 310 | reach_fea_score = self._feature_score() # compute feature score 311 | 312 | max_ind_list = [] 313 | for k in range(self.cand_num): 314 | max_score = max(reach_fea_score) 315 | max_ind = reach_fea_score.index(max_score) 316 | if max_ind in max_ind_list: 317 | break 318 | reach_fea_score[max_ind] = 0 319 | max_ind_list.append(max_ind) 320 | max_fea_id = [self.reachable_feature[i] for i in max_ind_list] 321 | [self.reachable_feature.remove(v) for v in max_fea_id] 322 | [self.reachable_feature.insert(0, v) for v in max_fea_id[::-1]] 323 | 324 | self.cur_conver_step += 1 325 | return self._get_state(), self._get_cand(), self._get_action_space(), reward, done 326 | 327 | def _updata_reachable_feature(self, start='large_feature'): 328 | next_reachable_feature = [] 329 | next_reachable_small_feature = [] 330 | reachable_item_small_feature_pair = {} 331 | 332 | filter_small_feature = [] 333 | for fea in self.user_acc_feature+self.user_rej_feature: 334 | filter_small_feature.extend(self.kg.G['large_feature'][fea]['link_to_feature']) 335 | filter_small_feature = set(filter_small_feature) 336 | 337 | for cand in self.cand_items: 338 | fea_belong_items = list(self.kg.G['item'][cand]['belong_to_large']) # A-I 339 | small_fea_belong_items = list(self.kg.G['item'][cand]['belong_to']) 340 | next_reachable_feature.extend(fea_belong_items) 341 | next_reachable_small_feature.extend(small_fea_belong_items) 342 | reachable_item_small_feature_pair[cand] = list(set(small_fea_belong_items) - filter_small_feature - self.invalid_small_feature) + self.acc_small_fea 343 | next_reachable_feature = list(set(next_reachable_feature)) 344 | next_reachable_small_feature = list(set(next_reachable_small_feature)) 345 | self.reachable_feature = list(set(next_reachable_feature) - set(self.user_acc_feature)) 346 | self.reachable_small_feature = list(set(next_reachable_small_feature) - filter_small_feature - self.invalid_small_feature) 347 | self.item_feature_pair = reachable_item_small_feature_pair 348 | 349 | def _feature_score(self): 350 | reach_fea_score = [] 351 | for feature_id in self.reachable_feature: 352 | score = self.attr_ent[feature_id] 353 | reach_fea_score.append(score) 354 | return reach_fea_score 355 | 356 | def _item_score(self): 357 | cand_item_score = [] 358 | for item_id in self.cand_items: 359 | item_embed = self.ui_embeds[self.user_length + item_id] 360 | score = 0 361 | score += np.inner(np.array(self.user_embed), item_embed) 362 | prefer_embed = self.feature_emb[self.acc_small_fea, :] #np.array (x*64), small_feature 363 | unprefer_feature = list(set(self.rej_small_fea) & set(self.kg.G['item'][item_id]['belong_to'])) 364 | unprefer_embed = self.feature_emb[unprefer_feature, :] #np.array (x*64) 365 | for i in range(len(self.acc_small_fea)): 366 | score += np.inner(prefer_embed[i], item_embed) 367 | for i in range(len(unprefer_feature)): 368 | score -= self.sigmoid([np.inner(unprefer_embed[i], item_embed)])[0] 369 | cand_item_score.append(score) 370 | return cand_item_score 371 | 372 | 373 | def _ask_update(self, asked_feature): 374 | ''' 375 | :return: reward, acc_feature, rej_feature 376 | ''' 377 | done = 0 378 | # TODO datafram! groundTruth == target_item features 379 | feature_groundtrue = self.kg.G['item'][self.target_item]['belong_to_large'] 380 | print(self.target_item, asked_feature) 381 | 382 | if asked_feature in feature_groundtrue: 383 | acc_rej = True 384 | self.user_acc_feature.append(asked_feature) 385 | self.cur_node_set.append(asked_feature) 386 | reward = self.reward_dict['ask_suc'] 387 | self.conver_his[self.cur_conver_step] = self.history_dict['ask_suc'] #update conver_his 388 | else: 389 | acc_rej = False 390 | self.user_rej_feature.append(asked_feature) 391 | reward = self.reward_dict['ask_fail'] 392 | self.conver_his[self.cur_conver_step] = self.history_dict['ask_fail'] #update conver_his 393 | 394 | if self.cand_items == []: #candidate items is empty 395 | done = 1 396 | reward = self.reward_dict['cand_none'] 397 | 398 | return reward, done, acc_rej 399 | 400 | def _update_cand_items(self, asked_feature, acc_rej): 401 | small_feature_groundtrue = self.kg.G['item'][self.target_item]['belong_to'] # TODO small_ground truth 402 | print(small_feature_groundtrue) 403 | assert self.target_item in self.cand_items 404 | if acc_rej: # accept feature 405 | print('=== ask acc: update cand_items') 406 | feature_small_ids = self.kg.G['large_feature'][asked_feature]['link_to_feature'] 407 | print(set(feature_small_ids) & set(small_feature_groundtrue)) 408 | for small_id in feature_small_ids: 409 | if small_id in small_feature_groundtrue: # user_accept small_tag 410 | print(small_id) 411 | self.acc_small_fea.append(small_id) 412 | feature_items = self.kg.G['feature'][small_id]['belong_to'] 413 | self.cand_items = set(self.cand_items) & set(feature_items) # itersection 414 | print(self.cand_items) 415 | else: #uesr reject small_tag 416 | self.rej_small_fea.append(small_id) #reject no update 417 | self.cand_items = list(self.cand_items) 418 | else: # reject feature 419 | print('=== ask rej: update cand_items') 420 | feature_small_ids = self.kg.G['large_feature'][asked_feature]['link_to_feature'] 421 | for small_id in feature_small_ids: 422 | self.rej_small_fea.append(small_id) #reject no update 423 | 424 | #select topk candidate items to recommend 425 | cand_item_score = self._item_score() 426 | item_score_tuple = list(zip(self.cand_items, cand_item_score)) 427 | sort_tuple = sorted(item_score_tuple, key=lambda x: x[1], reverse=True) 428 | self.cand_items, self.cand_item_score = zip(*sort_tuple) 429 | 430 | 431 | def _recommend_update(self, recom_items): 432 | print('-->action: recommend items') 433 | print(set(recom_items) - set(self.cand_items[: self.rec_num])) 434 | self.cand_items = list(self.cand_items) 435 | self.cand_item_score = list(self.cand_item_score) 436 | #recom_items = self.cand_items[: self.rec_num] # TOP k item to recommend 437 | if self.target_item in recom_items: 438 | reward = self.reward_dict['rec_suc'] 439 | self.conver_his[self.cur_conver_step] = self.history_dict['rec_scu'] #update state vector: conver_his 440 | tmp_score = [] 441 | for item in recom_items: 442 | idx = self.cand_items.index(item) 443 | tmp_score.append(self.cand_item_score[idx]) 444 | self.cand_items = recom_items 445 | self.cand_item_score = tmp_score 446 | done = recom_items.index(self.target_item) + 1 447 | else: 448 | reward = self.reward_dict['rec_fail'] 449 | self.conver_his[self.cur_conver_step] = self.history_dict['rec_fail'] #update state vector: conver_his 450 | if len(self.cand_items) > self.rec_num: 451 | for item in recom_items: 452 | del self.item_feature_pair[item] 453 | idx = self.cand_items.index(item) 454 | self.cand_items.pop(idx) 455 | self.cand_item_score.pop(idx) 456 | #self.cand_items = self.cand_items[self.rec_num:] #update candidate items 457 | done = 0 458 | return reward, done 459 | 460 | def _update_feature_entropy(self): 461 | if self.ent_way == 'entropy': 462 | cand_items_fea_list = [] 463 | #TODO Dataframe 464 | for item_id in self.cand_items: 465 | cand_items_fea_list.append(list(self.kg.G['item'][item_id]['belong_to'])) 466 | cand_items_fea_list = list(_flatten(cand_items_fea_list)) 467 | self.attr_count_dict = dict(Counter(cand_items_fea_list)) 468 | 469 | self.attr_ent = [0] * self.attr_state_num # reset attr_ent 470 | 471 | 472 | real_ask_able_large_fea = self.reachable_feature 473 | for large_fea_id in real_ask_able_large_fea: 474 | large_ent = 0 475 | small_feature = list(self.kg.G['large_feature'][large_fea_id]['link_to_feature']) 476 | small_feature_in_cand = list(set(small_feature) & set(self.attr_count_dict.keys())) 477 | 478 | for fea_id in small_feature_in_cand: 479 | p1 = float(self.attr_count_dict[fea_id]) / len(self.cand_items) 480 | p2 = 1.0 - p1 481 | if p1 == 1: 482 | large_ent += 0 483 | else: 484 | ent = (- p1 * np.log2(p1) - p2 * np.log2(p2)) 485 | large_ent += ent 486 | self.attr_ent[large_fea_id] =large_ent 487 | elif self.ent_way == 'weight_entropy': 488 | cand_items_fea_list = [] 489 | self.attr_count_dict = {} 490 | #cand_item_score = self._item_score() 491 | cand_item_score_sig = self.sigmoid(self.cand_item_score) # sigmoid(score) 492 | 493 | for score_ind, item_id in enumerate(self.cand_items): 494 | cand_items_fea_list = list(self.kg.G['item'][item_id]['belong_to']) 495 | for fea_id in cand_items_fea_list: 496 | if self.attr_count_dict.get(fea_id) == None: 497 | self.attr_count_dict[fea_id] = 0 498 | self.attr_count_dict[fea_id] += cand_item_score_sig[score_ind] 499 | 500 | self.attr_ent = [0] * self.attr_state_num # reset attr_ent 501 | real_ask_able_large_fea = self.reachable_feature 502 | sum_score_sig = sum(cand_item_score_sig) 503 | for large_fea_id in real_ask_able_large_fea: 504 | large_ent = 0 505 | small_feature = list(self.kg.G['large_feature'][large_fea_id]['link_to_feature']) 506 | small_feature_in_cand = list(set(small_feature) & set(self.attr_count_dict.keys())) 507 | 508 | for fea_id in small_feature_in_cand: 509 | p1 = float(self.attr_count_dict[fea_id]) / sum_score_sig 510 | p2 = 1.0 - p1 511 | if p1 == 1 or p1 <= 0: 512 | large_ent += 0 513 | else: 514 | ent = (- p1 * np.log2(p1) - p2 * np.log2(p2)) 515 | large_ent += ent 516 | self.attr_ent[large_fea_id] = large_ent 517 | def sigmoid(self, x_list): 518 | x_np = np.array(x_list) 519 | s = 1 / (1 + np.exp(-x_np)) 520 | return s.tolist() 521 | 522 | def _map_to_all_id(self, x_list, old_type): 523 | if old_type == 'item': 524 | return [x + self.user_length for x in x_list] 525 | elif old_type == 'feature': 526 | return [x + self.user_length + self.item_length for x in x_list] 527 | else: 528 | return x_list 529 | 530 | def _map_to_old_id(self, x): 531 | if x >= self.user_length + self.item_length: 532 | x -= (self.user_length + self.item_length) 533 | elif x >= self.user_length: 534 | x -= self.user_length 535 | return x 536 | 537 | def _cand_small_feature(self, cand_features): 538 | cand_small_features = [] 539 | for fea in cand_features: 540 | cand_small_features.extend(self.kg.G['large_feature'][fea]['link_to_feature']) 541 | print(list(set(cand_small_features) & set(self.reachable_small_feature))) 542 | return list(set(cand_small_features) & set(self.reachable_small_feature)) 543 | --------------------------------------------------------------------------------