├── README.md ├── util.py ├── list_film_book_en.py ├── tv_book_en.py ├── kg_category.py ├── film_book_en.py ├── join.py ├── Train.py ├── extract_kg.py └── MIFN.py /README.md: -------------------------------------------------------------------------------- 1 | # Readme--MIFN 2 | 3 | This is the our implementation for the paper: Exploring Mixed Information Flow for Cross-domain Sequential Recommendations. 4 | 5 | The datasets used in this paper are Movie, Book, Kitchen, Food domains from [Amazon data](http://jmcauley.ucsd.edu/data/amazon/). 6 | 7 | 1. `MIFN.py` is the model code. 8 | 2. `Train.py` is the training code; 9 | 3. `extract_kg.py` is the knowledge graph extraction algorithm. 10 | 11 | 12 | # Knowledge construction 13 | 14 | As for knowledge construction, we first use the metadata from Amazon http://jmcauley.ucsd.edu/data/amazon/. 15 | 16 | > There are four relationships among these knowledge triples, which are “also-buy”, “also-view”, “buy-after-viewing” and “buy-together”. 17 | 18 | We direct use these triples as part of our TKDD knowledge. 19 | 20 | As for “Movie-Book” domain, we find there maybe some adaptation relations between books and movies, 21 | 22 | > e.g., if one movie is adapted from one book, there exists such an “adaptation” relationship between them. 23 | 24 | Therefore, we crawl some triples from Wikipedia as below. 25 | 26 | ## A Crawler for Adaptation 27 | 28 | 1. `list_film_book_en.py` for [Category:Lists of Films based on books](https://en.wikipedia.org/wiki/Category:Lists_of_films_based_on_books), which saves tabular records. 29 | 30 | 2. `tv_book_en.py` for [Category:Television shows based on books](https://en.wikipedia.org/wiki/Category:Television_shows_based_on_books). 31 | 32 | 3. `film_book_en.py` for [Category:Films based on books](https://en.wikipedia.org/wiki/Category:Films_based_on_books). It is similar with tv. 33 | 34 | But note that, in `film_book_en.py`, if a book name is not found, we use the category instead. 35 | 36 | > For example, `Arabian Nights (2015 film)` belongs to `Category:Films based on One Thousand and One Nights‎`. Then `One Thousand and One Nights‎` may be the potential base. 37 | 38 | It may raise noise, and you can change codes following the pattern `tv_book_en.py` to avoid. There's little difference in corresponding codes. 39 | 40 | 41 | > File 1 outputs `film_book.txt` in 'w'(write) mode, while file 2 and 3 output in 'a'(attach) mode. You can run them in order or change modes. 42 | 43 | > `/data/film_book.txt` is a previous output file. Codes are modified recently, so re-run may result in a different version. 44 | 45 | ## Categories 46 | 47 | Categories of a film or a book can be extracted after requesting [wiki metadata api]('https://en.wikipedia.org/api/rest_v1/page/metadata/'). 48 | 49 | Codes to extract have been written in the above-mentioned .py files, but they have been commented out. In this way, you will crawl adaptation relations and categories at the same time. 50 | 51 | `get_category()` in `kg_category.py` provides another way to use `wikipedia` library. 52 | 53 | 54 | ## Match for amazon dataset 55 | 56 | Movie and Book datasets is published in (Amazon Reviews)[http://jmcauley.ucsd.edu/data/amazon/]. 57 | 58 | `join.py` matches movies and books which have adaptation relations. 59 | 60 | You can also change codes commented out in match() and ceiling() for the rule of matching and filtering. 61 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import html 2 | import re 3 | import nltk 4 | from nltk import word_tokenize, pos_tag 5 | 6 | wnl = nltk.stem.WordNetLemmatizer() 7 | stemer = nltk.stem.snowball.EnglishStemmer() 8 | 9 | 10 | def lemmatize_and_stem_all(sentence): 11 | for word, tag in pos_tag(word_tokenize(sentence)): 12 | if tag.startswith('NN'): 13 | yield stemer.stem(wnl.lemmatize(word, pos='n')) 14 | elif tag.startswith('VB'): 15 | yield stemer.stem(wnl.lemmatize(word, pos='v')) 16 | elif tag.startswith('JJ'): 17 | yield stemer.stem(wnl.lemmatize(word, pos='a')) 18 | elif tag.startswith('R'): 19 | yield stemer.stem(wnl.lemmatize(word, pos='r')) 20 | else: 21 | yield stemer.stem(word) 22 | 23 | 24 | 25 | # process HTML chars, $ and ! 26 | def trans_html(word): 27 | # if re.search('&[#\w]+;', word): 28 | # for key in dic.keys(): 29 | # word = re.sub(key, dic[key], word) 30 | word = html.unescape(word) 31 | res = re.search('!+(\w[^!]+\w)!+', word) 32 | if res: 33 | word = re.sub('!+(\w[^!]+\w)!+', r'\1', word) 34 | if not re.search('\${3,}', word): 35 | word = re.sub('\$\$([^A-Z\d])', r'ss\1', word) 36 | word = re.sub('\$([^A-Z\d])', r's\1', word) 37 | word = re.sub('([A-Za-z])\$\$', r'\1ss', word) 38 | word = re.sub('([A-Za-z])\$', r'\1s', word) 39 | return word 40 | 41 | 42 | # process sensitive words and * 43 | def trans_sensitive(title): 44 | if re.search('[^ \W\d][^ \-\'.\w][^ \W\d]', title): 45 | title = re.sub('sh*t', 'shit', title, flags=re.I) 46 | title = re.sub('f[u*][*c]k', 'fuck', title, flags=re.I) 47 | title = re.sub('m*a*s*h', 'mash', title, flags=re.I) 48 | title = re.sub('\*tm|\*r', '', title) 49 | title = re.sub('([^ \W\d])[^ \-\'.\w]([^ \W\d])', r'\1 \2', title) 50 | title = re.sub(' \*(\w)', r'\1', title) 51 | return title 52 | 53 | 54 | # flag=0: special for movie 55 | def process(title, flag): 56 | # title = re.sub('[Tt]he ', '', title) 57 | # title = re.sub('[Aa]n? ', '', title) 58 | title = re.sub('\(.*vol(ume)?[. :/\'\-\d)].*$', '', title, flags=re.I) 59 | title = re.sub(r' *[-/\[\\:,] *vol(ume)?[. :/\-\d)].*$', '', title, flags=re.I) 60 | title = re.sub(' *vol(ume)?[., :/\-\d].*$', '', title, flags=re.I) 61 | if flag == 0: 62 | title = re.sub('\(.*VHS.*\).*$', '', title, flags=re.I) 63 | title = re.sub('\[? *VHS *\]? *.*$', '', title, flags=re.I) 64 | title = re.sub(':.*DVD.*$', '', title, flags=re.I) 65 | title = re.sub(' *DVD *', '', title, flags=re.I) 66 | # title = re.sub('\(.+\) *$', '', title) 67 | title = title.strip('&-# ."\'') 68 | # title = ' '.join(lemmatize_and_stem_all(title)) 69 | return title 70 | 71 | 72 | mte = 'movie_title_entity.txt' 73 | bte = 'book_title_entity.txt' 74 | encoding = 'utf-8' 75 | 76 | 77 | def process_entity(flag): 78 | entity = mte if flag == 0 else bte 79 | with open('data/' + entity, 'r', encoding=encoding) as f1: 80 | with open('data/disposed/' + entity, 'w+', encoding=encoding) as f2: 81 | for line in f1: 82 | title = line.strip().split('\t')[1] 83 | title = process(title, 0) 84 | f2.write(title + '\n') 85 | 86 | 87 | def sort_title(flag): 88 | entity = mte if flag == 0 else bte 89 | with open('data/disposed/' + entity, 'r', encoding=encoding) as f: 90 | with open('data/disposed/sorted/' + mte, 'w+', encoding=encoding) as w: 91 | list = [process(line.strip(), flag=flag) for line in f] 92 | list.sort() 93 | for l in list: 94 | w.write(l + '\n') 95 | 96 | 97 | def shorten(title): 98 | title = re.sub(':.+$', '', title) 99 | title = re.sub(' [\d\W]+$', '', title) 100 | # title = re.sub('[Tt]he ', '', title) 101 | # title = re.sub('[Aa]n ', '', title) 102 | title = title.strip(' "\'') 103 | return title 104 | 105 | 106 | 107 | def cut(title): 108 | title = ' '.join(lemmatize_and_stem_all(shorten(title))) 109 | return title 110 | -------------------------------------------------------------------------------- /list_film_book_en.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import bs4 4 | import requests 5 | from bs4 import BeautifulSoup 6 | from opencc import OpenCC 7 | 8 | 9 | encoding = 'utf-8' 10 | domin_name = 'https://en.wikipedia.org' 11 | href1 = '/wiki/Category:Lists_of_films_based_on_books' 12 | api_metadata = 'https://en.wikipedia.org/api/rest_v1/page/metadata/' 13 | api_summary = 'https://en.wikipedia.org/api/rest_v1/page/summary/' 14 | film_list = [] 15 | pair_dic = {} 16 | catogories = [] 17 | 18 | PAD_TOKEN = '' 19 | 20 | s = requests.session() 21 | s.keep_alive = False 22 | 23 | f = open('data/film_book.txt', 'w+', encoding='utf-8') 24 | # ffc = open('data/film_category.txt', 'a', encoding='utf-8') 25 | # fbc = open('data/book_category.txt', 'a', encoding='utf-8') 26 | 27 | stop = ['', '?', '(?)', 'uncredited', '(uncredited)', '?(uncredited)'] 28 | 29 | def dispose(name): 30 | if name in stop: 31 | return '' 32 | name = re.sub('[♠♦*]', '', name) 33 | match = re.match('^"([^"]+)"$', name) 34 | if match: 35 | name = match.group(1) 36 | return name.strip() 37 | 38 | 39 | def get_title(td): 40 | if td.string: 41 | return td.string 42 | i = td.find('i') 43 | if i: 44 | name = i.string if i.string else i.get_text() 45 | return name 46 | a = td.find('a') 47 | if a: 48 | name = a.string if a.string else a.get_text() 49 | if td.get_text().startswith(name): 50 | return name 51 | return '' 52 | 53 | 54 | def get_urltoken(a, title): 55 | if a and 'href' in a.attrs: 56 | return a.attrs['href'][6:] 57 | else: 58 | return re.sub(' ', '_', title) 59 | 60 | 61 | def search(href): 62 | """ 63 | extract data from tables. 64 | """ 65 | try: 66 | r = s.get(domin_name + href, timeout=30) 67 | r.encoding = encoding 68 | soup = BeautifulSoup(r.text, 'html.parser') 69 | except: 70 | print("Network err!") 71 | 72 | wikitables = soup.find_all('table', {'class':'wikitable'}) 73 | for table in wikitables: 74 | pair = ['', ''] 75 | i, j = 0, 0 76 | for th in table.find_all('th'): 77 | if th.get_text() == 'Film': 78 | break 79 | i += 1 80 | for th in table.find_all('th'): 81 | if th.get_text() == 'Sourcework' or th.get_text() == 'Source work': 82 | break 83 | j += 1 84 | 85 | skip = 0 86 | for tr in table.find_all('tr')[1:]: 87 | 88 | if skip > 0: 89 | tds = tr.find_all('td') 90 | pair[1] = dispose(get_title(tds[0])) 91 | if pair[0] != '' and pair[1] != '': 92 | # pair_dic[pair[0]] = pair[1] 93 | print(pair[0], '\t', pair[1]) 94 | f.write('\t'.join(pair) + '\n') 95 | # get_category(tds[0].find('a'), pair[1], fbc) 96 | skip -= 1 97 | continue 98 | tds = tr.find_all('td') 99 | if tds is None: 100 | continue 101 | if 'rowspan' in tds[0].attrs: 102 | skip = int(tds[0].attrs['rowspan'])-1 103 | 104 | if len(tds) <= i or len(tds) <= j: 105 | continue 106 | 107 | pair[0] = dispose(get_title(tds[i])) 108 | pair[1] = dispose(get_title(tds[j])) 109 | if pair[0] != '' and pair[1] != '': 110 | # pair_dic[pair[0]] = pair[1] 111 | print(pair[0], '\t', pair[1]) 112 | f.write('\t'.join(pair) + '\n') 113 | # get_category(tds[i].find('a'), pair[0], ffc) 114 | # get_category(tds[j].find('a'), pair[1], fbc) 115 | 116 | 117 | def search2(href): 118 | try: 119 | r = s.get(domin_name + href, timeout=30) 120 | r.encoding = encoding 121 | soup = BeautifulSoup(r.text, 'html.parser') 122 | except: 123 | print("Network err!") 124 | 125 | wikitables = soup.find_all('table', {'class':'wikitable'}) 126 | for table in wikitables: 127 | for tr in table.find_all('tr')[1:]: 128 | tds = tr.find_all('td') 129 | book = dispose(get_title(tds[0])) 130 | # get_category(tds[0].find('a'), book, fbc) 131 | if book: 132 | for a in tds[1].find_all('a'): 133 | if a.string: 134 | print(a.string, '\t', book) 135 | f.write(a.string + '\t' + book + '\n') 136 | # get_category(a, a.string, ffc) 137 | 138 | 139 | def get_category(a, title, f): 140 | """ 141 | Get categories from metadata. 142 | """ 143 | url_token = get_urltoken(a, title) 144 | r = s.get(api_metadata + url_token) 145 | try: 146 | json = r.json() 147 | for category in json['categories']: 148 | f.write(title + '\t' + category['titles']['display'] + '\n') 149 | except: 150 | url_token = re.sub(' ', '_', title) 151 | r = s.get(api_metadata + url_token) 152 | try: 153 | json = r.json() 154 | for category in json['categories']: 155 | f.write(title + '\t' + category['titles']['display'] + '\n') 156 | except Exception as e2: 157 | print(str(e2)) 158 | 159 | 160 | 161 | def crawl(href): 162 | """ 163 | Get all category entrances and then use the method `search` for each. 164 | Specially, *List of children's books made into feature films* need the method `search2` for its different structure. 165 | """ 166 | try: 167 | r = s.get(domin_name + href, timeout=30) 168 | r.encoding = encoding 169 | soup = BeautifulSoup(r.text, 'html.parser') 170 | except: 171 | print("Network err!") 172 | categories_soup = soup.find('div', {'class': 'mw-category'}) 173 | href_list = [] 174 | for category in categories_soup.find_all('a'): 175 | if category.string in ['List of children\'s books made into feature films', 176 | 'Lists of book-based war films']: 177 | continue 178 | href_list.append(category.attrs['href']) 179 | 180 | for href in href_list: 181 | print(href) 182 | search(href) 183 | 184 | search2('/wiki/List_of_children%27s_books_made_into_feature_films') 185 | 186 | 187 | 188 | 189 | def main(): 190 | try: 191 | crawl(href1) 192 | finally: 193 | if f: 194 | f.flush() 195 | f.close() 196 | # if ffc: 197 | # ffc.flush() 198 | # ffc.close() 199 | # if fbc: 200 | # fbc.flush() 201 | # fbc.close() 202 | 203 | 204 | 205 | if __name__ == '__main__': 206 | main() -------------------------------------------------------------------------------- /tv_book_en.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import requests 4 | from bs4 import BeautifulSoup 5 | 6 | encoding = 'utf-8' 7 | domin_name = 'https://en.wikipedia.org' 8 | href1 = '/wiki/Category:Television_shows_based_on_books' 9 | film_list = [] 10 | pair_dic = [] 11 | catogories = [] 12 | 13 | PAD_TOKEN = '' 14 | 15 | s = requests.session() 16 | s.keep_alive = False 17 | 18 | f = open('data/film_book.txt', 'a', encoding='utf-8') 19 | # tf = open('dataset/0a/film_head.txt', 'w+', encoding='utf-8') 20 | 21 | def cut(str): 22 | return re.sub('|\(.*book.*\)' 23 | '|\(.*film.*\)' 24 | '|\(.*series.*\)', '', str, flags=re.I).strip() 25 | 26 | 27 | def search_infobox(infobox): 28 | # infobox = soup.find('table', {'class': 'infobox vevent'}) 29 | pair1 = ['', ''] 30 | if infobox: 31 | th = infobox.find('th') 32 | if th.string: 33 | pair1[0] = th.string 34 | 35 | origin = infobox.find('th', string=re.compile('based', flags=re.I)) 36 | if origin: 37 | td = origin.parent.find('td') 38 | if td: 39 | if td.string: 40 | pair1[1] = td.string 41 | else: 42 | text = td.get_text() 43 | a = td.find('a') 44 | if a and a.string and text.startswith(a.string): 45 | pair1[1] = a.string 46 | else: 47 | i = td.find('i') 48 | if i and i.string and text.startswith(i.string): 49 | pair1[1] = i.string 50 | else: 51 | match = re.match('(.+)(Written )?by .+', text, flags=re.I) 52 | if match: 53 | pair1[1] = match.group(1) 54 | pair1[1] = re.sub(' *\( *novel *\) *', '', pair1[1]).strip() 55 | match = re.search('[Bb]ased( on)? *(.+)$', pair1[1]) 56 | if match: 57 | pair1[1] = match.group(2) 58 | if pair1[1] == '': 59 | pair1[1] = text 60 | 61 | return pair1 62 | 63 | 64 | def search(href): 65 | try: 66 | r = s.get(domin_name + href) 67 | r.encoding = encoding 68 | soup = BeautifulSoup(r.text, 'html.parser') 69 | except: 70 | print("Network err!", href) 71 | return 72 | 73 | # film_title file_title_en book_title book_title_en 74 | pair = ['', ''] 75 | pair1 = ['', ''] 76 | pair2 = ['', ''] 77 | 78 | title = soup.find('title').string 79 | title = re.sub(' - Wikipedia', '', title) 80 | # 从词条表获取中英名,如有,获得原著名 81 | # 如果没有电影的括号注释,先查找是否有电影一栏 82 | match = re.match('^(.+)(\(.*series.*\))$', title) 83 | title_name = cut(title) 84 | found = False 85 | if match: 86 | film_columns = soup.find_all( 87 | 'span', {'class':'mw-headline'}, 88 | string=re.compile('series')) 89 | # for fc in film_columns: 90 | # tf.write(fc.string+'\n') 91 | if film_columns: 92 | found = True 93 | film_text = '' 94 | fp = None 95 | if found: 96 | for film_column in film_columns: 97 | infobox = film_column.parent.find_next_sibling('table', {'class': 'infobox vevent'}) 98 | pair = search_infobox(infobox) 99 | fp = film_column.find_next('p') 100 | if fp and title_name in fp.get_text(): 101 | film_text = fp.get_text() 102 | else: 103 | infobox = soup.find('table', {'class': 'infobox vevent'}) 104 | pair1 = search_infobox(infobox) 105 | 106 | # 从第一段简介获取,可信度较低 107 | # introduction = film_text if found else soup.find('p').get_text() 108 | cursor = soup.find('p') 109 | for i in range(10): 110 | if cursor is None or title_name in cursor.get_text(): 111 | break 112 | cursor = cursor.find_next('p') 113 | introduction = '' 114 | if found: 115 | introduction = film_text 116 | cursor = fp 117 | elif cursor: 118 | introduction = cursor.get_text() 119 | first_p = None 120 | first_para = '' 121 | bs = soup.find_all('b', string=title_name) 122 | if bs: 123 | for b in bs: 124 | if b.parent and b.parent.name == 'p': 125 | first_p = b.parent 126 | break 127 | if introduction == '' or introduction == '\n' or introduction is None: 128 | first_para = first_p.get_text() if first_p else '' 129 | introduction = first_para 130 | cursor = first_p 131 | 132 | pair2[0] = title_name 133 | 134 | same_name = False 135 | if re.search('based on.+same name', introduction, flags=re.I): 136 | pair2[1] = pair2[0] 137 | 138 | # 如果有电影一栏,并且词条不包含电影两字,极有可能是一个大词条,主栏可能是书籍 139 | pair3 = ['', ''] 140 | if found and not re.search('series', title): 141 | searchObj = re.search('[Bb]ook|[Nn]ovel|[Ww]ritten', first_para) 142 | if searchObj: 143 | pair3[1] = title_name 144 | 145 | 146 | 147 | for i in range(2): 148 | pair[i] = pair1[i] if pair1[i] else pair2[i] 149 | pair[1] = pair[1] if pair[1] else pair3[1] 150 | 151 | if pair[0] and pair[0] not in film_list and pair[1]: 152 | film_list.append(pair[0]) 153 | print(pair) 154 | f.write("\t".join(pair)+'\n') 155 | 156 | 157 | def crawl(href, get_children=True): 158 | """ 159 | First, crawl subcategories recursively. 160 | Second, crawl pages in category below. 161 | """ 162 | try: 163 | r = s.get(domin_name + href, timeout=30) 164 | r.encoding = encoding 165 | soup = BeautifulSoup(r.text, 'html.parser') 166 | except: 167 | print("Network err!", href) 168 | h2 = soup.find('h2', string=re.compile('Subcategories')) 169 | if h2: 170 | children = h2.find_next('div', {'class': 'mw-content-ltr'}).find_all('div', {'class': 'CategoryTreeSection'}) 171 | if get_children and children: 172 | for child in children: 173 | a = child.find('a') 174 | print(a.string) 175 | catogory = a.string 176 | if catogory in catogories or re.search('^User:|^Template:', catogory): 177 | continue 178 | catogories.append(catogory) 179 | crawl(a.attrs['href']) 180 | # pages = soup.find('div', {'class':'mw-category'}) 181 | # if not pages: 182 | if h2 and children: 183 | pages = children[-1].parent.find_next('div', {'class': 'mw-content-ltr'}) 184 | else: 185 | h2 = soup.find('h2', string=re.compile('Pages')) 186 | if h2: 187 | pages = h2.find_next('div', {'class': 'mw-content-ltr'}) 188 | else: 189 | pages = None 190 | if pages: 191 | for a in pages.find_all('a'): 192 | if re.search('^User:|^Template:', a.string): 193 | continue 194 | # try: 195 | search(a.attrs['href']) 196 | # except: 197 | # print('产生异常', a.attrs['href']) 198 | 199 | 200 | def main(): 201 | try: 202 | crawl(href1) 203 | finally: 204 | if f: 205 | f.flush() 206 | f.close() 207 | 208 | 209 | if __name__ == '__main__': 210 | main() -------------------------------------------------------------------------------- /kg_category.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import Levenshtein 4 | import pandas as pd 5 | import requests 6 | import wikipedia 7 | from util import * 8 | from join import exist_match, match 9 | import multiprocessing 10 | import numpy as np 11 | 12 | 13 | def transform(s): 14 | return trans_sensitive(shorten(process(trans_html(s), 0))) 15 | 16 | 17 | pattern_note = re.compile(' *\(.+\) *$| *[Tt]he *| *[Aa]n? *') 18 | 19 | 20 | def delete_note(s): 21 | return re.sub(pattern_note, '', s) 22 | 23 | 24 | def clip(type): 25 | print(type) 26 | # entity_set = pd.read_csv('data/output/' + type + '_entity_dict.txt', delimiter='\t', header=None)[0].tolist() 27 | df = pd.read_csv('data/id2title/' + type + '_entity_title.txt', delimiter='\t', header=None) 28 | # df = df[df[0].isin(entity_set)] 29 | df[1] = df[1].apply(transform) 30 | df.to_csv('data/output/' + type + '_entity_title.txt', sep='\t', header=False, index=False) 31 | 32 | 33 | trash = [] 34 | entities = '(movie|film|book|drama|video|television|novel|comic|play|series|show|anime|manga|ova|tv)s?' 35 | entity_pattern = re.compile(entities) 36 | stop_pattern = re.compile(entities + '$') 37 | dirty_pattern = re.compile('cs1|article|set in|shot in|error|wiki|template|use|games|albums|radio' 38 | '|page|using|usage|language|link|adapted|adaption|unfinished|introduced') 39 | 40 | 41 | def valid(category): 42 | if not category: 43 | return False 44 | if re.match('\W+$', category): 45 | return False 46 | s = category.lower() 47 | if re.match(stop_pattern, s): 48 | return False 49 | if category in trash: 50 | return False 51 | if re.search(dirty_pattern, s) is not None: 52 | trash.append(category) 53 | return False 54 | return True 55 | 56 | 57 | def extract(category): 58 | category = re.sub('\d{4}s?|\d{2}th[ \-][Cc]entury', '', category).strip() 59 | category = re.sub(' {2,}', ' ', category) 60 | if category is None: 61 | return category 62 | res = re.search('(about|by|based on|shot at) (.+)', category) 63 | if res: 64 | category = res.group(2) 65 | res = re.search('(.+) (' + entities + ')+', 66 | category, flags=re.I) 67 | if res: 68 | category = res.group(1) 69 | return category 70 | 71 | 72 | delete_word = re.compile('Americans?|Canadian|British|Hong Kong|Chinese|Italian|Australian' 73 | 'French|Japanese|Spanish|African|English|the United States|Indian|' 74 | 'and|[Dd]ebuts?|\(genre\)|Western|(South )?Korean|German|' 75 | + entities) 76 | 77 | 78 | def is_entity(categories): 79 | has_entity = False 80 | for c in categories: 81 | res = re.search('(ers|[Aa]ctors?|ists)', c) 82 | if res: 83 | back = res.span()[1] 84 | if len(c) == back or c[back] == ' ': 85 | break 86 | if re.search(entity_pattern, c): 87 | has_entity = True 88 | break 89 | return has_entity 90 | 91 | 92 | def wash_func(categories): 93 | categories = eval(categories) 94 | if not is_entity(categories): 95 | return [] 96 | categories = [re.sub(delete_word, '', extract(c)).strip().strip('-') for c in categories] 97 | categories = [re.sub(' {2,}', '', c).lower() for c in categories if valid(c)] 98 | categories = list(set(categories)) 99 | return categories 100 | 101 | 102 | 103 | def ceiling(query, title): 104 | if query.count(' ') <= 2: 105 | res = re.search(re.escape(query), title, flags=re.I) 106 | if res: 107 | if re.match(' [^\w,\']|( *$)', title[res.span()[1]:]): 108 | if res.span()[0] == 0: 109 | return True 110 | elif res.span()[0] >= 2 \ 111 | and re.match('[^\w,\']{2}', title[res.span()[0] - 2:res.span()[0]]): 112 | return True 113 | return False 114 | return True 115 | 116 | 117 | api_metadata = 'https://en.wikipedia.org/api/rest_v1/page/metadata/' 118 | 119 | cnt = 0 120 | 121 | 122 | def get_category(title): 123 | global cnt 124 | cnt += 1 125 | print(cnt) 126 | try: 127 | for res in wikipedia.search(title, results=2): 128 | # print(title, '\t', res) 129 | s1 = delete_note(title) 130 | s2 = delete_note(res) 131 | if match(s1, s2, True) or Levenshtein.jaro_winkler(s1, s2) > 0.8: 132 | r = s.get(api_metadata + re.sub(' ', '_', res)) 133 | try: 134 | json = r.json() 135 | categories = [c['titles']['display'] for c in json['categories']] 136 | if not categories and is_entity(categories): continue 137 | print(categories) 138 | return [c for c in categories if valid(c)] 139 | except: 140 | continue 141 | except Exception as e: 142 | print(str(e)) 143 | return [] 144 | return [] 145 | 146 | 147 | def task(df, type, bs=200): 148 | i = 0 149 | while i < len(df): 150 | sub = df[i: i+bs].copy() 151 | sub['category'] = sub['title'].apply(get_category) 152 | sub.to_csv('data/output/' + type + '_df.csv', index=False, header=False, mode='a') 153 | i += bs 154 | 155 | 156 | def generate_category(type='movie'): 157 | df = pd.read_csv('data/output/' + type + '_entity_title.txt', delimiter='\t', header=None) 158 | df.columns = ['id', 'title'] 159 | # over = pd.read_csv('data/output/' + type + '_df.csv', header=None) 160 | # ids = over[0].unique() 161 | # empty = over[over[2] == '[]'][0].unique() 162 | # df = df[~df['id'].isin(ids) or df['id'].isin(empty)] 163 | # df = df[~df['id'].isin(ids)] 164 | # del over 165 | # del ids 166 | # del empty 167 | sum = len(df) 168 | print(sum, sum//n_process) 169 | for i in range(n_process): 170 | p = multiprocessing.Process(target=task, 171 | args=(df[i * sum//n_process: 172 | (i+1) * sum//n_process], type, )) 173 | p.start() 174 | 175 | 176 | def wash(type='movie'): 177 | df = pd.read_csv('data/output/' + type + '_df.csv', header=None) 178 | df.columns = ['id', 'title', 'category'] 179 | df = df[df['category'] != '[]'] 180 | df['category'] = df['category'].apply(wash_func) 181 | df = df[df['category'].map(len) > 0] 182 | # df.to_csv('data/output/' + type + '_df1.csv', header=False, index=False) 183 | df2 = pd.DataFrame({'id': df.id.repeat(df.category.str.len()), 184 | 'category': np.concatenate(df.category.values)}) 185 | df_count = df2['category'].value_counts() 186 | hot = df_count[df_count.values > 150].index 187 | tags = set(hot[hot.map(lambda x: len(x.split(' '))) == 1].to_list()) - {'in'} 188 | def split_tags(x): 189 | x = x.strip() 190 | s = set(x.split(' ')) 191 | if 2 <= len(s) <= 3: 192 | common = s & tags 193 | if s != common: 194 | common.add(x) 195 | return list(common) 196 | else: 197 | return [x] 198 | df2['category'] = df2['category'].apply(split_tags) 199 | df2 = pd.DataFrame({'id': df2.id.repeat(df2.category.str.len()), 200 | 'category': np.concatenate(df2.category.values)}) 201 | # df_kg = pd.read_csv('data/output/' + type + '_entity_dict.txt', header=None, delimiter='\t') 202 | df3 = pd.read_csv('data/output/' + type + '_raw_categories.txt', header=None, delimiter='\t') 203 | df3.columns = ['id', 'category'] 204 | # df3 = df3[df3.id.isin(df_kg[0].values)] 205 | df2 = df2.append(df3, ignore_index=True).drop_duplicates().sort_values(by='id') 206 | df2.to_csv('data/output/' + type + '_id_category.csv', sep='\t', header=False, index=False) 207 | 208 | 209 | def output(): 210 | # output relation of kg index and category id 211 | df1_dict = pd.read_csv('data/output/movie_rid2index.txt', header=None, delimiter='\t', names=['id', 'eid']) 212 | df2_dict = pd.read_csv('data/output/book_rid2index.txt', header=None, delimiter='\t', names=['id', 'eid']) 213 | 214 | df1 = pd.read_csv('data/output/movie_id_category.csv', header=None, delimiter='\t') 215 | df1.columns = ['id', 'category'] 216 | df2 = pd.read_csv('data/output/book_id_category.csv', header=None, delimiter='\t') 217 | df2.columns = ['id', 'category'] 218 | 219 | category = df1.append(df2)['category'].value_counts().reset_index() 220 | category = category[['index']] 221 | category.columns = ['category'] 222 | category['cid'] = category.index 223 | category.to_csv('data/output/category_dict.txt', sep='\t', header=False, index=False) 224 | 225 | df1 = df1.merge(category, on='category') 226 | df1 = df1.merge(df1_dict, on='id').sort_values(by='eid') 227 | df1[['eid', 'cid']].to_csv('data/output/movie_id_cid.txt', sep='\t', header=False, index=False) 228 | df2 = df2.merge(category, on='category') 229 | df2 = df2.merge(df2_dict, on='id').sort_values(by='eid') 230 | df2[['eid', 'cid']].to_csv('data/output/book_id_cid.txt', sep='\t', header=False, index=False) 231 | 232 | 233 | def rid2index(): 234 | for type in ['movie', 'book']: 235 | df = pd.read_csv('data/output/' + type + '_entity_dict.txt', header=None, delimiter='\t', names=['id', 'eid']) 236 | df2 = pd.read_csv('data/output/' + 'entity_id2index_' + type + '.txt', header=None, delimiter='\t', names=['eid', 'eidx']) 237 | if type == 'B': 238 | df2['eid'] -= 158410 239 | df = df.merge(df2).sort_values('eidx') 240 | df[['id', 'eidx']].to_csv('data/output/' + type + '_rid2index.txt', header=False, sep='\t', index=False) 241 | 242 | 243 | s = requests.session() 244 | s.keep_alive = False 245 | 246 | if __name__ == '__main__': 247 | 248 | n_process = 5 249 | for type in ['movie', 'book']: 250 | clip(type) 251 | generate_category(type) 252 | wash('book') 253 | 254 | # output() 255 | 256 | 257 | 258 | 259 | -------------------------------------------------------------------------------- /film_book_en.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import requests 4 | from bs4 import BeautifulSoup 5 | 6 | encoding = 'utf-8' 7 | domin_name = 'https://en.wikipedia.org' 8 | href_films = '/wiki/Category:Films_based_on_books' 9 | href_tv = '/wiki/Category:Television_shows_based_on_books' 10 | href_pass = '/wiki/Category:Lists_of_films_based_on_books' 11 | api_metadata = 'https://en.wikipedia.org/api/rest_v1/page/metadata/' 12 | api_summary = 'https://en.wikipedia.org/api/rest_v1/page/summary/' 13 | film_list = [] 14 | pair_dic = [] 15 | catogories = [] 16 | 17 | PAD_TOKEN = '' 18 | 19 | s = requests.session() 20 | s.keep_alive = False 21 | 22 | f = open('data/film_book.txt', 'a', encoding='utf-8') 23 | # fc = open('data/film_category.txt', 'a', encoding='utf-8') 24 | 25 | 26 | def cut(str): 27 | str = re.sub('|\(.*book.*\)' 28 | '|\(.*film.*\)' 29 | '|\(.*movie.*\)' 30 | '|\(.*series.*\)', '', str, flags=re.I).strip() 31 | return re.sub('(in )?[Ff]ilms?( series)?$', '', str, flags=re.I) 32 | 33 | def search_infobox(infobox): 34 | # infobox = soup.find('table', {'class': 'infobox vevent'}) 35 | pair1 = ['', ''] 36 | if infobox: 37 | th = infobox.find('th') 38 | if th.string: 39 | pair1[0] = th.string 40 | 41 | origin = infobox.find('th', string=re.compile('based', flags=re.I)) 42 | if origin: 43 | td = origin.parent.find('td') 44 | if td: 45 | if td.string: 46 | pair1[1] = td.string 47 | else: 48 | text = td.get_text() 49 | a = td.find('a') 50 | if a and a.string and text.startswith(a.string): 51 | pair1[1] = a.string 52 | else: 53 | i = td.find('i') 54 | if i and i.string and text.startswith(i.string): 55 | pair1[1] = i.string 56 | else: 57 | match = re.match('(.+)(Written )?by .+', text, flags=re.I) 58 | if match: 59 | pair1[1] = match.group(1) 60 | pair1[1] = re.sub(' *\( *novel *\) *', '', pair1[1]).strip() 61 | if re.match('^ *based on *', pair1[1]): 62 | pair1[1] = '' 63 | 64 | return pair1 65 | 66 | 67 | def search(href): 68 | try: 69 | r = s.get(domin_name + href) 70 | r.encoding = encoding 71 | soup = BeautifulSoup(r.text, 'html.parser') 72 | except: 73 | print("Network err!", href) 74 | return 75 | 76 | # film_title file_title_en book_title book_title_en 77 | pair = ['', ''] 78 | pair1 = ['', ''] 79 | pair2 = ['', ''] 80 | 81 | title = soup.find('title').string 82 | title = re.sub(' - Wikipedia', '', title) 83 | # 从词条表获取中英名,如有,获得原著名 84 | # 如果没有电影的括号注释,先查找是否有电影一栏 85 | match = re.match('^(.+)(\(.*film.*\))$', title) 86 | title_name = cut(title) 87 | found = False 88 | if match: 89 | film_columns = soup.find_all( 90 | 'span', {'class':'mw-headline'}, 91 | string=re.compile('film')) 92 | # for fc in film_columns: 93 | # tf.write(fc.string+'\n') 94 | if film_columns: 95 | found = True 96 | film_text = '' 97 | fp = None 98 | if found: 99 | for film_column in film_columns: 100 | infobox = film_column.parent.find_next_sibling('table', {'class': 'infobox vevent'}) 101 | pair = search_infobox(infobox) 102 | fp = film_column.find_next('p') 103 | if fp and title_name in fp.get_text(): 104 | film_text = fp.get_text() 105 | else: 106 | infobox = soup.find('table', {'class': 'infobox vevent'}) 107 | pair1 = search_infobox(infobox) 108 | 109 | # 从第一段简介获取,可信度较低 110 | # introduction = film_text if found else soup.find('p').get_text() 111 | cursor = soup.find('p') 112 | for i in range(10): 113 | if cursor is None or title_name in cursor.get_text(): 114 | break 115 | cursor = cursor.find_next('p') 116 | introduction = '' 117 | if found: 118 | introduction = film_text 119 | cursor = fp 120 | elif cursor: 121 | introduction = cursor.get_text() 122 | first_p = None 123 | first_para = '' 124 | bs = soup.find_all('b', string=title_name) 125 | if bs: 126 | for b in bs: 127 | if b.parent and b.parent.name == 'p': 128 | first_p = b.parent 129 | break 130 | if introduction == '' or introduction == '\n' or introduction is None: 131 | first_para = first_p.get_text() if first_p else '' 132 | introduction = first_para 133 | cursor = first_p 134 | 135 | pair2[0] = title_name 136 | 137 | same_name = False 138 | if re.search('based on.+same name', introduction, flags=re.I): 139 | pair2[1] = pair2[0] 140 | 141 | # 如果有电影一栏,并且词条不包含电影两字,极有可能是一个大词条,主栏可能是书籍 142 | pair3 = ['', ''] 143 | if found and not re.search('film', title): 144 | searchObj = re.search('[Bb]ook|[Nn]ovel|[Ww]ritten', first_para) 145 | if searchObj: 146 | pair3[1] = title_name 147 | 148 | 149 | 150 | for i in range(2): 151 | pair[i] = pair1[i] if pair1[i] else pair2[i] 152 | pair[1] = pair[1] if pair[1] else pair3[1] 153 | 154 | if pair[0] and pair[0] not in film_list and pair[1]: 155 | film_list.append(pair[0]) 156 | print(pair) 157 | f.write("\t".join(pair)+'\n') 158 | 159 | 160 | def crawl(href, get_children=True): 161 | try: 162 | r = s.get(domin_name + href, timeout=30) 163 | r.encoding = encoding 164 | soup = BeautifulSoup(r.text, 'html.parser') 165 | except: 166 | print("Network err!", href) 167 | h2 = soup.find('h2', string=re.compile('Subcategories')) 168 | if h2: 169 | children = h2.find_next('div', {'class': 'mw-content-ltr'}).find_all('div', {'class': 'CategoryTreeSection'}) 170 | if get_children and children: 171 | for child in children: 172 | a = child.find('a') 173 | print(a.string) 174 | catogory = a.string 175 | if catogory in catogories or re.search('^User:|^Template:', catogory): 176 | continue 177 | catogories.append(catogory) 178 | href = a.attrs['href'] 179 | if href != href_pass: 180 | crawl(href) 181 | # pages = soup.find('div', {'class':'mw-category'}) 182 | # if not pages: 183 | if h2 and children: 184 | pages = children[-1].parent.find_next('div', {'class': 'mw-content-ltr'}) 185 | else: 186 | h2 = soup.find('h2', string=re.compile('Pages')) 187 | if h2: 188 | pages = h2.find_next('div', {'class': 'mw-content-ltr'}) 189 | else: 190 | pages = None 191 | if pages: 192 | for a in pages.find_all('a'): 193 | if re.search('^User:|^Template:', a.string): 194 | continue 195 | # try: 196 | search(a.attrs['href']) 197 | # except: 198 | # print('产生异常', a.attrs['href']) 199 | 200 | 201 | def base_on_category(href, name=''): 202 | if name: 203 | exclude = re.search(' [Bb]y |[Nn]ovels?$|[Bb]ooks?$', name) 204 | if exclude is None: 205 | res = re.search('[Bb]ased on( the)? *([A-Z\d].+)$', name) 206 | if res: 207 | name = res.group(2) 208 | if re.search('[Ff]ilm', name): 209 | name = '' 210 | name = cut(name) 211 | else: 212 | name = '' 213 | else: 214 | name = '' 215 | try: 216 | r = s.get(domin_name + href, timeout=30) 217 | r.encoding = encoding 218 | soup = BeautifulSoup(r.text, 'html.parser') 219 | except: 220 | print("Network err!", href) 221 | h2 = soup.find('h2', string=re.compile('Subcategories')) 222 | if h2: 223 | children = h2.find_next('div', {'class': 'mw-content-ltr'}).find_all('div', {'class': 'CategoryTreeSection'}) 224 | if children: 225 | for child in children: 226 | a = child.find('a') 227 | print(a.string) 228 | catogory = a.string 229 | if catogory in catogories or re.search('^User:|^Template:', catogory): 230 | continue 231 | catogories.append(catogory) 232 | href = a.attrs['href'] 233 | if href != href_pass: 234 | base_on_category(href, a.string) 235 | 236 | pages = soup.find('div', {'id': 'mw-pages'}) 237 | 238 | if name: 239 | if pages: 240 | pages = pages.find('div', {'class' : 'mw-content-ltr'}) 241 | if pages: 242 | pages = pages.find_all('a') 243 | for page in pages: 244 | title = cut(page.string) 245 | if re.search('^User:|^Template:', title): 246 | continue 247 | f.write(title + '\t' + name + '\n') 248 | print([title, name]) 249 | # url_token = page.attrs['href'][6:] # after '/wiki/' 250 | # r = s.get(api_metadata + url_token) 251 | # try: 252 | # json = r.json() 253 | # for category in json['categories']: 254 | # fc.write(title + '\t' + category['titles']['display'] + '\n') 255 | # except Exception as e: 256 | # print(str(e)) 257 | 258 | 259 | def main(): 260 | try: 261 | # crawl(href_films) # Same as tv_book_en.py. 262 | base_on_category(href_films) # Take based category as the potential. It may raise noise but allow for more relations. 263 | # base_on_category(href_tv) 264 | finally: 265 | if f: 266 | f.flush() 267 | f.close() 268 | # if fc: 269 | # fc.flush() 270 | # fc.close() 271 | 272 | 273 | if __name__ == '__main__': 274 | main() -------------------------------------------------------------------------------- /join.py: -------------------------------------------------------------------------------- 1 | import html 2 | import re 3 | import nltk 4 | from nltk import word_tokenize, pos_tag 5 | from util import * 6 | 7 | wnl = nltk.stem.WordNetLemmatizer() 8 | stemer = nltk.stem.snowball.EnglishStemmer() 9 | 10 | 11 | def lemmatize_and_stem_all(sentence): 12 | for word, tag in pos_tag(word_tokenize(sentence)): 13 | if tag.startswith('NN'): 14 | yield stemer.stem(wnl.lemmatize(word, pos='n')) 15 | elif tag.startswith('VB'): 16 | yield stemer.stem(wnl.lemmatize(word, pos='v')) 17 | elif tag.startswith('JJ'): 18 | yield stemer.stem(wnl.lemmatize(word, pos='a')) 19 | elif tag.startswith('R'): 20 | yield stemer.stem(wnl.lemmatize(word, pos='r')) 21 | else: 22 | yield stemer.stem(word) 23 | 24 | 25 | fname = 'film_book.txt' 26 | mte = 'movie_title_entity.txt' 27 | bte = 'book_title_entity.txt' 28 | 29 | encoding = 'utf-8' 30 | 31 | # def process1(): 32 | # with open('data/' + mte, 'r', encoding=encoding) as f1: 33 | # with open('data/disposed/' + mte, 'w+', encoding=encoding) as f2: 34 | # for line in f1: 35 | # title = line.strip().split('\t')[1] 36 | # title = process(title, 0) 37 | # f2.write(title + '\n') 38 | # 39 | # 40 | # def process2(): 41 | # with open('data/' + bte, 'r', encoding=encoding) as f1: 42 | # with open('data/disposed/' + bte, 'w+', encoding=encoding) as f2: 43 | # for line in f1: 44 | # title = line.strip().split('\t')[1] 45 | # title = process(title, 1) 46 | # f2.write(title + '\n') 47 | # 48 | # 49 | # def sort_title(flag): 50 | # entity = mte if flag == 0 else bte 51 | # with open('data/disposed/' + entity, 'r', encoding=encoding) as f: 52 | # with open('data/disposed/sorted/' + entity, 'w+', encoding=encoding) as w: 53 | # list = [process(line.strip(), flag=flag) for line in f] 54 | # list.sort() 55 | # for l in list: 56 | # w.write(l + '\n') 57 | # 58 | # 59 | # def wash(flag): 60 | # entity = mte if flag == 0 else bte 61 | # with open('data/disposed/' + entity, 'w+', encoding=encoding) as f0: 62 | # with open('data/' + entity, 'r', encoding=encoding) as f1: 63 | # for line in f1: 64 | # pair = line.strip().split('\t') 65 | # f0.write(pair[0] + '\t' + trans(pair[1]) + '\n') 66 | """ 67 | check correctness of methods in util for preprocessing. 68 | """ 69 | # wash(0) 70 | # wash(1) 71 | # process1() 72 | # process2() 73 | # sort_title(0) 74 | # sort_title(1) 75 | 76 | def process_entity(): 77 | root = 'data/' 78 | i, j = 0, 0 79 | with open(root + 'p_' + mte, 'w+', encoding=encoding) as f0: 80 | with open(root + mte, 'r', encoding=encoding) as f: 81 | for title in f: 82 | f0.write(process(title.strip().split('\t')[1], 0) + '\n') 83 | i += 1 84 | print(i) 85 | with open(root + 'p_' + bte, 'w+', encoding=encoding) as f0: 86 | with open(root + bte, 'r', encoding=encoding) as f: 87 | for title in f: 88 | f0.write(process(title.strip().split('\t')[1], 0) + '\n') 89 | j += 1 90 | print(j) 91 | 92 | 93 | pattern = re.compile('\w') 94 | 95 | 96 | def match(title1, title2, all=False): 97 | if title1 == title2: 98 | return True 99 | len1, len2 = len(title1), len(title2) 100 | if len1 <= len2: 101 | word_num = title1.count(' ') + 1 102 | if word_num > 1 or 1 <= (title2.count(' ') + 1) / word_num <= 2: 103 | result = re.search(re.escape(title1), title2, re.I) 104 | if result: 105 | span = result.span() 106 | left, right = True, True 107 | if span[0] > 0: 108 | if re.match(pattern, title2[span[0] - 1]): 109 | left = False 110 | if span[1] < len2: 111 | if re.match(pattern, title2[span[1]]): 112 | right = False 113 | if left and right: 114 | return True 115 | elif all: 116 | word_num = title2.count(' ')+1 117 | if word_num > 2 or 1 <= (title1.count(' ') + 1) / (title2.count(' ') + 1) <= 2: 118 | result = re.search(re.escape(title2), title1, re.I) 119 | if result: 120 | span = result.span() 121 | left, right = True, True 122 | if span[0] > 0: 123 | if re.match(pattern, title1[span[0]-1]): 124 | left = False 125 | if span[1] < len1: 126 | if re.match(pattern, title1[span[1]]): 127 | right = False 128 | if left and right: 129 | return True 130 | return False 131 | 132 | def ceiling(ptmp, real, tmp, entity): 133 | bcount = ptmp.count(' ') 134 | if ptmp.count(' ') <= 3 and len(tmp) >= 15 + 5 * bcount: 135 | t1 = [] 136 | for t in tmp: 137 | res = re.search(re.escape(ptmp), entity[t], flags=re.I) 138 | if res: 139 | if re.match(' [^A-Za-z,\']| *$', entity[t][res.span()[1]:]): 140 | if res.span()[0] == 0: 141 | t1.append(t) 142 | elif res.span()[0] >= 2 \ 143 | and re.match('[^\w,\']{2}', entity[t][res.span()[0]-2:res.span()[0]]): 144 | t1.append(t) 145 | tmp = t1 146 | # if len(tmp) >= 20 + 5 * bcount: 147 | # t1 = [] 148 | # for t in tmp: 149 | # res = re.search(re.escape(ptmp), entity[t], flags=re.I) 150 | # if res: 151 | # if res.span()[0] == 0: 152 | # if re.match(' *[:\-;&(\[]| *$', entity[t][res.span()[1]:]): 153 | # t1.append(t) 154 | # elif entity[t][res.span()[0]-1] == '(' \ 155 | # and re.match(' ?(,[\w.\-]+ ?\d+)?\)', entity[t][res.span()[1]:]): 156 | # t1.append(t) 157 | # tmp = t1 158 | return tmp 159 | 160 | 161 | def consider_hyphen(title): 162 | result = [title] 163 | if re.search('([^ ])-([^ ])', title): 164 | result.append(re.sub('([^ ])-([^ ])', r'\1\2', title)) 165 | result.append(re.sub('([^ ])-([^ ])', r'\1 \2', title)) 166 | return result 167 | 168 | 169 | def exist_match(ptmp, tmp): 170 | if match(ptmp, tmp): 171 | return True 172 | list1 = consider_hyphen(ptmp) 173 | list2 = consider_hyphen(tmp) 174 | if len(list1) > 1 and len(list2) > 1: 175 | for x in consider_hyphen(ptmp): 176 | for y in consider_hyphen(tmp): 177 | if x == ptmp and y == tmp: 178 | continue 179 | if match(x, y): 180 | return True 181 | return False 182 | 183 | 184 | def join(line, len_m, film_list, film_list_real, len_b, book_list, book_list_real, count, num, f_left, f_join): 185 | pair = line.strip().split('\t') 186 | tmp1 = [] 187 | tmp2 = [] 188 | ptmp = trans_sensitive(cut(pair[0])) 189 | for mi in range(len_m): 190 | # if pair[0] in movie or movie in pair[0]: 191 | tmp = film_list[mi] 192 | if tmp == '': 193 | continue 194 | if exist_match(ptmp, tmp): 195 | tmp1.append(mi) 196 | # print([pair[0], tmp]) 197 | tmp1 = ceiling(ptmp, pair[0], tmp1, film_list) 198 | if tmp1: 199 | ptmp = cut(pair[1]) 200 | for bi in range(len_b): 201 | # if pair[1] in book or book in pair[1]: 202 | tmp = book_list[bi] 203 | if tmp == '': 204 | continue 205 | if exist_match(ptmp, tmp): 206 | tmp2.append(bi) 207 | # print(tmp) 208 | # print([pair[1], tmp]) 209 | 210 | tmp2 = ceiling(ptmp, pair[1], tmp2, book_list) 211 | if tmp2: 212 | num += len(tmp1) * len(tmp2) 213 | count += 1 214 | print(pair, len(tmp1), len(tmp2)) 215 | # f0.write('\t'.join(pair)) 216 | # f0.write('\t' + str(len(tmp1)) + '\t' + str(len(tmp2)) + '\n') 217 | 218 | f_left.write(pair[1] + '\t' + pair[0] + '\n') 219 | for t1 in tmp1: 220 | for t2 in tmp2: 221 | f_join.write(book_list_real[t2] + '\t' + film_list_real[t1] + '\n') 222 | return num, count 223 | 224 | 225 | 226 | def full_join(): 227 | root = 'data/' 228 | film_list = [] 229 | film_list_real = [] 230 | book_list = [] 231 | book_list_real = [] 232 | with open(root + 'film_book_left.txt', 'w+', encoding=encoding) as f_left: 233 | with open(root + 'full_join.txt', 'w+', encoding=encoding) as f_join: 234 | with open(root + mte, 'r', encoding=encoding) as f: 235 | for title in f: 236 | film_list_real.append(title.strip()) 237 | # with open(root + 'p_' + mte, 'r', encoding=encoding) as f: 238 | # for title in f: 239 | film_list.append(trans_sensitive(title.strip().split('\t')[1])) 240 | with open(root + bte, 'r', encoding=encoding) as f: 241 | for title in f: 242 | book_list_real.append(title.strip()) 243 | # with open(root + 'p_' + bte, 'r', encoding=encoding) as f: 244 | # for title in f: 245 | book_list.append(trans_sensitive(title.strip().split('\t')[1])) 246 | i = 0 247 | num = 0 248 | count = 0 249 | len_m = len(film_list) 250 | len_b = len(book_list) 251 | with open(root + 'film_book.txt', 'r', encoding=encoding) as f: 252 | for line in f: 253 | num, count = join(line, len_m, film_list, film_list_real, 254 | len_b, book_list, book_list_real, count, num, f_left, f_join) 255 | i += 1 256 | print(i, num) 257 | print('link_num:', num) 258 | print('pair_num:', count) 259 | # with open(root + 'book_set.txt', 'w+', encoding=encoding) as f0: 260 | # with open(root + 'movie_set.txt', 'w+', encoding=encoding) as f1: 261 | # books = [] 262 | # movies = [] 263 | # for l in link: 264 | # if l[0] not in books: 265 | # books.append(l[0]) 266 | # f0.write(l[0]) 267 | # if l[1] not in movies: 268 | # movies.append(l[1]) 269 | # f1.write(l[1]) 270 | 271 | 272 | if __name__ == '__main__': 273 | 274 | # You can preprocess first and output local files to reduce time. 275 | # Codes commented out in full_join() in should be changed. 276 | 277 | # process_entity() 278 | 279 | # You can also change codes commented out in match() and ceiling() for the rule of matching and filtering. 280 | full_join() 281 | 282 | # TODO: use multiprocessing to speed up. 283 | -------------------------------------------------------------------------------- /Train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import time 4 | import argparse 5 | import numpy as np 6 | import tensorflow as tf 7 | import MIFN as MIFN 8 | import pickle as pk 9 | from scipy import sparse 10 | import gc 11 | 12 | random.seed(331) 13 | np.random.seed(331) 14 | 15 | 16 | def load_batch(block_data, batch_size, pad_int, args): 17 | random.shuffle(block_data) 18 | for batch_i in range(0, len(block_data) // batch_size + 1): 19 | start_i = batch_i * batch_size 20 | batch = block_data[start_i:start_i + batch_size] 21 | yield get_session(batch, pad_int, args) 22 | 23 | 24 | def get_session(batch, pad_int, args): 25 | seq_A, seq_B = [], [] 26 | len_A, len_B = [], [] 27 | len_all = [] 28 | pos_A, pos_B = [], [] 29 | target_A, target_B = [], [] 30 | index1, index2 = [], [] 31 | adj_1, adj_2, adj_3, adj_4, adj_5 = [], [], [], [], [] 32 | neighbors = [] 33 | nei_mask1, nei_mask2 = [], [] 34 | nei_mask_L_A, nei_mask_L_B = [], [] 35 | nei_index1, nei_index2 = [], [] 36 | Isina, Isinb = [], [] 37 | tarinA, tarinB = [], [] 38 | for session in batch: 39 | len_A.append(session[4]) 40 | len_B.append(session[5]) 41 | len_all.append(len(session[8])) 42 | maxlen_A = max(len_A) 43 | maxlen_B = max(len_B) 44 | i = 0 45 | for session in batch: 46 | seq_A.append(session[0] + [pad_int] * (maxlen_A - len_A[i])) 47 | seq_B.append(session[1] + [pad_int] * (maxlen_B - len_B[i])) 48 | pos_A.append(session[2] + [pad_int] * (maxlen_A - len_A[i])) 49 | pos_B.append(session[3] + [pad_int] * (maxlen_B - len_B[i])) 50 | target_A.append(session[6]) 51 | target_B.append(session[7]) 52 | index1.append(session[9] + [pad_int] * (maxlen_A - len_A[i])) 53 | index2.append(session[10] + [pad_int] * (maxlen_B - len_B[i])) 54 | adj_1.append(session[11]) 55 | adj_2.append(session[12]) 56 | adj_3.append(session[13]) 57 | adj_4.append(session[14]) 58 | adj_5.append(session[15]) 59 | 60 | neighbors.append(list(session[16].keys())) 61 | nei_index1.append(session[17]) 62 | nei_index2.append(session[18]) 63 | nei_mask1.append(session[19]) 64 | nei_mask2.append(session[20]) 65 | pad_len = maxlen_A + maxlen_B - len_all[i] 66 | t1 = np.concatenate((session[21], np.zeros((pad_len, args.nei_num))), axis=0) 67 | t2 = np.concatenate((session[22], np.zeros((pad_len, args.nei_num))), axis=0) 68 | nei_mask_L_A.append(t1) 69 | nei_mask_L_B.append(t2) 70 | Isina.append(session[23]) 71 | Isinb.append(session[24]) 72 | tarinA.append(session[25]) 73 | tarinB.append(session[26]) 74 | i += 1 75 | index = np.arange(len(batch)) 76 | index = np.expand_dims(index, axis=-1) 77 | index_p = np.repeat(index, maxlen_A, axis=1) 78 | pos_A = np.stack([index_p, np.array(pos_A)], axis=-1) 79 | index1 = np.stack([index_p, np.array(index1)], axis=-1) 80 | index_p = np.repeat(index, maxlen_B, axis=1) 81 | pos_B = np.stack([index_p, np.array(pos_B)], axis=-1) 82 | index2 = np.stack([index_p, np.array(index2)], axis=-1) 83 | 84 | index_p = np.repeat(index, args.nei_num, axis=1) 85 | nei_index1 = np.stack([index_p, np.array(nei_index1)], axis=-1) 86 | nei_index2 = np.stack([index_p, np.array(nei_index2)], axis=-1) 87 | 88 | return np.array(adj_1, dtype=int), np.array(adj_2, dtype=int), np.array(adj_3, dtype=int), \ 89 | np.array(adj_4, dtype=int), np.array(adj_5, dtype=int), \ 90 | np.array(seq_A), np.array(seq_B), pos_A, pos_B, index1, index2, \ 91 | np.array(len_A), np.array(len_B), np.array(target_A), np.array(target_B), \ 92 | np.array(neighbors), np.array(nei_index1), np.array(nei_index2), \ 93 | np.array(Isina), np.array(Isinb), \ 94 | np.array(nei_mask_L_A), np.array(nei_mask_L_B), \ 95 | np.array(nei_mask1), np.array(nei_mask2), np.array(tarinA), np.array(tarinB) 96 | 97 | 98 | def feed_dict(model, data, isTrain): 99 | adj1, adj2, adj3, adj4, adj5 = data[0], data[1], data[2], data[3], data[4], 100 | seq_A, seq_B, pos_A, pos_B, index_A, index_B = data[5], data[6], data[7], data[8], data[9], data[10], 101 | len_A, len_B, target_A, target_B, neighbors = data[11], data[12], data[13], data[14], data[15], 102 | nei_index_A, nei_index_B, = data[16], data[17], 103 | Isina, Isinb = data[18], data[19], 104 | nei_mask_L_a, nei_mask_L_b = data[20], data[21], 105 | nei_maska, nei_maskb = data[22], data[23], 106 | tarinA, tarinB = data[24], data[25], 107 | if isTrain: 108 | feed_dict = {model.seq_A: seq_A, model.seq_B: seq_B, model.pos_A: pos_A, model.pos_B: pos_B, 109 | model.len_A: len_A, model.len_B: len_B, model.target_A: target_A, model.target_B: target_B, 110 | model.tar_in_A: tarinA, model.tar_in_B: tarinB, 111 | model.adj_1: adj1, model.adj_2: adj2, model.adj_3: adj3, model.adj_4: adj4, model.adj_5: adj5, 112 | model.neighbors: neighbors, 113 | model.index_A: index_A, model.index_B: index_B, 114 | model.nei_index_A: nei_index_A, model.nei_index_B: nei_index_B, 115 | model.nei_A_mask: nei_maska, model.nei_B_mask: nei_maskb, 116 | model.nei_L_A_mask: nei_mask_L_a, model.nei_L_T_mask: nei_mask_L_b, 117 | model.IsinnumA: Isina, model.IsinnumB: Isinb, } 118 | else: 119 | feed_dict = {model.seq_A: seq_A, model.seq_B: seq_B, model.pos_A: pos_A, model.pos_B: pos_B, 120 | model.len_A: len_A, model.len_B: len_B, 121 | model.adj_1: adj1, model.adj_2: adj2, model.adj_3: adj3, model.adj_4: adj4, model.adj_5: adj5, 122 | model.neighbors: neighbors, 123 | model.index_A: index_A, model.index_B: index_B, 124 | model.nei_index_A: nei_index_A, model.nei_index_B: nei_index_B, 125 | model.nei_A_mask: nei_maska, model.nei_B_mask: nei_maskb, 126 | model.nei_L_A_mask: nei_mask_L_a, model.nei_L_T_mask: nei_mask_L_b, 127 | model.IsinnumA: Isina, model.IsinnumB: Isinb, } 128 | return feed_dict 129 | 130 | 131 | def train(args): 132 | model = MIFN.AINet_all(num_items_A=29207, num_items_B=34886, 133 | num_entity_A=50273, num_entity_B=82552, num_cate=32, 134 | neighbor_num=args.nei_num, batch_size=args.batch_size, gpu=args.gpu, 135 | hidden_size=args.hidden_size, embedding_size=args.embedding_size, 136 | lr=args.lr, keep_prob=args.keep_prob) 137 | print(time.localtime()) 138 | checkpoint = '' 139 | with tf.Session(graph=model.graph, config=model.config) as sess: 140 | writer = tf.summary.FileWriter('', sess.graph) 141 | saver = tf.train.Saver(max_to_keep=args.epochs) 142 | sess.run(tf.global_variables_initializer()) 143 | for epoch in range(args.epochs): 144 | loss = 0 145 | step = 0 146 | filelist = os.listdir(args.train_path) 147 | for file in filelist: 148 | with open(args.train_path + file, 'rb') as f: 149 | block_data = pk.load(f) 150 | for k, (epoch_data) in enumerate(load_batch(block_data, args.batch_size, args.pad_int, args)): 151 | _, l = sess.run([model.train_op, model.loss], feed_dict(model, epoch_data, isTrain=True)) 152 | loss += l 153 | step += 1 154 | gc.collect() 155 | print('Epoch {}/{} - Training Loss: {:.3f}'.format(epoch + 1, args.epochs, loss / step)) 156 | saver.save(sess, checkpoint, global_step=epoch + 1) 157 | evaluation(model, sess, args.valid_path, epoch, loss, step) 158 | 159 | 160 | def evaluation(model, sess, validpath, epoch, loss, step): 161 | print(time.localtime()) 162 | validlen = 7650 163 | r5_a, r10_a, r20_a = 0, 0, 0 164 | m5_a, m10_a, m20_a = 0, 0, 0 165 | r5_b, r20_b, r10_b = 0, 0, 0 166 | m5_b, m10_b, m20_b = 0, 0, 0 167 | 168 | filelist = os.listdir(validpath) 169 | for file in filelist: 170 | with open(validpath + file, 'rb') as f: 171 | block_data = pk.load(f) 172 | for _, (epoch_data) in enumerate(load_batch(block_data, args.batch_size, args.pad_int, args)): 173 | pa, pb = sess.run([model.pred_A, model.pred_B], feed_dict(model, epoch_data, isTrain=False)) 174 | target_A, target_B = epoch_data[13], epoch_data[14] 175 | recall, mrr = get_eval(pa, target_A, [5, 10, 20]) 176 | r5_a += recall[0] 177 | m5_a += mrr[0] 178 | r10_a += recall[1] 179 | m10_a += mrr[1] 180 | r20_a += recall[2] 181 | m20_a += mrr[2] 182 | recall, mrr = get_eval(pb, target_B, [5, 10, 20]) 183 | r5_b += recall[0] 184 | m5_b += mrr[0] 185 | r10_b += recall[1] 186 | m10_b += mrr[1] 187 | r20_b += recall[2] 188 | m20_b += mrr[2] 189 | gc.collect() 190 | print('Recall5: {:.5f}; Mrr5: {:.5f}'.format(r5_a / validlen, m5_a / validlen)) 191 | print('Recall10: {:.5f}; Mrr10: {:.5f}'.format(r10_a / validlen, m10_a / validlen)) 192 | print('Recall20: {:.5f}; Mrr20: {:.5f}'.format(r20_a / validlen, m20_a / validlen)) 193 | print('Recall5: {:.5f}; Mrr5: {:.5f}'.format(r5_b / validlen, m5_b / validlen)) 194 | print('Recall10: {:.5f}; Mrr10: {:.5f}'.format(r10_b / validlen, m10_b / validlen)) 195 | print('Recall20: {:.5f}; Mrr20: {:.5f}'.format(r20_b / validlen, m20_b / validlen)) 196 | print(time.localtime()) 197 | 198 | with open('', 'a+') as f: 199 | f.write('epoch: ' + str(epoch + 1) + '\t' + str(loss / step) + '\n') 200 | f.write('recall-A @5|10|20: ' + str(r5_a / validlen) + '\t' + str(r10_a / validlen) + '\t' + str( 201 | r20_a / validlen) + '\t') 202 | f.write('mrr-A @5|10|20: ' + str(m5_a / validlen) + '\t' + str(m10_a / validlen) + '\t' + str( 203 | m20_a / validlen) + '\n') 204 | f.write('recall-B @5|10|20: ' + str(r5_b / validlen) + '\t' + str(r10_b / validlen) + '\t' + str( 205 | r20_b / validlen) + '\t') 206 | f.write('mrr-B @5|10|20: ' + str(m5_b / validlen) + '\t' + str(m10_b / validlen) + '\t' + str( 207 | m20_b / validlen) + '\n') 208 | 209 | 210 | def get_eval(predlist, truelist, klist): # return recall@k and mrr@k 211 | recall = [] 212 | mrr = [] 213 | predlist = predlist.argsort() 214 | for k in klist: 215 | recall.append(0) 216 | mrr.append(0) 217 | templist = predlist[:, -k:] # the result of argsort is in ascending 218 | i = 0 219 | while i < len(truelist): 220 | pos = np.argwhere(templist[i] == truelist[i]) # pos is a list of positions whose values are all truelist[i] 221 | if len(pos) > 0: 222 | recall[-1] += 1 223 | mrr[-1] += 1 / (k - pos[0][0]) 224 | else: 225 | recall[-1] += 0 226 | mrr[-1] += 0 227 | i += 1 228 | return recall, mrr # they are sum instead of mean 229 | 230 | if __name__ == '__main__': 231 | parser = argparse.ArgumentParser() 232 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 233 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 234 | parser.add_argument('--epochs', type=int, default=100, help='the number of epochs') 235 | parser.add_argument('--keep_prob', type=float, default=0.8, help='keep prob of hidden unit') 236 | parser.add_argument('--pad_int', type=int, default=0, 237 | help='padding on the session') 238 | parser.add_argument('--train_path', type=str, default='', 239 | help='train data path') 240 | parser.add_argument('--valid_path', type=str, default='', 241 | help='valid data path') 242 | parser.add_argument('--test_path', type=str, default='', 243 | help='test data path') 244 | parser.add_argument('--n_iter', type=int, default=1, help='the number of h hop') 245 | parser.add_argument('--embedding_size', type=int, default=256, help='embedding size of vector') 246 | parser.add_argument('--hidden_size', type=int, default=256, help='hidden size') 247 | parser.add_argument('--num_layers', type=int, default=1, help='num_layers') 248 | parser.add_argument('--nei_num', type=int, default=200, help='num_neighbours') 249 | parser.add_argument('--gpu', type=str, default='3', help='use of gpu') 250 | args = parser.parse_args() 251 | 252 | train(args) -------------------------------------------------------------------------------- /extract_kg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import pickle as pk 5 | from scipy import sparse 6 | import collections 7 | import linecache 8 | 9 | def read_all_responding_index(): 10 | 11 | with open(ent2ind_a_p, 'r') as f: 12 | for line in f.readlines(): 13 | line = line.strip().split('\t') 14 | entity_id2index_A[int(line[0])] = int(line[1]) #[oldid, index] 15 | with open(ent2ind_b_p, 'r') as f: 16 | for line in f.readlines(): 17 | line = line.strip().split('\t') 18 | entity_id2index_B[int(line[0])] = int(line[1]) 19 | 20 | with open(item2ind_a_p, 'r') as f: 21 | for line in f.readlines(): 22 | line = line.strip().split('\t') 23 | item_index_old2new_A[int(line[0])] = int(line[1]) 24 | with open(item2ind_b_p, 'r') as f: 25 | for line in f.readlines(): 26 | line = line.strip().split('\t') 27 | item_index_old2new_B[int(line[0])] = int(line[1]) 28 | def getdata(datapath): 29 | # read sequnece data,each line represents a sequence 30 | with open(datapath, 'r') as f: 31 | sessions = [] 32 | for line in f.readlines(): 33 | session = [] 34 | line = line.strip().split('\t') 35 | for item in line[1:]: 36 | item = int(item) 37 | session.append(item) 38 | sessions.append(session) 39 | return sessions 40 | def gen_mask(index_dict,sequence): 41 | # generate nei_mask 42 | nei = list(index_dict.keys()) 43 | ini_A = np.ones(max_Nei) 44 | ini_A[np.where(np.array(nei) > entity_border)] = 0 45 | ini_B = np.ones(max_Nei) 46 | ini_B[np.where(np.array(nei) <= entity_border)] = 0 47 | temp_a, temp_t = [], [] 48 | for item in sequence: 49 | if item <= entity_border: 50 | temp_a.append(ini_A) 51 | temp_t.append(ini_B) 52 | else: 53 | temp_a.append(ini_B) 54 | temp_t.append(ini_A) 55 | return ini_A, ini_B, temp_a,temp_t 56 | def gen_index(index_dict): 57 | itemB_border = 85159 58 | itemAset = set(range(item_border+1)) 59 | itemBset = set(range(entity_border+1, itemB_border)) 60 | IsinnumA,IsinnumB = [],[] 61 | for ii in index_dict.keys(): 62 | if ii in itemAset: 63 | IsinnumA.append(1) 64 | IsinnumB.append(0) 65 | elif ii in itemBset: 66 | IsinnumA.append(0) 67 | IsinnumB.append(1) 68 | else: 69 | IsinnumA.append(0) 70 | IsinnumB.append(0) 71 | nei_index_A = np.zeros(max_Nei) 72 | ind1 = np.where(np.array(list(index_dict.keys())) <= item_border)[0] 73 | ent_id = np.array(list(index_dict.keys()))[ind1] 74 | ite_id = [] 75 | for ii in ent_id: 76 | ite_id.append(ii) 77 | nei_index_A[ind1] = ite_id 78 | 79 | nei_index_B = np.zeros(max_Nei) 80 | ind1 = np.where((np.array(list(index_dict.keys())) > entity_border) & 81 | (np.array(list(index_dict.keys())) <= itemB_border))[0] 82 | ent_id = np.array(list(index_dict.keys()))[ind1] 83 | ite_id = [] 84 | for ii in ent_id: 85 | ite_id.append(ii-entity_border) 86 | nei_index_B[ind1] = ite_id 87 | 88 | return nei_index_A, nei_index_B, IsinnumA, IsinnumB 89 | 90 | def processdata(dataset,filepath): 91 | sessions = [] 92 | all_count = 0 93 | for session in dataset: # 1->A,2->B 94 | all_count += 1 95 | temp = [] 96 | seq1,seq2 = [],[] 97 | pos1,pos2 = [],[] 98 | index1,index2 = [],[] 99 | len1,len2 = 0,0 100 | ind = 0 101 | sequence,allseq = [],[] 102 | tar_in_A, tar_in_B = [], [] 103 | for item in session[:-2]: # get all the items except targetA and targetB 104 | if item <= item_border: # domain A 105 | item = item_index_old2new_A[item] 106 | seq1.append(item) 107 | pos1.append(len2) 108 | len1 += 1 109 | index1.append(ind) 110 | else: # domain B 111 | item = item_index_old2new_B[item-item_border-1] + len(entity_id2index_A) 112 | seq2.append(item) 113 | pos2.append(len1) 114 | len2 += 1 115 | index2.append(ind) 116 | sequence.append(item) 117 | allseq.append(item) 118 | ind += 1 119 | tarA = item_index_old2new_A[session[-2]] 120 | tarB = item_index_old2new_B[session[-1]-item_border-1] 121 | 122 | adj_1, adj_2, adj_3, adj_4, adj_5, index_dict = get_subgraph3(allseq) 123 | 124 | if tarA in index_dict.keys(): 125 | tar_in_A.append(1) 126 | else: 127 | tar_in_A.append(0) 128 | if tarB in index_dict.keys(): 129 | tar_in_B.append(1) 130 | else: 131 | tar_in_B.append(0) 132 | nei_mask_A, nei_mask_B,nei_mask_L_A,nei_mask_L_T = gen_mask(index_dict,sequence) 133 | nei_indexA, nei_indexB,IsinnumA, IsinnumB = gen_index(index_dict) 134 | temp.append(seq1) 135 | temp.append(seq2) 136 | temp.append(pos1) 137 | temp.append(pos2) 138 | temp.append(len1) 139 | temp.append(len2) 140 | temp.append(tarA) 141 | temp.append(tarB) 142 | temp.append(seq1 + seq2) 143 | temp.append(index1) 144 | temp.append(index2) 145 | temp.append(adj_1.toarray()) 146 | temp.append(adj_2.toarray()) 147 | temp.append(adj_3.toarray()) 148 | temp.append(adj_4.toarray()) 149 | temp.append(adj_5.toarray()) 150 | temp.append(index_dict) 151 | temp.append(nei_indexA) 152 | temp.append(nei_indexB) 153 | temp.append(nei_mask_A) 154 | temp.append(nei_mask_B) 155 | temp.append(nei_mask_L_A) 156 | temp.append(nei_mask_L_T) 157 | temp.append(IsinnumA) 158 | temp.append(IsinnumB) 159 | temp.append(tar_in_A) 160 | temp.append(tar_in_B) 161 | pk.dump(temp, open(filepath + str(all_count) + '.pkl', 'wb+'), -1) 162 | sessions.append(temp) 163 | 164 | return sessions 165 | 166 | def initial_adj(): 167 | adj_1 = sparse.dok_matrix((max_Nei, max_Nei), dtype=np.float32) 168 | adj_2 = sparse.dok_matrix((max_Nei, max_Nei), dtype=np.float32) 169 | adj_3 = sparse.dok_matrix((max_Nei, max_Nei), dtype=np.float32) 170 | adj_4 = sparse.dok_matrix((max_Nei, max_Nei), dtype=np.float32) 171 | adj_5 = sparse.dok_matrix((max_Nei, max_Nei), dtype=np.float32) 172 | for i in range(max_Nei): 173 | adj_1[i, i] = 1 174 | adj_2[i, i] = 1 175 | adj_3[i, i] = 1 176 | adj_4[i, i] = 1 177 | adj_5[i, i] = 1 178 | return adj_1,adj_2,adj_3,adj_4,adj_5, 179 | 180 | def get_subgraph3(session): 181 | 182 | node = np.unique(session) 183 | node_a = [] 184 | node_b = [] 185 | for i in node: 186 | if i <= entity_border: 187 | node_a.append(i) 188 | else: 189 | node_b.append(i) 190 | 191 | adj_1, adj_2, adj_3, adj_4, adj_5, = initial_adj() 192 | a_set = set(node_a) 193 | b_set = set(node_b) 194 | Nei_a_dic, Nei_b_dic = {},{} 195 | nei_dic_a, nei_dic_b = {},{} 196 | bNei_a_dic, bNei_b_dic = {},{} 197 | temp_node_a = node_a 198 | temp_node_b = node_b 199 | for hop in range(max_H): 200 | Nei_a_dic[hop],Nei_b_dic[hop] = {},{} 201 | nei_dic_a[hop],nei_dic_b[hop] = set(), set() 202 | bNei_a_dic[hop], bNei_b_dic[hop] = {}, {} 203 | nei_set_a, nei_set_b = extract_h_graph(temp_node_a,temp_node_b,hop, 204 | nei_dic_a,nei_dic_b, 205 | Nei_a_dic,Nei_b_dic, 206 | bNei_a_dic, bNei_b_dic) # within 1-hop neighbor 207 | temp_node_a = nei_set_a - a_set - b_set 208 | temp_node_b = nei_set_b - b_set - a_set 209 | connected,X,Y = IsConnect3(nei_set_a, nei_set_b, node_a, node_b) 210 | a_set.update(nei_set_a) 211 | b_set.update(nei_set_b) 212 | if connected or hop == max_H-1: 213 | new_node_a, new_node_b, \ 214 | new_dict_a, new_dict_b = pruning_nei(node_a,node_b,a_set,b_set, 215 | nei_dic_a, nei_dic_b,Nei_a_dic, Nei_b_dic, 216 | bNei_a_dic, bNei_b_dic, hop, X,Y) 217 | index_dict = build_adj(session, new_node_a, new_node_b, new_dict_a, new_dict_b, 218 | adj_1, adj_2, adj_3, adj_4, adj_5,) 219 | return adj_1, adj_2, adj_3, adj_4, adj_5, index_dict 220 | 221 | def build_adj(session,a_set,b_set, new_dict_a, new_dict_b, 222 | adj_1, adj_2, adj_3, adj_4, adj_5,): 223 | 224 | a_set = a_set.union(b_set) 225 | if len(a_set) < max_Nei: 226 | pool = set(kg.keys()).difference(a_set) 227 | rest = list(random.sample(list(pool), max_Nei - len(a_set))) 228 | a_set.update(rest) 229 | index_dict = dict(zip(list(a_set)[:max_Nei], range(max_Nei))) 230 | else: 231 | index_dict = dict(zip(list(a_set)[:max_Nei], range(max_Nei))) 232 | all_dic = {} 233 | all_dic.update(new_dict_a) 234 | all_dic.update(new_dict_b) 235 | for k in all_dic.keys(): 236 | if k not in index_dict.keys(): 237 | break 238 | else: 239 | u = index_dict[k] 240 | for it in all_dic[k]: 241 | if it[0] not in index_dict.keys(): 242 | break 243 | else: 244 | v = index_dict[it[0]] 245 | if it[1] == 0: # also-buy 246 | adj_1[u, v] = 1 247 | elif it[1] == 1: # also-view 248 | adj_2[u, v] = 1 249 | elif it[1] == 2: # buy-after-view 250 | adj_3[u, v] = 1 251 | elif it[1] == 3: # buy-together 252 | adj_4[u, v] = 1 253 | elif it[1] == 4: # category 254 | adj_5[u, v] = 1 255 | 256 | for ind in range(len(session)): 257 | item = session[ind] 258 | if item not in index_dict.keys(): 259 | continue 260 | else: 261 | u = index_dict[item] 262 | if ind < len(session)-1: 263 | next_item = session[ind+1] 264 | v = index_dict[next_item] 265 | adj_1[u, v] = 1 266 | adj_2[u, v] = 1 267 | adj_3[u, v] = 1 268 | adj_4[u, v] = 1 269 | adj_5[u, v] = 1 270 | else: 271 | break 272 | 273 | return index_dict 274 | 275 | def extract_h_graph(temp_node_a, temp_node_b, hop, nei_dic_a, nei_dic_b, Nei_a_dic, Nei_b_dic, 276 | bNei_a_dic, bNei_b_dic): 277 | a_set = set(temp_node_a) 278 | b_set = set(temp_node_b) 279 | orig_set = a_set.union(b_set) 280 | all_count = len(orig_set) 281 | 282 | for item in a_set: 283 | if item in kg.keys(): 284 | neighbors = kg[item] 285 | else: 286 | neighbors = [] 287 | if item not in Nei_a_dic[hop].keys() and neighbors: 288 | Nei_a_dic[hop][item] = [] 289 | Nei_a_dic[hop][item] += neighbors 290 | for pair in neighbors: 291 | if pair[0] not in bNei_a_dic[hop].keys(): 292 | bNei_a_dic[hop][pair[0]] = [] 293 | bNei_a_dic[hop][pair[0]].append((item,pair[1])) 294 | nei_set_a = set() 295 | for nn in a_set: 296 | if nn in Nei_a_dic[hop].keys(): 297 | li = set([nei_pair[0] for nei_pair in Nei_a_dic[hop][nn]]) 298 | nei_dic_a[hop].update(li) 299 | nei_set_a.update(li) 300 | temp_a_count = len(nei_set_a) 301 | 302 | for item in b_set: 303 | if item in kg.keys(): 304 | neighbors = kg[item] 305 | else: 306 | neighbors = [] 307 | if item not in Nei_b_dic[hop].keys() and neighbors: 308 | Nei_b_dic[hop][item] = [] 309 | Nei_b_dic[hop][item] += neighbors 310 | for pair in neighbors: 311 | if pair[0] not in bNei_b_dic[hop].keys(): 312 | bNei_b_dic[hop][pair[0]] = [] 313 | bNei_b_dic[hop][pair[0]].append((item,pair[1])) 314 | nei_set_b = set() 315 | for nn in b_set: 316 | if nn in Nei_b_dic[hop].keys(): 317 | li = set([nei_pair[0] for nei_pair in Nei_b_dic[hop][nn]]) 318 | nei_dic_b[hop].update(li) 319 | nei_set_b.update(li) 320 | temp_b_count = len(nei_set_b) 321 | 322 | return nei_set_a, nei_set_b 323 | 324 | def pruning_nei(node_a,node_b,a_set,b_set,nei_dic_a, nei_dic_b,Nei_a_dic, Nei_b_dic, 325 | bNei_a_dic, bNei_b_dic, hop,x,y): 326 | 327 | orig_a = set(node_a) 328 | orig_b = set(node_b) 329 | orig_all = len(orig_a) + len(orig_b) 330 | current_set = a_set.union(b_set) 331 | current_num = len(current_set) 332 | nei_set_a = a_set - orig_a 333 | nei_set_b = b_set - orig_b 334 | temp_a_count = len(nei_set_a) 335 | temp_b_count = len(nei_set_b) 336 | if current_num < max_Nei: # kg-extra-sample 337 | rest_count = max_Nei - current_num 338 | if temp_a_count > 0 and temp_b_count > 0: 339 | sam_a_count = int(np.ceil(rest_count * (len(a_set) / current_num))) 340 | sam_b_count = rest_count - sam_a_count 341 | else: 342 | sam_a_count = int(np.ceil(rest_count * (len(orig_a) / orig_all))) 343 | sam_b_count = rest_count - sam_a_count 344 | return sample_kg(sam_a_count,sam_b_count,a_set,b_set,hop,Nei_a_dic, Nei_b_dic) 345 | elif current_num == max_Nei: # self-sample 346 | return sample_self(a_set,b_set,hop,Nei_a_dic, Nei_b_dic) 347 | else: 348 | if x: # a-->b connected 349 | find_path(x, hop, orig_a, node_a, bNei_a_dic,temp_a_count) 350 | if y: # b-->a connected 351 | find_path(y, hop, orig_b, node_b, bNei_b_dic,temp_b_count) 352 | rest_count = max_Nei - len(orig_a.union(orig_b)) 353 | sam_a_count = int(np.ceil(rest_count * (len(a_set) / current_num))) 354 | sam_b_count = rest_count - sam_a_count 355 | if temp_a_count < sam_a_count: # a不够 356 | delta = sam_a_count - temp_a_count 357 | sam_a_count = temp_a_count 358 | sam_b_count = sam_b_count + delta 359 | return sample_frequency(node_a,node_b,sam_a_count,sam_b_count,a_set,b_set,nei_dic_a, nei_dic_b,hop,Nei_a_dic, Nei_b_dic,1,bNei_a_dic, bNei_b_dic) 360 | elif temp_b_count < sam_b_count: # b不够 361 | delta = sam_b_count - temp_b_count 362 | sam_b_count = temp_b_count 363 | sam_a_count = sam_a_count + delta 364 | return sample_frequency(node_a,node_b,sam_a_count,sam_b_count,a_set,b_set,nei_dic_a, nei_dic_b,hop,Nei_a_dic, Nei_b_dic,2,bNei_a_dic, bNei_b_dic) 365 | else: 366 | return sample_frequency(node_a,node_b,sam_a_count,sam_b_count,a_set,b_set,nei_dic_a, nei_dic_b,hop,Nei_a_dic, Nei_b_dic,3,bNei_a_dic, bNei_b_dic) 367 | 368 | def find_path(Z,hop,origset,orignode,bNeidict,tempcount): 369 | pathnode = set() 370 | pathnode.update(Z) 371 | for i in range(hop+1): 372 | j = hop - i 373 | for item in Z: 374 | if item in bNeidict[j].keys(): 375 | backnode = set([pair[0] for pair in bNeidict[j][item]]) 376 | pathnode.update(backnode) 377 | Z = pathnode - Z 378 | mid = pathnode-Z 379 | origset.update(mid) 380 | for i in mid: 381 | orignode.append(i) 382 | tempcount -= 1 383 | 384 | def sample_kg(sam_a_count,sam_b_count,a_set,b_set,hop,Nei_a_dic, Nei_b_dic): 385 | new_dict_a, new_dict_b = {},{} 386 | c = a_set.union(b_set) 387 | pool = set(kg.keys()) - c 388 | extra_a_list = np.random.choice(list(pool), size=sam_a_count, replace=False) 389 | new_node_a = a_set.union(set(extra_a_list)) 390 | pool = pool - set(extra_a_list) 391 | extra_dic_a = {} 392 | if not Nei_a_dic[0].keys(): 393 | ind = np.random.choice(list(Nei_b_dic[0].keys()), size=1, replace=False)[0] 394 | else: 395 | ind = np.random.choice(list(Nei_a_dic[0].keys()), size=1, replace=False)[0] 396 | extra_dic_a[ind] = [] 397 | for item in extra_a_list: 398 | extra_dic_a[ind] += [(item, -1)] 399 | for h in range(hop+1): 400 | for i in Nei_a_dic[h].keys(): 401 | if i not in new_dict_a.keys(): 402 | new_dict_a[i] = [] 403 | new_dict_a[i] += Nei_a_dic[h][i] 404 | if ind not in new_dict_a.keys(): 405 | new_dict_a[ind] = [] 406 | new_dict_a[ind] += extra_dic_a[ind] 407 | 408 | extra_b_list = np.random.choice(list(pool), size=sam_b_count, replace=False) 409 | new_node_b = b_set.union(extra_b_list) 410 | extra_dic_b = {} 411 | if not Nei_b_dic[0].keys(): 412 | ind = np.random.choice(list(Nei_a_dic[0].keys()), size=1, replace=False)[0] 413 | else: 414 | ind = np.random.choice(list(Nei_b_dic[0].keys()), size=1, replace=False)[0] 415 | extra_dic_b[ind] = [] 416 | for item in extra_b_list: 417 | extra_dic_b[ind] += [(item, -1)] 418 | for h in range(hop+1): 419 | for i in Nei_b_dic[h].keys(): 420 | if i not in new_dict_b.keys(): 421 | new_dict_b[i] = [] 422 | new_dict_b[i] += Nei_b_dic[h][i] 423 | if ind not in new_dict_b.keys(): 424 | new_dict_b[ind] = [] 425 | new_dict_b[ind] += extra_dic_b[ind] 426 | return new_node_a, new_node_b, new_dict_a, new_dict_b 427 | 428 | def sample_self(a_set,b_set,hop,Nei_a_dic, Nei_b_dic): 429 | new_dict_a, new_dict_b = {}, {} 430 | for h in range(hop+1): 431 | for i in Nei_a_dic[h].keys(): 432 | if i not in new_dict_a.keys(): 433 | new_dict_a[i] = [] 434 | new_dict_a[i] += Nei_a_dic[h][i] 435 | for h in range(hop+1): 436 | for i in Nei_b_dic[h].keys(): 437 | if i not in new_dict_b.keys(): 438 | new_dict_b[i] = [] 439 | new_dict_b[i] += Nei_b_dic[h][i] 440 | return a_set,b_set, new_dict_a, new_dict_b 441 | 442 | def sample_frequency(node_a,node_b,sam_a_count,sam_b_count,a_set,b_set, 443 | nei_dic_a,nei_dic_b,hop,Nei_a_dic, Nei_b_dic,k, 444 | bNei_a_dic, bNei_b_dic): 445 | new_node_a, new_node_b = set(), set() 446 | new_dict_a, new_dict_b = {}, {} 447 | if k == 1: 448 | new_node_a = a_set 449 | new_dict_a = {} 450 | for h in range(hop+1): 451 | for i in Nei_a_dic[h].keys(): 452 | if i not in new_dict_a.keys(): 453 | new_dict_a[i] = [] 454 | new_dict_a[i] += Nei_a_dic[h][i] 455 | new_node_b = set(node_b) 456 | new_dict_b = {} 457 | for h in range(hop+1): 458 | temp_b = nei_dic_b[h]-(new_node_a-set(node_a)) 459 | temp_b_rel = Nei_b_dic[h] 460 | need_num = sam_b_count - len(new_node_b - set(node_b) - set(new_node_a)) 461 | if len(temp_b) <= need_num: 462 | new_node_b.update(temp_b) 463 | for item in temp_b_rel.keys(): 464 | if item not in new_dict_b.keys(): 465 | new_dict_b[item] = [] 466 | new_dict_b[item] += temp_b_rel[item] 467 | else: 468 | c = new_node_b.union(new_node_a) 469 | temp_fre = read_embsim_temp(temp_b, need_num, c, bNei_b_dic, h) 470 | new_node_b.update(temp_fre) 471 | new_dict_b = add_rel(temp_fre, temp_b_rel, new_dict_b) # 增加对应的rel 472 | break 473 | 474 | if k == 2: # a采样,b不变 475 | new_node_b = b_set 476 | new_dict_b = {} 477 | for h in range(hop+1): # 把每一层的关系都加入 478 | for i in Nei_b_dic[h].keys(): 479 | if i not in new_dict_b.keys(): 480 | new_dict_b[i] = [] 481 | new_dict_b[i] += Nei_b_dic[h][i] 482 | new_node_a = set(node_a) #具有原始的node-a 483 | new_dict_a = {} 484 | for h in range(hop+1): 485 | temp_a = nei_dic_a[h]-(new_node_b-set(node_b)) # 拿到当前hop的nei 486 | temp_a_rel = Nei_a_dic[h] # 拿到当前hop的nei-rel 487 | need_num = sam_a_count - len(new_node_a - set(node_a) - set(new_node_b)) 488 | if len(temp_a) <= need_num: # 把该hop的都加入 489 | new_node_a.update(temp_a) 490 | for item in temp_a_rel.keys(): 491 | if item not in new_dict_a.keys(): 492 | new_dict_a[item] = [] 493 | new_dict_a[item] += temp_a_rel[item] 494 | else: # 到该层stop,从该层中按照fre-sam 495 | c = new_node_a.union(set(node_b)) 496 | # temp_fre = read_fre_temp(temp_a,need_num,c) 497 | temp_fre = read_embsim_temp(temp_a, need_num, c, bNei_a_dic, h) 498 | # print('top-k-a:', temp_fre[:need_num]) 499 | # x = set([ent[0] for ent in temp_fre[:need_num]]) 500 | new_node_a.update(temp_fre) 501 | new_dict_a = add_rel(temp_fre, temp_a_rel, new_dict_a) # 增加对应的rel 502 | # if len(new_node_a - set(node_a) - new_node_b) < need_num: 503 | # print('having repeat sample in node-a!!!') 504 | # e_num = need_num - len(new_node_a - set(node_a) - new_node_b) 505 | # x = set([ent[0] for ent in temp_fre[need_num:need_num + e_num]]) 506 | # new_node_a.update(x) 507 | # new_dict_a = add_rel(x, temp_a_rel, new_dict_a) # 增加对应的rel 508 | # # need_num += e_num 509 | break 510 | 511 | if k == 3: # a,b均采样 512 | ########### 对a采样 ################### 513 | print('k3--sample--a') 514 | print('for A sampling............') 515 | # print(node_a) 516 | new_node_a = set(node_a) # 具有原始的node-a加上路径上的mid-node 517 | new_dict_a = {} 518 | for h in range(hop+1): 519 | temp_a = nei_dic_a[h] # 拿到当前hop的nei 520 | temp_a_rel = Nei_a_dic[h] # 拿到当前hop的nei-rel 521 | need_num = sam_a_count - len(new_node_a-set(node_a)-set(node_b)) 522 | print('need-num:',need_num) 523 | print(len(temp_a)) 524 | if len(temp_a) <= need_num: # 把该hop的都加入 525 | new_node_a.update(temp_a) 526 | for item in temp_a_rel.keys(): 527 | if item not in new_dict_a.keys(): 528 | new_dict_a[item] = [] 529 | new_dict_a[item] += temp_a_rel[item] 530 | else: # 到该层stop,从该层中按照fre-sam 531 | c = new_node_a.union(set(node_b)) 532 | # temp_fre = read_fre_temp(temp_a,need_num,c) # 按照fre采样 533 | temp_fre = read_embsim_temp(temp_a, need_num, c,bNei_a_dic,h) # 按照emb-sim采样 534 | print('top-k-a:',len(temp_fre)) 535 | new_node_a.update(temp_fre) 536 | new_dict_a = add_rel(temp_fre, temp_a_rel, new_dict_a) # 增加对应的rel 537 | break 538 | ########### 对b采样 ################### 539 | # print('k3--sample--b') 540 | print('for B sampling............') 541 | # print('new-node-a:', new_node_a) 542 | # print(node_b) 543 | new_node_b = set(node_b) # 具有原始的node-b 544 | new_dict_b = {} 545 | for h in range(hop+1): 546 | temp_b = nei_dic_b[h]-(new_node_a-set(node_a)) # 拿到当前hop的nei 547 | temp_b_rel = Nei_b_dic[h] # 拿到当前hop的nei-rel 548 | need_num = sam_b_count - len(new_node_b - set(node_b) - set(new_node_a)) 549 | print('need-num:', need_num) 550 | print(len(temp_b)) 551 | if len(temp_b) <= need_num: # 把该hop的都加入 552 | new_node_b.update(temp_b) 553 | for item in temp_b_rel.keys(): 554 | if item not in new_dict_b.keys(): 555 | new_dict_b[item] = [] 556 | new_dict_b[item] += temp_b_rel[item] 557 | else: # 到该层stop,从该层中按照fre-sam 558 | c = new_node_b.union(new_node_a) 559 | # temp_fre = read_fre_temp(temp_b,need_num,c) 560 | temp_fre = read_embsim_temp(temp_b, need_num, c, bNei_b_dic, h) 561 | print('top-k-b:', len(temp_fre)) 562 | # x = set([ent[0] for ent in temp_fre[:need_num]]) 563 | new_node_b.update(temp_fre) 564 | new_dict_b = add_rel(temp_fre, temp_b_rel, new_dict_b) # 增加对应的rel 565 | # if len(new_node_b - set(node_b) - new_node_a) < need_num: 566 | # # print('having repeat sample in node-b!!!') 567 | # e_num = need_num - len(new_node_b - set(node_b) - new_node_a) 568 | # x = set([ent[0] for ent in temp_fre[need_num:need_num+e_num]]) 569 | # new_node_b.update(x) 570 | # new_dict_b = add_rel(x, temp_b_rel, new_dict_b) # 增加对应的rel 571 | # # need_num += e_num 572 | # print('new-node-b:',new_node_b) 573 | break 574 | return new_node_a, new_node_b, new_dict_a, new_dict_b 575 | 576 | def read_embsim_temp(nei_set_temp,need_num,nodeset,bNeidict,hop): 577 | score = [] 578 | node_temp = [] 579 | for item in nei_set_temp: 580 | if item in bNeidict[hop].keys(): 581 | backnode = set([pair[0] for pair in bNeidict[hop][item]]) 582 | for s in cal_sim_p(item, backnode): 583 | score.append(s) 584 | for p in backnode: 585 | node_temp.append(item) 586 | ind = np.argsort(score).tolist() 587 | ind.reverse() 588 | a = [node_temp[ii] for ii in ind] 589 | get_ent_fre = set(a) 590 | fre_set = set() 591 | count = 0 592 | for item in get_ent_fre: 593 | if count < need_num: 594 | if item not in nodeset: 595 | fre_set.add(item) 596 | count += 1 597 | else: 598 | break 599 | return fre_set 600 | 601 | def cal_sim_p(item,backnode): 602 | item_vec = get_line_context(pretrain_emb_path,item+1)[1:] 603 | backnode_vec = [get_line_context(pretrain_emb_path,i+1)[1:] for i in backnode] 604 | score = [] 605 | for vec in backnode_vec: 606 | vec_score = cos_sim(item_vec,vec) 607 | score.append(vec_score) 608 | return score 609 | 610 | def get_line_context(file_path, line_number): 611 | line = linecache.getline(file_path, line_number).strip().split(' ') 612 | fltline = list(map(float, line)) 613 | return np.array(fltline) 614 | 615 | def add_rel(node, rel_dic,new_dic): 616 | for ii in rel_dic.keys(): 617 | if not rel_dic[ii]: 618 | continue 619 | nodeli = list(np.array(rel_dic[ii])[:, 0]) 620 | relli = list(np.array(rel_dic[ii])[:, 1]) 621 | for jj in range(len(nodeli)): 622 | if nodeli[jj] in node: 623 | r = relli[jj] 624 | if ii not in new_dic.keys(): 625 | new_dic[ii] = [] 626 | new_dic[ii] += [(nodeli[jj],r)] 627 | return new_dic 628 | 629 | def IsConnect3(nei_set_a, nei_set_b, node_a, node_b): 630 | orig_a = set(node_a) 631 | orig_b = set(node_b) 632 | flag = False 633 | x = nei_set_a.intersection(orig_b) 634 | y = nei_set_b.intersection(orig_a) 635 | if x or y: 636 | flag = True 637 | return flag,x,y 638 | 639 | def load_kg(): 640 | if os.path.exists(kg_file + '.npy'): 641 | kg_np = np.load(kg_file + '.npy') 642 | else: 643 | kg_np = np.loadtxt(kg_file + '.txt', dtype=np.int64, delimiter='\t') 644 | np.save(kg_file + '.npy', kg_np) 645 | kg = collections.defaultdict(list) 646 | for head, relation, tail in kg_np: 647 | kg[head].append((tail, relation)) 648 | return kg 649 | 650 | def cos_sim(vector_a, vector_b): 651 | vector_a = np.mat(vector_a) 652 | vector_b = np.mat(vector_b) 653 | num = float(vector_a * vector_b.T) 654 | denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b) 655 | cos = num / denom 656 | sim = 0.5 + 0.5 * cos 657 | return sim 658 | 659 | if __name__ == '__main__': 660 | train_path = '' 661 | valid_path = '' 662 | test_path = '' 663 | ent2ind_a_p = '' 664 | ent2ind_b_p = '' 665 | item2ind_a_p = '' 666 | item2ind_b_p = '' 667 | entity_id2index_A = dict() 668 | entity_id2index_B = dict() 669 | item_index_old2new_A = dict() 670 | item_index_old2new_B = dict() 671 | read_all_responding_index() 672 | entity_border = len(entity_id2index_A) - 1 673 | item_border = len(item_index_old2new_A) - 1 674 | print('entity-border:', entity_border) 675 | print('item-border:', item_border) 676 | 677 | kg_file = '' 678 | train_file = '' 679 | valid_file = '' 680 | test_file = '' 681 | max_H = 2 682 | max_Nei = 200 683 | 684 | kg = load_kg() 685 | print('kg load done...') 686 | 687 | pretrain_emb_path = '' 688 | 689 | traindata = getdata(train_path) 690 | processdata(traindata, train_file) 691 | 692 | -------------------------------------------------------------------------------- /MIFN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import math 4 | from keras.utils.np_utils import to_categorical 5 | from scipy import sparse 6 | 7 | 8 | class FilterCell(tf.contrib.rnn.RNNCell): 9 | def __init__(self, num_units, activation=None, reuse=None, kernel_initializer=None, 10 | bias_initializer=None): 11 | super(FilterCell, self).__init__(_reuse=reuse) 12 | self._num_units = num_units 13 | self._activation = activation or tf.tanh 14 | self._kernel_initializer = kernel_initializer 15 | self._bias_initializer = bias_initializer 16 | 17 | @property 18 | def state_size(self): 19 | return self._num_units 20 | 21 | @property 22 | def output_size(self): 23 | return self._num_units 24 | 25 | def call(self, inputs, state): 26 | inputs_A, inputs_T = tf.split(inputs, num_or_size_splits=2, axis=1) 27 | if self._kernel_initializer is None: 28 | self._kernel_initializer = tf.contrib.layers.xavier_initializer(uniform=False) 29 | if self._bias_initializer is None: 30 | self._bias_initializer = tf.constant_initializer(1.0) 31 | with tf.variable_scope('gate'): # sigmoid([i_A|i_T|s_(t-1)]*[W_fA;W_fT;U_f]+b_f) 32 | self.W_f = tf.get_variable(dtype=tf.float32, name='W_f', 33 | shape=[inputs.get_shape()[-1].value + state.get_shape()[-1].value, self._num_units], 34 | initializer=self._kernel_initializer) 35 | self.b_f = tf.get_variable(dtype=tf.float32, name='b_f', shape=[self._num_units, ], 36 | initializer=self._bias_initializer) 37 | f = tf.concat([inputs, state], axis=-1) # f=[batch_size, hidden_size+hidden_size+self._num_units] 38 | f = tf.matmul(f, self.W_f) # f=[batch_size,self._num_units] 39 | f = f + self.b_f # f=[batch_size, self._num_units] 40 | f = tf.sigmoid(f) # f=[batch_size, self._num_units] 41 | 42 | with tf.variable_scope('candidate'): # tanh([i_A|s_(t-1)]*[W_s;U_s]+b_s) 43 | self.W_s = tf.get_variable(dtype=tf.float32, name='W_s', 44 | shape=[inputs_A.get_shape()[-1].value + state.get_shape()[-1].value, 45 | self._num_units], initializer=self._kernel_initializer) 46 | self.b_s = tf.get_variable(dtype=tf.float32, name='b_s', shape=[self._num_units, ], 47 | initializer=self._bias_initializer) 48 | _s = tf.concat([inputs_A, state], axis=-1) # _s=[batch_size, hidden_size+self._num_units] 49 | _s = tf.matmul(_s, self.W_s) # _s=[batch_size,self._num_units] 50 | _s = _s + self.b_s # _s=[batch_size,self._num_units] 51 | _s = self._activation(_s) 52 | 53 | new_s = f * _s + (1 - f) * state # new_s=[batch_size, self._num_units] 54 | return new_s, new_s 55 | 56 | class MIFN(): 57 | 58 | def __init__(self, num_items_A, num_items_B, num_entity_A,num_entity_B,num_cate,batch_size, neighbor_num, gpu, 59 | embedding_size=256, hidden_size=256, num_layers=1, 60 | lr=0.01,keep_prob=0.8,n_iter=1, 61 | training_steps_per_epoch=5,lr_decay_factor=0.9,min_lr=0.00001): 62 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu 63 | self.config = tf.ConfigProto() 64 | self.config.gpu_options.allow_growth = True 65 | self.num_items_A = num_items_A 66 | self.num_items_B = num_items_B 67 | self.n_items = num_entity_A + num_entity_B + num_cate 68 | self.embedding_size = embedding_size 69 | self.hidden_size = hidden_size 70 | self.num_layers = num_layers 71 | self.batch_size = batch_size 72 | self.neighbors_num = neighbor_num 73 | self.n_iter = n_iter 74 | self.keep_prob = keep_prob 75 | self.graph = tf.Graph() 76 | 77 | with self.graph.as_default(): 78 | tf.set_random_seed(1) 79 | with tf.name_scope('inputs'): 80 | self.get_inputs() 81 | with tf.name_scope('all_encoder'): 82 | self.all_encoder() 83 | with tf.name_scope('encoder_A'): 84 | encoder_output_A, encoder_state_A = self.encoder_A() 85 | with tf.name_scope('encoder_B'): 86 | encoder_output_B, encoder_state_B = self.encoder_B() 87 | with tf.name_scope('sequence_transfer_A'): 88 | filter_output_A, filter_state_A = self.filter_A(encoder_output_A, encoder_output_B,) 89 | transfer_output_A, transfer_state_A = self.transfer_A(filter_output_A,) 90 | with tf.name_scope('sequence_transfer_B'): 91 | filter_output_B, filter_state_B = self.filter_B(encoder_output_B, encoder_output_A,) 92 | transfer_output_B, transfer_state_B = self.transfer_B(filter_output_B,) 93 | 94 | with tf.name_scope('graph_transfer'): 95 | entity_emb = self.graph_gnn(encoder_output_A, encoder_output_B, 96 | transfer_output_A, transfer_output_B) 97 | 98 | with tf.name_scope('prediction_A'): 99 | self.PG_A, self.PS_A = self.switch_A(encoder_state_A, transfer_state_B, entity_emb,self.nei_A_mask,) 100 | s_pred_A = self.s_decoder_A(self.num_items_A, encoder_state_A, transfer_state_B, self.keep_prob) 101 | g_pred_A,g_att_A = self.g_decoder_A(encoder_state_A,entity_emb,self.num_items_A,self.nei_A_mask, 102 | self.nei_index_A, self.IsinnumA) 103 | self.pred_A = self.final_pred_A(self.PG_A, self.PS_A, s_pred_A, g_pred_A) 104 | 105 | with tf.name_scope('prediction_B'): 106 | self.PG_B, self.PS_B = self.switch_B(encoder_state_B, transfer_state_A, entity_emb,self.nei_B_mask,) 107 | s_pred_B = self.s_decoder_B(self.num_items_B, encoder_state_B, transfer_state_A, self.keep_prob) 108 | g_pred_B,g_att_B = self.g_decoder_B(encoder_state_B, entity_emb, self.num_items_B,self.nei_B_mask, 109 | self.nei_index_B, self.IsinnumB) 110 | self.pred_B = self.final_pred_B(self.PG_B, self.PS_B, s_pred_B, g_pred_B) 111 | 112 | with tf.name_scope('loss'): 113 | self.loss = self.cal_loss(self.target_A, self.pred_A, self.target_B, self.pred_B,) 114 | 115 | with tf.name_scope('optimizer'): 116 | self.train_op,self.grad = self.optimizer(lr,training_steps_per_epoch,lr_decay_factor,min_lr) 117 | 118 | def get_inputs(self): 119 | self.seq_A = tf.placeholder(dtype=tf.int32, shape=[None, None], name='seq_A') 120 | self.seq_B = tf.placeholder(dtype=tf.int32, shape=[None, None], name='seq_B') 121 | self.len_A = tf.placeholder(dtype=tf.int32, shape=[None, ], name='len_A') 122 | self.len_B = tf.placeholder(dtype=tf.int32, shape=[None, ], name='len_B') 123 | self.pos_A = tf.placeholder(dtype=tf.int32, shape=[None, None, 2], name='pos_A') 124 | self.pos_B = tf.placeholder(dtype=tf.int32, shape=[None, None, 2], name='pos_B') 125 | self.index_A = tf.placeholder(dtype=tf.int32, shape=[None, None, 2], name='index_A') 126 | self.index_B = tf.placeholder(dtype=tf.int32, shape=[None, None, 2], name='index_B') 127 | self.target_A = tf.placeholder(dtype=tf.int32,shape=[None,],name='target_A') 128 | self.target_B = tf.placeholder(dtype=tf.int32,shape=[None,],name='target_B') 129 | self.tar_in_A = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='tar_in_A') 130 | self.tar_in_B = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='tar_in_B') 131 | self.adj_1 = tf.placeholder(dtype=tf.float32, shape=[None,self.neighbors_num,self.neighbors_num],name='adj_alb') 132 | self.adj_2 = tf.placeholder(dtype=tf.float32, shape=[None,self.neighbors_num,self.neighbors_num],name='adj_alv') 133 | self.adj_3 = tf.placeholder(dtype=tf.float32, shape=[None,self.neighbors_num,self.neighbors_num],name='adj_bav') 134 | self.adj_4 = tf.placeholder(dtype=tf.float32, shape=[None,self.neighbors_num,self.neighbors_num],name='adj_bt') 135 | self.adj_5 = tf.placeholder(dtype=tf.float32, shape=[None,self.neighbors_num,self.neighbors_num],name='adj_ada') 136 | self.neighbors = tf.placeholder(dtype=tf.int64, shape=[None, self.neighbors_num],name='neighbors') 137 | self.nei_index_A = tf.placeholder(dtype=tf.int64, shape=[None, self.neighbors_num, 2],name='nei_index_A') 138 | self.nei_index_B = tf.placeholder(dtype=tf.int64, shape=[None, self.neighbors_num, 2], name='nei_index_B') 139 | self.nei_A_mask = tf.placeholder(dtype=tf.float32, shape=[None, self.neighbors_num], name='nei_A_mask') 140 | self.nei_B_mask = tf.placeholder(dtype=tf.float32, shape=[None, self.neighbors_num], name='nei_B_mask') 141 | self.IsinnumA = tf.placeholder(dtype=tf.float32, shape=[None, self.neighbors_num], name='IsinnumA') 142 | self.IsinnumB = tf.placeholder(dtype=tf.float32, shape=[None, self.neighbors_num], name='IsinnumB') 143 | 144 | self.nei_L_A_mask = tf.placeholder(dtype=tf.float32, shape=[None, None, self.neighbors_num], 145 | name='nei_L_A_mask') 146 | self.nei_L_T_mask = tf.placeholder(dtype=tf.float32, shape=[None, None, self.neighbors_num], 147 | name='nei_L_T_mask') 148 | 149 | def get_gru_cell(self,hidden_size,keep_prob): 150 | gru_cell = tf.contrib.rnn.GRUCell(hidden_size, kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False)) 151 | gru_cell = tf.contrib.rnn.DropoutWrapper(gru_cell, input_keep_prob=keep_prob, 152 | output_keep_prob=keep_prob, 153 | state_keep_prob=keep_prob) 154 | return gru_cell 155 | 156 | def get_filter_cell(self, hidden_size, keep_prob): 157 | filter_cell = FilterCell(hidden_size) 158 | filter_cell = tf.contrib.rnn.DropoutWrapper(filter_cell, input_keep_prob=keep_prob, output_keep_prob=keep_prob, 159 | state_keep_prob=keep_prob) 160 | return filter_cell 161 | 162 | def all_encoder(self,): 163 | with tf.variable_scope('all_encoder'): 164 | self.all_emb_matrix = tf.get_variable(shape=[self.n_items, self.embedding_size], name='item_emb_matrix', 165 | initializer=tf.contrib.layers.xavier_initializer(uniform=False)) 166 | 167 | def encoder_A(self,): 168 | with tf.variable_scope('encoder_A'): 169 | embedd_seq_A = tf.nn.embedding_lookup(self.all_emb_matrix, self.seq_A) 170 | print(embedd_seq_A) 171 | encoder_cell_A = tf.nn.rnn_cell.MultiRNNCell([self.get_gru_cell(self.hidden_size,self.keep_prob) for _ in range(self.num_layers)]) 172 | encoder_output_A, encoder_state_A = tf.nn.dynamic_rnn(encoder_cell_A, embedd_seq_A, sequence_length=self.len_A, dtype=tf.float32) 173 | return encoder_output_A, encoder_state_A, 174 | 175 | def encoder_B(self,): 176 | with tf.variable_scope('encoder_B'): 177 | embedd_seq_B = tf.nn.embedding_lookup(self.all_emb_matrix, self.seq_B) 178 | print(embedd_seq_B) 179 | encoder_cell_B = tf.nn.rnn_cell.MultiRNNCell([self.get_gru_cell(self.hidden_size,self.keep_prob) for _ in range(self.num_layers)]) 180 | encoder_output_B, encoder_state_B = tf.nn.dynamic_rnn(encoder_cell_B, embedd_seq_B, sequence_length=self.len_B, dtype=tf.float32) 181 | return encoder_output_B, encoder_state_B 182 | 183 | def filter_A(self, encoder_output_A, encoder_output_B,): 184 | with tf.variable_scope('filter_A'): 185 | zero_state = tf.zeros(dtype=tf.float32, shape=(tf.shape(encoder_output_A)[0], 1, tf.shape(encoder_output_A)[-1])) 186 | encoder_output = tf.concat([zero_state, encoder_output_B], axis=1) 187 | # print(encoder_output) #encoder_output=[batch_size,timestamp_B+1,hidden_size] 188 | select_output_B = tf.gather_nd(encoder_output,self.pos_A) # 挑出A-output之前的B-item, len还是timestep_A 189 | # print(select_output_B) #select_output_A=[batch_size,timestamp_A,hidden_size] 190 | filter_input_A = tf.concat([encoder_output_A, select_output_B], axis=-1) # filter_input_A=[b,tA,2*h] 191 | # att = tf.layers.dense(tf.concat([encoder_output_A, select_output_B], axis=-1), units=hidden_size,activation=tf.nn.sigmoid) 192 | # combined_output_A = att * tf.nn.tanh(encoder_output_A) + (1 - att) * select_output_B 193 | # print(combined_output_A) #[batch_size,timestamp_A,hidden_size] 194 | filter_cell_A = tf.nn.rnn_cell.MultiRNNCell( 195 | [self.get_filter_cell(self.hidden_size, self.keep_prob) for _ in range(self.num_layers)]) 196 | filter_output_A, filter_state_A = tf.nn.dynamic_rnn(filter_cell_A, filter_input_A, sequence_length=self.len_A, 197 | dtype=tf.float32) 198 | # print(filter_output_A) # filter_output_A=[batch_size,timestamp_A,hidden_size], 199 | # print(filter_state_A) # filter_state_A=[batch_size,hidden_size] 200 | return filter_output_A, filter_state_A 201 | 202 | def transfer_A(self, filter_output_A,): 203 | 204 | with tf.variable_scope('transfer_A'): 205 | transfer_cell_A = tf.nn.rnn_cell.MultiRNNCell( 206 | [self.get_gru_cell(self.hidden_size, self.keep_prob) for _ in range(self.num_layers)]) 207 | transfer_output_A, transfer_state_A = tf.nn.dynamic_rnn(transfer_cell_A, filter_output_A, 208 | sequence_length=self.len_A,dtype=tf.float32) 209 | # print(transfer_output_A) # transfer_output_A=[batch_size,timestamp_A,hidden_size], 210 | # print(transfer_state_A) # transfer_state_A=([batch_size,hidden_size]*num_layers) 211 | return transfer_output_A, transfer_state_A 212 | 213 | def filter_B(self, encoder_output_B, encoder_output_A, ): 214 | with tf.variable_scope('filter_B'): 215 | zero_state = tf.zeros(dtype=tf.float32, 216 | shape=(tf.shape(encoder_output_B)[0], 1, tf.shape(encoder_output_B)[-1])) 217 | # print(zero_state) # zero_state=[batch_size,1,hidden_size] 218 | encoder_output = tf.concat([zero_state, encoder_output_A], axis=1) 219 | # print(encoder_output) # encoder_output=[batch_size,timestamp_B+1,hidden_size] 220 | select_output_A = tf.gather_nd(encoder_output, self.pos_B) # 挑出B-output之前的A-item 221 | # print(select_output_A) # select_output_B=[batch_size,timestamp_B,hidden_size] 222 | filter_input_B = tf.concat([encoder_output_B, select_output_A], axis=-1) 223 | # print(filter_input_B) # filter_input_B=[batch_size,timestamp_B,hidden_size+hidden_size] 224 | 225 | # att = tf.layers.dense(tf.concat([encoder_output_B, select_output_A], axis=-1), units=hidden_size, 226 | # activation=tf.nn.sigmoid) 227 | # combined_output_B = att * tf.nn.tanh(encoder_output_B) + (1 - att) * select_output_A 228 | # print(combined_output_B) # [batch_size,timestamp_B,hidden_size] 229 | filter_cell_B = tf.nn.rnn_cell.MultiRNNCell( 230 | [self.get_filter_cell(self.hidden_size, self.keep_prob) for _ in range(self.num_layers)]) 231 | filter_output_B, filter_state_B = tf.nn.dynamic_rnn(filter_cell_B, filter_input_B, sequence_length=self.len_B, 232 | dtype=tf.float32) 233 | # print(filter_output_B) # filter_output_B=[batch_size,timestamp_B,hidden_size], 234 | # print(filter_state_B) # filter_state_B=[batch_size,hidden_size] 235 | return filter_output_B, filter_state_B 236 | 237 | def transfer_B(self, filter_output_B,): 238 | with tf.variable_scope('transfer_B'): 239 | transfer_cell_B = tf.nn.rnn_cell.MultiRNNCell( 240 | [self.get_gru_cell(self.hidden_size, self.keep_prob) for _ in range(self.num_layers)]) 241 | transfer_output_B, transfer_state_B = tf.nn.dynamic_rnn(transfer_cell_B, filter_output_B, 242 | sequence_length=self.len_B, 243 | dtype=tf.float32) 244 | # print(transfer_output_B) # transfer_output_B=[batch_size,timestamp_B,hidden_size], 245 | # print(transfer_state_B) # transfer_state_B=([batch_size,hidden_size]*num_layers) 246 | return transfer_output_B, transfer_state_B 247 | 248 | def graph_gnn(self, encoder_output_A, encoder_output_B, 249 | transfer_output_A, transfer_output_B,): 250 | with tf.variable_scope('graph_gnn'): 251 | self.entity_emb = tf.nn.embedding_lookup(self.all_emb_matrix, self.neighbors) # [b,N,e] 252 | # ----------------- in-domain adj parameter ------------------ 253 | self.W_alb1 = random_weight(self.hidden_size, self.hidden_size, name='W_alb1') 254 | self.b11 = random_bias(self.hidden_size, name='b11') 255 | self.W_alv1 = random_weight(self.hidden_size, self.hidden_size, name='W_alv1') 256 | self.b21 = random_bias(self.hidden_size, name='b21') 257 | self.W_bav1 = random_weight(self.hidden_size, self.hidden_size, name='W_bav1') 258 | self.b31 = random_bias(self.hidden_size, name='b31') 259 | self.W_bt1 = random_weight(self.hidden_size, self.hidden_size, name='W_bt1') 260 | self.b41 = random_bias(self.hidden_size, name='b41') 261 | self.W_ada1 = random_weight(self.hidden_size, self.hidden_size, name='W_ada1') 262 | self.b51 = random_bias(self.hidden_size, name='b51') 263 | 264 | # ------------------- cross-domain adj parameter ------------------ 265 | self.W_alb2 = random_weight(self.hidden_size, self.hidden_size, name='W_alb2') 266 | self.b12 = random_bias(self.hidden_size, name='b12') 267 | self.W_alv2 = random_weight(self.hidden_size, self.hidden_size, name='W_alv2') 268 | self.b22 = random_bias(self.hidden_size, name='b22') 269 | self.W_bav2 = random_weight(self.hidden_size, self.hidden_size, name='W_bav2') 270 | self.b32 = random_bias(self.hidden_size, name='b32') 271 | self.W_bt2 = random_weight(self.hidden_size, self.hidden_size, name='W_bt2') 272 | self.b42 = random_bias(self.hidden_size, name='b42') 273 | self.W_ada2 = random_weight(self.hidden_size, self.hidden_size, name='W_ada2') 274 | self.b52 = random_bias(self.hidden_size, name='b52') 275 | 276 | inputs_A, inputs_A2T,\ 277 | nei_mask_A, nei_mask_T = self.get_ht(encoder_output_A, encoder_output_B, 278 | transfer_output_A, transfer_output_B, 279 | self.len_A, self.len_B, self.index_A, self.index_B, 280 | self.nei_L_A_mask, self.nei_L_T_mask) 281 | inputs_A = tf.tile(tf.expand_dims(inputs_A, axis=1), [1, self.neighbors_num, 1]) 282 | print('input-A:', inputs_A) # [b, N, h] 283 | inputs_A2T = tf.tile(tf.expand_dims(inputs_A2T, axis=1), [1, self.neighbors_num, 1]) 284 | print('input-A2T:', inputs_A2T) # [b, N, h] 285 | 286 | cell = tf.nn.rnn_cell.GRUCell(self.hidden_size) 287 | 288 | # -------------- softmask-A attention weight -------------- 289 | self.W_s_in = random_weight(self.hidden_size, self.hidden_size, name='W_s_in') 290 | self.W_emb_in = random_weight(self.hidden_size, self.hidden_size, name='W_emb_in') 291 | self.W_v_in = random_weight(self.hidden_size, 1, name='W_v_in') 292 | # -------------- softmask-A2T attention weight ------------- 293 | self.W_s_cross = random_weight(self.hidden_size, self.hidden_size, name='W_s_cross') 294 | self.W_emb_cross = random_weight(self.hidden_size, self.hidden_size, name='W_emb_cross') 295 | self.W_v_cross = random_weight(self.hidden_size, 1, name='W_v_cross') 296 | 297 | for i in range(self.n_iter): 298 | softmask_A = self.indomain_attention(inputs_A, nei_mask_A) 299 | softmask_A2T = self.crossdomain_attention(inputs_A2T, nei_mask_T) 300 | gcn_emb = self.get_neigh_rep(inputs_A, inputs_A2T,softmask_A, softmask_A2T, nei_mask_A,nei_mask_T) 301 | print(gcn_emb) # [b, N, 10*h] 302 | # self.entity_emb = tf.layers.dense(gcn_emb, self.hidden_size, 303 | # activation=None, 304 | # kernel_initializer=tf.contrib.layers.xavier_initializer( 305 | # uniform=False)) # [b, N, h] 306 | self.entity_emb = tf.reshape(self.entity_emb, [-1, self.hidden_size]) # [b*N, h] 307 | graph_output, self.entity_emb = \ 308 | tf.nn.dynamic_rnn(cell, tf.expand_dims(tf.reshape(gcn_emb, [-1, 12 * self.hidden_size]), axis=1), 309 | initial_state=self.entity_emb) 310 | print(graph_output) # graph_output=[b*N, h], 311 | print(self.entity_emb) # graph_state=[b*N, h]*num_layers 312 | 313 | self.entity_emb = tf.reshape(self.entity_emb, [-1, self.neighbors_num, self.hidden_size]) 314 | print('after gnn:', self.entity_emb) 315 | return self.entity_emb 316 | 317 | def get_neigh_rep(self,inputs_A, inputs_A2T,softmask_A, softmask_A2T, nei_mask_A,nei_mask_T): 318 | with tf.variable_scope('cdgcn', reuse=tf.AUTO_REUSE): 319 | 320 | self.W_c = random_weight(3 * self.hidden_size, self.hidden_size, 321 | name='W_f') 322 | self.b_c = random_bias(self.hidden_size, name='b_f') 323 | inputs = tf.concat([inputs_A, inputs_A2T], axis=-1) # [b, N, 2h] 324 | f = tf.concat([inputs, self.entity_emb], axis=-1) # [b, N, 3h] 325 | f = tf.matmul(f, self.W_c) + self.b_c 326 | f = tf.sigmoid(f) # [b, N, h] 327 | print('cross-gate:', f) # f=[b, N, h] 328 | softmask_A = f * softmask_A 329 | softmask_A2T = (1 - f) * softmask_A2T 330 | 331 | #--------------- in-domain neighbor ------------------ 332 | fin_state_A = tf.reshape(softmask_A, [-1, self.hidden_size]) 333 | self.W1 = random_weight(self.hidden_size, self.hidden_size, name='w1') 334 | self.b1 = random_bias(self.hidden_size, name='b1') 335 | fin_state_A = tf.matmul(fin_state_A, self.W1) + self.b1 336 | print(fin_state_A) # [b*N, s] 337 | 338 | #---------------- cross-domain neighbor --------------- 339 | fin_state_A2T = tf.reshape(softmask_A2T, [-1, self.hidden_size]) 340 | self.W2 = random_weight(self.hidden_size, self.hidden_size, name='w2') 341 | self.b2 = random_bias(self.hidden_size, name='b2') 342 | fin_state_A2T = tf.matmul(fin_state_A2T, self.W2) + self.b2 343 | print(fin_state_A2T) # [b*N, s] 344 | 345 | nei_mask_A = tf.expand_dims(nei_mask_A, -1) # [b,N,1] 346 | nei_mask_T = tf.expand_dims(nei_mask_T, -1) # [b,N,1] 347 | mask_emb_A = nei_mask_A * self.entity_emb # [b,N,h] 348 | mask_emb_T = nei_mask_T * self.entity_emb # [b,N,h] 349 | att_emb_T = tf.reshape(self.mutual_att(mask_emb_T, mask_emb_A), [-1, self.hidden_size]) # [b*N,s] 350 | fin_state_A2T = tf.add(fin_state_A2T, att_emb_T) 351 | print(fin_state_A2T) # [b*N, s] 352 | 353 | # ------------------ in-domain representation -------------- 354 | fin_state_1a = matrix_mutliply(fin_state_A, self.W_alb1, self.b11, self.neighbors_num, self.hidden_size) 355 | fin_state_2a = matrix_mutliply(fin_state_A, self.W_alv1, self.b21, self.neighbors_num, self.hidden_size) 356 | fin_state_3a = matrix_mutliply(fin_state_A, self.W_bav1, self.b31, self.neighbors_num, self.hidden_size) 357 | fin_state_4a = matrix_mutliply(fin_state_A, self.W_bt1, self.b41, self.neighbors_num, self.hidden_size) 358 | fin_state_5a = matrix_mutliply(fin_state_A, self.W_ada1, self.b51, self.neighbors_num, self.hidden_size) 359 | all_nei_A = tf.nn.relu(tf.concat([ 360 | tf.matmul(self.adj_1, fin_state_1a), 361 | tf.matmul(self.adj_2, fin_state_2a), 362 | tf.matmul(self.adj_3, fin_state_3a), 363 | tf.matmul(self.adj_4, fin_state_4a), 364 | tf.matmul(self.adj_5, fin_state_5a),], axis=-1)) 365 | print(all_nei_A) # all_nei=[b, N, 5*h] 366 | 367 | ###################### cross-domain representation ################# 368 | fin_state_1t = matrix_mutliply(fin_state_A2T, self.W_alb2, self.b12, self.neighbors_num, self.hidden_size) 369 | fin_state_2t = matrix_mutliply(fin_state_A2T, self.W_alv2, self.b22, self.neighbors_num, self.hidden_size) 370 | fin_state_3t = matrix_mutliply(fin_state_A2T, self.W_bav2, self.b32, self.neighbors_num, self.hidden_size) 371 | fin_state_4t = matrix_mutliply(fin_state_A2T, self.W_bt2, self.b42, self.neighbors_num, self.hidden_size) 372 | fin_state_5t = matrix_mutliply(fin_state_A2T, self.W_ada2, self.b52, self.neighbors_num, self.hidden_size) 373 | all_nei_A2T = tf.nn.relu(tf.concat([ 374 | tf.matmul(self.adj_1, fin_state_1t), 375 | tf.matmul(self.adj_2, fin_state_2t), 376 | tf.matmul(self.adj_3, fin_state_3t), 377 | tf.matmul(self.adj_4, fin_state_4t), 378 | tf.matmul(self.adj_5, fin_state_5t), ], axis=-1)) 379 | print(all_nei_A2T) # all_nei=[b, N, s*5] 380 | all_nei = tf.concat([all_nei_A, all_nei_A2T], axis=-1) 381 | print(all_nei) # [b, N, 10*s] 382 | 383 | return all_nei 384 | 385 | def get_ht(self,encoder_output_A, encoder_output_B,transfer_output_A, transfer_output_B, 386 | len_A, len_B, index_A, index_B, nei_L_A_mask, nei_L_T_mask): 387 | all_len = tf.add(len_A, len_B) 388 | ########## get hAi from encoder ########## 389 | e1 = tf.scatter_nd(index_A, encoder_output_A, [tf.shape(encoder_output_A)[0], 390 | tf.shape(encoder_output_A)[1] + tf.shape(encoder_output_B)[1], 391 | self.hidden_size]) 392 | print(e1) 393 | e2 = tf.scatter_nd(index_B, encoder_output_B, [tf.shape(encoder_output_A)[0], 394 | tf.shape(encoder_output_A)[1] + tf.shape(encoder_output_B)[1], 395 | self.hidden_size]) 396 | print(e2) 397 | seq_L = e1 + e2 # [b, time_A+time_B, h] 398 | hA = tf.gather_nd(seq_L, tf.stack([tf.range(tf.shape(encoder_output_A)[0]), all_len-1], axis=1)) # 拿到了最后一步的输入 399 | print('inputsA:',hA) 400 | ########## get h(A->B)i from transfer ########## 401 | e3 = tf.scatter_nd(index_A, transfer_output_A, [tf.shape(transfer_output_A)[0], 402 | tf.shape(transfer_output_A)[1] + tf.shape(transfer_output_B)[1], 403 | self.hidden_size]) 404 | print(e3) 405 | e4 = tf.scatter_nd(index_B, transfer_output_B, [tf.shape(transfer_output_A)[0], 406 | tf.shape(transfer_output_A)[1] + tf.shape(transfer_output_B)[1], 407 | self.hidden_size]) 408 | print(e4) 409 | trans_L = e3 + e4 # [b, time_A+time_B, h] 410 | hA2T = tf.gather_nd(trans_L, tf.stack([tf.range(tf.shape(encoder_output_A)[0]), all_len-1], axis=1)) # 拿到了最后一步的输入 411 | print('inputsA2T:', hA2T) 412 | ########## get mask ############### 413 | print(tf.shape(nei_L_A_mask)[1]) 414 | nei_A_mask = tf.gather_nd(nei_L_A_mask, tf.stack([tf.range(tf.shape(encoder_output_A)[0]), all_len-1], axis=1)) 415 | print('neiA_mask:', nei_A_mask) # [b, N] 416 | nei_T_mask = tf.gather_nd(nei_L_T_mask, tf.stack([tf.range(tf.shape(encoder_output_A)[0]), all_len-1], axis=1)) 417 | print('neiA2T_mask:', nei_T_mask) # [b, N] 418 | 419 | return hA, hA2T,nei_A_mask, nei_T_mask 420 | 421 | def indomain_attention(self, item_state, nei_mask): 422 | 423 | with tf.variable_scope('softmask_att_in', reuse=tf.AUTO_REUSE): 424 | S_it = tf.matmul(item_state, self.W_s_in) 425 | print('S_it:', S_it) # [b, N, h] 426 | S_emb = tf.matmul(self.entity_emb, self.W_emb_in) 427 | print('S_emb:', S_emb) # [b, N, h] 428 | tanh = tf.tanh(S_it + S_emb) 429 | print("tanh:", tanh) # [b, N, h] 430 | s = tf.squeeze(tf.matmul(tanh, self.W_v_in)) 431 | print("s:", s) # [b, N] 432 | s_inf_mask = self.mask_softmax(nei_mask, s) 433 | print(s_inf_mask) # [b, N] 434 | score = self.normalize_softmax(s_inf_mask) # [b, N] 435 | score = tf.expand_dims(score, axis=-1) 436 | print('score:', score) # [b, N, 1] 437 | softmask = score * self.entity_emb # [b, N, e] 438 | return softmask 439 | 440 | def crossdomain_attention(self, item_state, nei_mask): 441 | with tf.variable_scope('softmask_att_cross', reuse=tf.AUTO_REUSE): 442 | S_it = tf.matmul(item_state, self.W_s_cross) 443 | print('S_it:', S_it) # [b, N, h] 444 | S_emb = tf.matmul(self.entity_emb, self.W_emb_cross) 445 | print('S_emb:', S_emb) # [b, N, h] 446 | tanh = tf.tanh(S_it + S_emb) 447 | print("tanh:", tanh) # [b, N, h] 448 | s = tf.squeeze(tf.matmul(tanh, self.W_v_cross)) 449 | print("s:", s) # [b, N] 450 | s_inf_mask = self.mask_softmax(nei_mask, s) 451 | print(s_inf_mask) # [b, N] 452 | score = self.normalize_softmax(s_inf_mask) # [b, N] 453 | score = tf.expand_dims(score, axis=-1) 454 | print('score:', score) # [b, N, 1] 455 | softmask = score * self.entity_emb # [b, N, e] 456 | return softmask 457 | 458 | def mask_softmax(self, seq_mask, scores): 459 | ''' 460 | to do softmax, assign -inf value for the logits of padding tokens 461 | ''' 462 | seq_mask = tf.cast(seq_mask, tf.bool) 463 | score_mask_values = -1e10 * tf.ones_like(scores, dtype=tf.float32) 464 | return tf.where(seq_mask, scores, score_mask_values) 465 | 466 | def normalize_softmax(self,x): 467 | max_value = tf.reshape(tf.reduce_max(x, -1), [-1, 1]) 468 | each_ = tf.exp(x - max_value) 469 | all_ = tf.reshape(tf.reduce_sum(each_, -1), [-1, 1]) 470 | score = each_ / all_ 471 | return score 472 | 473 | def mutual_att(self,hb, hA,): 474 | hb_ext = tf.expand_dims(hb, axis=2) # hb_ext=[b,N1,1,h] 475 | hb_ext = tf.tile(hb_ext, [1, 1, tf.shape(hA)[1], 1]) # hb_ext=[b,N1,N2,h] 476 | hA_ext = tf.expand_dims(hA, axis=1) # hA_ext=[b,1,N2,h] 477 | hA_ext = tf.tile(hA_ext, [1, tf.shape(hb)[1], 1, 1]) # hA_ext=[b,N1,N2,h] 478 | dot = hb_ext * hA_ext 479 | # dot = tf.concat([hb_ext, hA_ext, hb_ext * hA_ext], axis=-1) # dot=[b,N1,N2,h] 480 | dot = tf.layers.dense(dot, 1, activation=None, use_bias=False) # dot=[b,N1,N2,1] 481 | dot = tf.squeeze(dot) # dot=[b,N1,N2] 482 | # sum_row = tf.reduce_sum(dot, axis=-1, keep_dims=True) # sum_row=[b,N1,1] 483 | # att_hb = sum_row * hb 484 | # print(att_hb) # [b, N1, h] 485 | att_hb = tf.matmul(dot, hA) # [b,N1,h] 486 | return att_hb 487 | 488 | def switch_A(self, encoder_state_A, transfer_state_B, graph_state, nei_mask): 489 | with tf.variable_scope('switch_A'): 490 | graph_rep = tf.reshape(graph_state, [-1, self.neighbors_num, self.hidden_size]) 491 | nei_mask = tf.expand_dims(nei_mask, -1) 492 | graph_rep = nei_mask * graph_rep 493 | graph_rep = tf.reduce_sum(graph_rep, axis=1) 494 | concat_output = tf.concat([encoder_state_A[-1], transfer_state_B[-1], graph_rep], axis=-1) 495 | linear_switch = tf.layers.Dense(1, kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1)) 496 | switch_matrix = linear_switch(concat_output) # Tensor shape (b, 1) 497 | PG_A = tf.sigmoid(switch_matrix) 498 | PS_A = 1 - PG_A 499 | print('PSA:',PS_A) 500 | print('PGA:',PG_A) 501 | return PG_A, PS_A 502 | 503 | def s_decoder_A(self, num_items_A, encoder_state_A, transfer_state_B, keep_prob): 504 | with tf.variable_scope('s_predict_A'): 505 | concat_output = tf.concat([encoder_state_A[-1],transfer_state_B[-1]],axis=-1) 506 | concat_output = tf.nn.dropout(concat_output, keep_prob) 507 | pred_A = tf.layers.dense(concat_output, num_items_A, 508 | activation=None, 509 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False)) 510 | # pred_A = self.normalize_softmax(pred_A) 511 | pred_A = tf.nn.softmax(pred_A) 512 | # print(pred_A) # pred_A=[b, num_items_A] 513 | return pred_A 514 | 515 | def g_decoder_A(self, ht, graph_state, num_items_A, nei_mask, nei_index_A, IsinnumA): 516 | 517 | with tf.variable_scope('g_predict_A'): 518 | self.W_h_a = random_weight(self.hidden_size, self.hidden_size, name='W_h_a') 519 | self.W_emb_a = random_weight(self.hidden_size, self.hidden_size, name='W_emb_a') 520 | self.W_v_a = random_weight(self.hidden_size, 1, name='W_v_a') 521 | graph_state = tf.reshape(graph_state, [-1, self.neighbors_num, self.hidden_size]) # [b, N, h] 522 | nei_mask = tf.expand_dims(nei_mask, -1) 523 | graph_state = nei_mask * graph_state 524 | att = self.g_decode_attention_A(ht[-1], graph_state, IsinnumA) 525 | print(att) # [b, N] 526 | g_pred_A = tf.scatter_nd(nei_index_A, att, [tf.shape(graph_state)[0], num_items_A]) 527 | print(g_pred_A) # [b, num_item_A] 528 | return g_pred_A, att 529 | def g_decode_attention_A(self, ht, repre, mask): 530 | S_h = tf.matmul(ht, self.W_h_a) # [b, h] 531 | S_h = tf.expand_dims(S_h, 1) 532 | print('S_it:', S_h) # [b, 1, h] 533 | S_emb = tf.reshape(tf.matmul(tf.reshape(repre, [-1, self.hidden_size]), self.W_emb_a), 534 | [-1, self.neighbors_num, self.hidden_size]) # [b, N, h] 535 | print('S_emb:', S_emb) 536 | tanh = tf.tanh(S_h + S_emb) # [b, N, h] 537 | print("tanh:", tanh) 538 | s = tf.reshape(tf.squeeze(tf.matmul(tf.reshape(tanh, [-1, self.hidden_size]), self.W_v_a)), 539 | [-1, self.neighbors_num]) # [b, N] 540 | print("s:", s) # [b, N] 541 | s_inf_mask = self.mask_softmax(mask, s) 542 | print(s_inf_mask) # [b, N] 543 | score = self.normalize_softmax(s_inf_mask) # [b, N] 544 | print('score:', score) 545 | return score 546 | 547 | def final_pred_A(self, PG_A, PS_A, s_pred_A, g_pred_A): 548 | with tf.variable_scope('final_predict_A'): 549 | pred_A = PG_A * g_pred_A + PS_A * s_pred_A 550 | print(pred_A) # [b, num_items_A] 551 | return pred_A 552 | 553 | def switch_B(self, encoder_state_B, transfer_state_A, graph_state, nei_mask): 554 | with tf.variable_scope('switch_B'): 555 | graph_rep = tf.reshape(graph_state, [-1, self.neighbors_num, self.hidden_size]) 556 | nei_mask = tf.expand_dims(nei_mask, -1) 557 | graph_rep = nei_mask * graph_rep 558 | graph_rep = tf.reduce_sum(graph_rep, axis=1) 559 | concat_output = tf.concat([encoder_state_B[-1], transfer_state_A[-1], graph_rep], axis=-1) 560 | # print(concat_output) # [batch_size, 3*hidden_size] 561 | linear_switch = tf.layers.Dense(1, kernel_initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1)) 562 | switch_matrix = linear_switch(concat_output) # Tensor shape (b, 1) 563 | PG_B = tf.sigmoid(switch_matrix) 564 | PS_B = 1 - PG_B 565 | # PG_B = tf.expand_dims(PG_B, 1) # [batch,1] 566 | # PS_B = tf.expand_dims(PS_B, 1) 567 | return PG_B, PS_B 568 | 569 | def s_decoder_B(self, num_items_B, encoder_state_B, transfer_state_A, keep_prob): 570 | with tf.variable_scope('s_predict_B'): 571 | concat_output = tf.concat([encoder_state_B[-1],transfer_state_A[-1]],axis=-1) 572 | concat_output = tf.nn.dropout(concat_output, keep_prob) 573 | pred_B = tf.layers.dense(concat_output, num_items_B, 574 | activation=None, 575 | kernel_initializer=tf.contrib.layers.xavier_initializer(uniform=False)) 576 | pred_B = tf.nn.softmax(pred_B) 577 | # pred_B = self.normalize_softmax(pred_B) 578 | print(pred_B) # [b, num_B] 579 | return pred_B 580 | 581 | def g_decoder_B(self, ht, graph_state, num_items_B, nei_mask, nei_index_B, IsinnumB): 582 | with tf.variable_scope('g_predict_B'): 583 | self.W_h_b = random_weight(self.hidden_size, self.hidden_size, name='W_h_b') 584 | self.W_emb_b = random_weight(self.hidden_size, self.hidden_size, name='W_emb_b') 585 | self.W_v_b = random_weight(self.hidden_size, 1, name='W_v_b') 586 | graph_state = tf.reshape(graph_state, [-1, self.neighbors_num, self.hidden_size]) # [b, N, h] 587 | nei_mask = tf.expand_dims(nei_mask, -1) 588 | graph_state = nei_mask * graph_state 589 | att = self.g_decode_attention_B(ht[-1], graph_state, IsinnumB) # [b, N] 590 | g_pred_B = tf.scatter_nd(nei_index_B, att, [tf.shape(graph_state)[0], num_items_B]) 591 | print(g_pred_B) # [b, num_item_B] 592 | return g_pred_B, att 593 | def g_decode_attention_B(self, ht, repre, mask): 594 | S_h = tf.matmul(ht, self.W_h_b) # [b, h] 595 | S_h = tf.expand_dims(S_h, 1) 596 | print('S_it:', S_h) # [b, 1, h] 597 | S_emb = tf.reshape(tf.matmul(tf.reshape(repre, [-1, self.hidden_size]), self.W_emb_b), 598 | [-1, self.neighbors_num, self.hidden_size]) # [b, N, h] 599 | print('S_emb:', S_emb) 600 | tanh = tf.tanh(S_h + S_emb) # [b, N, h] 601 | print("tanh:", tanh) 602 | s = tf.reshape(tf.squeeze(tf.matmul(tf.reshape(tanh, [-1, self.hidden_size]), self.W_v_b)), 603 | [-1, self.neighbors_num]) # [b, N] 604 | print("s:", s) # [b, N] 605 | s_inf_mask = self.mask_softmax(mask, s) 606 | print(s_inf_mask) # [b, N] 607 | score = self.normalize_softmax(s_inf_mask) # [b, N] 608 | print('score:', score) 609 | return score 610 | 611 | def final_pred_B(self, PG_B, PS_B, s_pred_B, g_pred_B): 612 | with tf.variable_scope('final_predict_B'): 613 | pred_B = PG_B * g_pred_B + PS_B * s_pred_B 614 | print(pred_B) 615 | return pred_B 616 | 617 | def cal_loss(self, target_A, pred_A, target_B, pred_B,): 618 | 619 | # loss_A = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target_A, logits=pred_A) 620 | loss_A = tf.contrib.keras.losses.sparse_categorical_crossentropy(target_A,pred_A) 621 | self.loss_A = tf.reduce_mean(loss_A, name='loss_A') 622 | # loss_B = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=target_B, logits=pred_B) 623 | loss_B = tf.contrib.keras.losses.sparse_categorical_crossentropy(target_B,pred_B) 624 | self.loss_B = tf.reduce_mean(loss_B, name='loss_B') 625 | loss = self.loss_A + self.loss_B 626 | # return loss 627 | loss_m_A = -(1 - tf.sign(self.tar_in_A)) * tf.log(self.PS_A + 0.0001) 628 | self.loss_m_A = tf.reduce_mean(loss_m_A, name='loss_m_A') 629 | loss_m_B = -(1 - tf.sign(self.tar_in_B)) * tf.log(self.PS_B + 0.0001) 630 | self.loss_m_B = tf.reduce_mean(loss_m_B, name='loss_m_B') 631 | loss_m = self.loss_m_A + self.loss_m_B 632 | loss_all = loss + loss_m 633 | 634 | return loss_all 635 | 636 | def optimizer(self, lr, training_steps_per_epoch, lr_decay_factor, min_lr): 637 | optimizer = tf.train.AdamOptimizer(lr) 638 | gradients = optimizer.compute_gradients(self.loss) 639 | capped_gradients = [(tf.clip_by_value(grad, -5., 5.), var) for grad, var in gradients if grad is not None] 640 | train_op = optimizer.apply_gradients(capped_gradients) 641 | return train_op, gradients 642 | 643 | 644 | def random_weight(dim_in, dim_out, name=None): 645 | return tf.get_variable(dtype=tf.float32, name=name, shape=[dim_in, dim_out], initializer=tf.contrib.layers.xavier_initializer(uniform=False)) 646 | 647 | def random_bias(dim, name=None): 648 | return tf.get_variable(dtype=tf.float32, name=name, shape=[dim], initializer=tf.constant_initializer(1.0)) 649 | 650 | def matrix_mutliply(finstate, W,b,N,h): 651 | fin_state_new = tf.reshape(tf.matmul(finstate, W) + b, [-1, N, h]) 652 | return fin_state_new 653 | 654 | # def get_neighbours1(sessions,lenA,lenB,seqA,seqB,tarA,tarB,neinum): 655 | # # this function is used to run as a sample to the model. 656 | # newseqB = [] 657 | # for session in seqB: 658 | # temp = [] 659 | # for i in session: 660 | # if i > 0: 661 | # temp.append(i + 500) 662 | # else: 663 | # temp.append(0) 664 | # newseqB.append(temp) 665 | # 666 | # sa = [] 667 | # for session in seqA: 668 | # seqa = session[np.where(session > 0)] 669 | # sa.append(seqa) 670 | # sb = [] 671 | # for session in seqB: 672 | # seqb = session[np.where(session > 0)] 673 | # sb.append([i+500 for i in seqb]) 674 | # 675 | # nei_all = [] 676 | # index_A,index_B = [],[] 677 | # IsinnumA,IsinnumB = [],[] 678 | # tarinid = [] 679 | # nei_mask_A,nei_mask_B = [],[] 680 | # nei_mask_L_A, nei_mask_L_T = [], [] 681 | # tarinA,tarinB = [],[] 682 | # for i in range(len(sessions)): 683 | # node = set() 684 | # itemA = set(sa[i]) 685 | # node.update(itemA) 686 | # itemB = set(sb[i]) 687 | # node.update(itemB) 688 | # 689 | # ran = random.randint(0,9) 690 | # if ran <= 1: 691 | # node.add(tarA[i]) 692 | # node.add(tarB[i] + 500) 693 | # tarinid.append(i+10) 694 | # tarinA.append(1) 695 | # tarinB.append(1) 696 | # elif ran >=2 and ran <=4: 697 | # node.add(tarA[i]) 698 | # tarinid.append(i + 20) 699 | # tarinA.append(1) 700 | # tarinB.append(0) 701 | # elif ran >=5 and ran <=7: 702 | # node.add(tarB[i]+500) 703 | # tarinid.append(i+30) 704 | # tarinA.append(0) 705 | # tarinB.append(1) 706 | # else: 707 | # tarinid.append(i+40) 708 | # tarinA.append(0) 709 | # tarinB.append(0) 710 | # 711 | # pool = set(range(1000)).difference(node) 712 | # rest = list(random.sample(list(pool),neinum-len(node))) 713 | # # print(set(rest)) 714 | # node.update(set(rest)) 715 | # # print(node) 716 | # indx = dict(zip(list(node), range(len(node)))) 717 | # nei_all.append(list(node)) 718 | # 719 | # tempinaset, tempinbset = [], [] 720 | # itemAset = set(range(100)) 721 | # itemBset = set(range(500,600)) 722 | # for ii in indx.keys(): 723 | # if ii in itemAset: 724 | # tempinaset.append(1) 725 | # tempinbset.append(0) 726 | # elif ii in itemBset: 727 | # tempinaset.append(0) 728 | # tempinbset.append(1) 729 | # else: 730 | # tempinaset.append(0) 731 | # tempinbset.append(0) 732 | # IsinnumA.append(tempinaset) 733 | # IsinnumB.append(tempinbset) 734 | # 735 | # ini_A = np.ones(neinum) 736 | # ini_A[np.where(np.array(list(node)) >= 500)] = 0 737 | # ini_B = np.ones(neinum) 738 | # ini_B[np.where(np.array(list(node)) < 500)] = 0 739 | # nei_mask_A.append(ini_A) 740 | # nei_mask_B.append(ini_B) 741 | # 742 | # temp_a, temp_t = [],[] 743 | # for item in sessions[i]: 744 | # if item < 500: 745 | # temp_a.append(ini_A) 746 | # temp_t.append(ini_B) 747 | # else: 748 | # temp_a.append(ini_B) 749 | # temp_t.append(ini_A) 750 | # nei_mask_L_A.append(temp_a) 751 | # nei_mask_L_T.append(temp_t) 752 | # 753 | # 754 | # zero1 = np.zeros(neinum) 755 | # ind1 = np.where(np.array(list(node)) <= 99)[0] 756 | # ent_id = np.array(list(node))[ind1] 757 | # ite_id = [] 758 | # for ii in ent_id: 759 | # ite_id.append(ii) 760 | # zero1[ind1] = ite_id 761 | # index_A.append(zero1) 762 | # 763 | # zero2 = np.zeros(neinum) 764 | # ind2 = [] 765 | # for lll in node: 766 | # if lll >= 500 and lll < 600: 767 | # ind2.append(indx[lll]) 768 | # ent_id = np.array(list(node))[ind2] 769 | # ite_id = [] 770 | # for ii in ent_id: 771 | # ite_id.append(ii-500) 772 | # zero2[ind2] = ite_id 773 | # index_B.append(zero2) 774 | # 775 | # nei_all = np.array(nei_all) 776 | # index = np.arange(len(sessions)) 777 | # index = np.expand_dims(index, axis=-1) 778 | # index_p = np.repeat(index, neinum, axis=1) 779 | # index1 = np.stack([index_p, np.array(index_A)], axis=-1) 780 | # index2 = np.stack([index_p, np.array(index_B)], axis=-1) 781 | # nei_mask_A = np.array(nei_mask_A) 782 | # nei_mask_B = np.array(nei_mask_B) 783 | # nei_mask_L_A = np.array(nei_mask_L_A) 784 | # nei_mask_L_T = np.array(nei_mask_L_T) 785 | # 786 | # return np.array(newseqB), nei_all,index1,index2,IsinnumA,IsinnumB,tarinid,nei_mask_A,nei_mask_B,nei_mask_L_A,nei_mask_L_T,tarinA,tarinB 787 | 788 | # batch_size = 4 789 | # model = MIFN(num_items_A=100, num_items_B=100,num_entity_A=500,num_entity_B=500,num_cate=100,neighbor_num=100, 790 | # batch_size=batch_size, gpu='1') #decay=lr_dc_step * train_data_len / batch_size 791 | # ''' 792 | # 这里给出的混合序列是 793 | # A36,B8,A55,B9,A2,B3,B77,(A89),(B16)、 794 | # B19,A1,A45,(A90),(B45)、 795 | # B23,A70,B54,B67,(B56),(A31),括号里的是预测目标 796 | # 我们需要将该个混合序列里的A和B分开,然后还需要记录A和B里的每个元素的位置 797 | # 以B8的位置为[0,1]为例,0表示它在该batch的第一个样例里(0是该样例的index),1表示它的前面有一个A元素即A36(由于我们会给A序列的开头都添加一个时间步即zero_state,所以1其实就是A36的index) 798 | # 而以A36的位置为[0,0]为例,0表示它在该batch的第一个样例里(0是该样例的index),0表示它的前面没有B元素 799 | # 对于B序列里的padding的元素,我们记录它们前面的A元素数也是0(因为我们想给它们用zero_state),对于A序列里的padding的元素,我们记录它们前面的B元素数也是0(因为我们想给它们用zero_state) 800 | # ''' 801 | # # seq_A = np.array([[36,55,2],[1,45,0],[70,0,0]]) 802 | # # seq_B = np.array([[8,9,3,77],[19,0,0,0],[23,54,67,0]]) 803 | # # len_A = np.array([3,2,1]) 804 | # # len_B = np.array([4,1,3]) 805 | # # # pos_A = np.array([[[0,0],[0,1],[0,2]],[[1,1],[1,1],[1,0]],[[2,1],[2,0],[2,0]]]) 806 | # # # pos_B = np.array([[[0,1],[0,2],[0,3],[0,3]],[[1,0],[1,0],[1,0],[1,0]],[[2,0],[2,1],[2,1],[2,0]]]) 807 | # # index_A = np.array([[[0,0],[0,2],[0,4]],[[1,1],[1,2],[1,0]],[[2,1],[2,0],[2,0]]]) 808 | # # index_B = np.array([[[0,1],[0,3],[0,5],[0,6]],[[1,0],[1,0],[1,0],[1,0]],[[2,0],[2,2],[2,3],[2,0]]]) 809 | # # target_A = np.array([89,90,31]) 810 | # # target_B = np.array([16,45,56]) 811 | # # # sequence = np.array([[36,8,55,9,2,3,77],[19,1,45],[23,70,54,67]]) 812 | # # sequence = np.array([[36,55,2,8,9,3,77],[1,45,19,0,0,0,0],[70,23,54,67,0,0,0]]) 813 | # ''' 814 | # 这里给出的混合序列是 815 | # A88,B16,A99,B2,A67,B45,(B44),(A56)、 816 | # B17,B91,A14,A43,(A90),(B73)、 817 | # B44,A34,B87,B72,A11,(A8),(B90)、 818 | # A21,A35,B56,A78,A79,(B11),(A62),括号里的是预测目标 819 | # ''' 820 | # seq_A = np.array([[88,99,67,0],[14,43,0,0],[34,11,0,0],[21,35,78,79]]) 821 | # seq_B = np.array([[16,2,45],[17,91,0],[44,87,72],[56,0,0]]) 822 | # len_A = np.array([3,2,2,4]) 823 | # len_B = np.array([3,2,3,1]) 824 | # len_all = np.array([6,4,5,5]) 825 | # pos_A = np.array([[[0,0],[0,1],[0,2],[0,2]],[[1,2],[1,2],[1,0],[1,0]],[[2,1],[2,3],[2,0],[2,0]],[[3,0],[3,0],[3,1],[3,1]]]) 826 | # pos_B = np.array([[[0,1],[0,2],[0,3]],[[1,0],[1,0],[1,0]],[[2,0],[2,1],[2,1]],[[3,2],[3,0],[3,0]]]) 827 | # index_A = np.array([[[0,0],[0,2],[0,4],[0,0]],[[1,2],[1,3],[1,0],[1,0]],[[2,1],[2,4],[2,0],[2,0]],[[3,0],[3,1],[3,3],[3,4]]]) 828 | # index_B = np.array([[[0,1],[0,3],[0,5]],[[1,0],[1,1],[1,0]],[[2,0],[2,2],[2,3]],[[3,2],[3,0],[3,0]]]) 829 | # target_A = np.array([56,90,8,62]) 830 | # target_B = np.array([44,73,90,11]) 831 | # sequence = np.array([[88,99,67,16,2,45],[14,43,17,91,0,0],[34,11,44,87,72,0],[21,35,78,79,56,0]]) 832 | # adj1_1 = batch_size*[np.random.randint(0,2,(100,100))] 833 | # adj2_1 = batch_size*[np.random.randint(0,2,(100,100))] 834 | # adj3_1 = batch_size*[np.random.randint(0,2,(100,100))] 835 | # adj4_1 = batch_size*[np.random.randint(0,2,(100,100))] 836 | # adj5_1 = batch_size*[np.random.randint(0,2,(100,100))] 837 | # # adj6_1 = batch_size*[np.random.randint(0,2,(100,100))] 838 | # newseqb, neighbors_1,nei_index_A1,nei_index_B1,IsinnumA,IsinnumB,tarinid,\ 839 | # nei_mask_A,nei_mask_B,nei_mask_L_A,nei_mask_L_T,tarinA,tarinB = get_neighbours1(sequence,len_A,len_B,seq_A,seq_B,target_A,target_B,100) 840 | # tarinA = np.expand_dims(tarinA,axis=1) 841 | # tarinB = np.expand_dims(tarinB,axis=1) 842 | # print(tarinA) 843 | # print(tarinB) 844 | # print('************************ start training...******************************') 845 | # 846 | # with tf.Session(graph=model.graph,config=model.config) as sess: 847 | # sess.run(tf.global_variables_initializer()) 848 | # i = 0 849 | # while i < 100: 850 | # _, _, l, pa, pb,psa,psb,pga,pgb = sess.run([model.train_op, model.grad, model.loss, model.pred_A, model.pred_B, 851 | # model.PS_A, model.PS_B, model.PG_A, model.PG_B,], 852 | # {model.seq_A:seq_A, model.seq_B:newseqb, model.len_A:len_A, model.len_B:len_B, 853 | # model.target_A:target_A, model.target_B:target_B, 854 | # model.pos_A:pos_A, model.pos_B:pos_B, model.index_A:index_A,model.index_B:index_B, 855 | # model.adj_1:adj1_1,model.adj_2:adj2_1,model.adj_3:adj3_1,model.adj_4:adj4_1,model.adj_5:adj5_1, 856 | # # model.adj_6:adj6_1, 857 | # model.neighbors:neighbors_1, model.nei_index_A:nei_index_A1, model.nei_index_B:nei_index_B1, 858 | # model.IsinnumA:IsinnumA, model.IsinnumB:IsinnumB, 859 | # model.nei_A_mask: nei_mask_A, model.nei_B_mask: nei_mask_B, 860 | # model.nei_L_A_mask:nei_mask_L_A, model.nei_L_T_mask:nei_mask_L_T, 861 | # model.tar_in_A: tarinA, model.tar_in_B: tarinB,}) 862 | # 863 | # print('loss:', l) 864 | # i += 1 --------------------------------------------------------------------------------