├── README.md ├── data ├── 50_keywordsets_eval │ └── word_sets.txt ├── ROC │ └── ROCStories_20_storylines_500_0.txt └── keyword_to_articles │ ├── test_10.txt │ ├── test_12.txt │ ├── test_13.txt │ ├── test_14.txt │ ├── test_15.txt │ ├── test_16.txt │ ├── test_4.txt │ ├── test_5.txt │ ├── test_8.txt │ └── test_9.txt ├── encode_keywords.py ├── encode_keywords_word2vec.py ├── main.py ├── metrics_degen.py ├── metrics_degen_run.sh ├── perplexity.py ├── requirements.txt ├── run.sh └── utility_gpt.py /README.md: -------------------------------------------------------------------------------- 1 | # Keyword2Text 2 | 3 | This repository contains the code of the paper: "A Plug-and-Play Method for Controlled Text Generation", if you find this useful and use it for your own research, please cite us. 4 | 5 | ## Setup 6 | 7 | 1. Download and unzip the repository. 8 | 2. Create a new conda environment and install the required libraries from the `requirements.txt` file. 9 | ```bash 10 | conda create -n k2t python=3.6 11 | conda activate k2t 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | A GPU will be required to run the experiments. 16 | Make sure you have a results folder. 17 | 18 | 19 | 20 | ## Run Model 21 | 22 | ### Hyperparameter Study 23 | 24 | Uncomment the appropriate lines of `run.sh` to run the hyperparameter experiments from the paper. For example, 25 | 26 | ```bash 27 | python main.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=guide_vs_no_guide_beams -weight=10.0 -top_p=0.9 -n_generated_sentences=90 -do_guarantee=True 28 | ``` 29 | 30 | runs K2T with ordered guide words (mode='next') on the random keywords dataset. It runs with lambda=weight=10, nucleus sampling with top-p=0.9, number of generated tokens = 90, and no weight annealing to guarantee word appearance. The results are saved in `results/tmp` 31 | 32 | ### ROC Story dataset 33 | 34 | Uncomment the appropriate line of `run.sh` to run the model on the ROC story dataset: 35 | 36 | ```bash 37 | python main.py -mode='max' -file_name=/data/ROC/ROCStories_20_storylines_500_0.txt -results_subfolder=final4_ -weight=5.0 -top_p=0.9 -n_generated_sentences=-7 -n_beams=4 -do_guarantee=True -task='ROC' 38 | ``` 39 | 40 | ### News Article dataset 41 | 42 | Uncomment the appropriate line of `run.sh` to run the model on the News Article story dataset: 43 | 44 | ```bash 45 | python main_DBS.py -mode='max' -file_name=/data/keyword_to_articles -results_subfolder=tmp -weight=5.0 -top_p=0.9 -n_generated_sentences=-15 -n_beams=4 -do_guarantee=True -task='key2article' 46 | ``` 47 | 48 | 49 | ## Contents 50 | ``` 51 | ├── data 52 | │   ├── 50_keywordsets_eval 53 | │   │   └── word_sets.txt 54 | │   ├── keyword_to_articles 55 | │   │   ├── test_10.txt 56 | │   │   ├── test_12.txt 57 | │   │   ├── test_13.txt 58 | │   │   ├── test_14.txt 59 | │   │   ├── test_15.txt 60 | │   │   ├── test_16.txt 61 | │   │   ├── test_4.txt 62 | │   │   ├── test_5.txt 63 | │   │   ├── test_8.txt 64 | │   │   └── test_9.txt 65 | │   └── ROC 66 | │   └── ROCStories_20_storylines_500_0.txt 67 | ├── encode_keywords.py 68 | ├── encode_keywords_word2vec.py 69 | ├── main.py 70 | ├── metrics_degen.py 71 | ├── metrics_degen_run.sh 72 | ├── perplexity.py 73 | ├── README.md 74 | ├── requirements.txt 75 | ├── results 76 | ├── run.sh 77 | └── utility_gpt.py 78 | 79 | 80 | -------------------------------------------------------------------------------- /data/50_keywordsets_eval/word_sets.txt: -------------------------------------------------------------------------------- 1 | enemy, speed, meet, colony, mouth 2 | suit, valley, speak, wing, tie 3 | dad, store, glad, wire, decimal 4 | touch, nation, chief, edge, wire 5 | poem, temperature, sky, crop, silent 6 | effect, neighbor, gas, oxygen, term 7 | feed, salt, molecule, lot, represent 8 | lady, dear, lot, except, region 9 | guide, jump, eight, row, settle 10 | rise, favor, soft, tie, tie 11 | excite, quite, capital, trip, lady 12 | oxygen, anger, art, spoke, race 13 | joy, require, chicken, else, cow 14 | search, major, string, cost, village 15 | enemy, dad, thick, crowd, length 16 | engine, experiment, caught, chair, visit 17 | represent, substance, student, tone, million 18 | ring, imagine, human, fruit, shout 19 | matter, dad, flat, control, women 20 | stone, anger, include, milk, scale 21 | slip, visit, material, reply, cow 22 | safe, drop, wife, eight, matter 23 | ball, hurry, lost, smell, plural 24 | push, choose, bright, trouble, human 25 | chance, bed, forward, kill, hair 26 | blow, mass, brother, market, roll 27 | block, quite, dead, subtract, band 28 | difficult, shine, practice, stream, supply 29 | corn, magnet, team, symbol, egg 30 | experiment, cost, third, wrong, melody 31 | deal, heavy, cell, spend, shall 32 | join, circle, describe, led, save 33 | stream, create, sugar, parent, crop 34 | desert, wild, crop, bear, enemy 35 | paragraph, anger, quite, general, led 36 | born, ride, collect, milk, rose 37 | cost, thank, summer, control, nose 38 | increase, village, gather, summer, fit 39 | tone, speech, represent, century, duck 40 | particular, summer, insect, rise, nature 41 | anger, yet, experiment, enter, eight 42 | favor, crowd, consonant, vary, melody 43 | mount, print, particular, range, swim 44 | spring, trip, instrument, subject, choose 45 | fair, corn, tall, cotton, decimal 46 | age, sheet, solution, evening, view 47 | child, lake, leg, flower, camp 48 | shall, organ, connect, noon, steel 49 | truck, populate, brother, stone, bank 50 | element, control, size, increase, speech 51 | -------------------------------------------------------------------------------- /data/ROC/ROCStories_20_storylines_500_0.txt: -------------------------------------------------------------------------------- 1 | kate, needed, decided, bought, home 2 | today, nervous, worried, luckily, ended 3 | loved, decided, store, bought, home 4 | shopping, decided, store, bought, happy 5 | vegas, decided, drove, found, trip 6 | wanted, decided, applied, finally, happy 7 | wanted, decided, signed, worked, year 8 | she, decided, wanted, store, bought 9 | flowers, bought, thought, loved, happy 10 | walk, started, began, falling, ended 11 | wanted, decided, practiced, hard, finally 12 | loved, decided, hiked, trail, glad 13 | wanted, decided, bought, home, happy 14 | walk, suddenly, rain, decided, home 15 | he, garage, decided, bought, happy 16 | home, house, called, told, decided 17 | party, wanted, decided, cake, friends 18 | wanted, decided, looked, found, bought 19 | jake, wanted, decided, worked, happy 20 | school, team, game, ball, scored 21 | -------------------------------------------------------------------------------- /data/keyword_to_articles/test_10.txt: -------------------------------------------------------------------------------- 1 | art_and_culture-20885615.txt 2 | Oscar nominated actress Hailee Steinfeld is set to play Sleeping Beauty in a re-visioning of the classic fairytale. The 14-year-old star has signed on to play the unfortunate princess in the new movie. It will be a spin of a version of the story which is being written by screen writer Lindsay Devlin. According to Deadline.com the new version will give the star more to do than just sleep as it will follow her as she enters a dream world and has to find her way out. Hailee is also currently being considered to play the lead role in the novel-based film Forgotten, from the upcoming book of the same name by Cat Patrick. After being one of the youngest ever actresses to be nominated for an Academy Award at this years Oscars she is set to have a busy year with several other offers on the table. The young California native had only starred in several television bit parts before landing the lead role in the Cohen Brothers adaptation of the classic John Wayne western True Grit alongside Jeff Bridges. 3 | movie, version, story, star, sleep, dream, role, novel, upcoming, youngest, actress, nominated, television, adaptation, classic 4 | ['Academy Award', 'California', 'fairytale', 'nominated', 'western', 'Oscar nominated actress', 'television', 'revisioning', 'version', 'novelbased', 'written', 'Sleeping Beauty', 'Forgotten', 'movie', 'adaptation', 'upcoming', 'actresses', 'unfortunate princess', 'Deadlinecom', 'considered', 'landing', 'starred', 'unfortunate', 'story', 'role', 'youngest', 'upcoming book', 'Oscars', 'California native', 'Oscar nominated', 'dream', 'True', 'classic', 'Beauty'] 5 | -------------------------------------------------------------------------------- /data/keyword_to_articles/test_12.txt: -------------------------------------------------------------------------------- 1 | science-20842091.txt 2 | At a government laboratory in Alabama, workers in blue coats unload envelopes packed with small filters that trapped air particles in Hawaii, Alaska and elsewhere. The discs are placed in lead- lined, barrellike devices for testing to make sure no traces of radioactive materials have wafted across the Pacific Ocean from Japan. So far, the sea breeze in places such as Honolulu is no more dangerous than the pollen-laden air of the Deep South, according to officials. Still, the 60 or so workers in the 72,000-square-foot building will be the first to know if the Japanese disaster spreads harmful amounts of radiation to the U.S. On Wednesday, the Environmental Protection Agency and the Food and Drug Administration said very low levels of radiation had turned up in a sample of milk from Washington state, but federal officials assured consumers not to worry. The FDA said such findings were to be expected in the coming days because of the nuclear crisis in Japan and that the levels were expected to drop relatively quickly. The EPA said it was increasing the level of nationwide monitoring of milk, precipitation and drinking water. 3 | devices, traces, radioactive, materials, ocean, dangerous, workers, disaster, radiation, environmental, administration, federal, consumers, crisis, water 4 | ['radioactive materials', 'Japanese disaster', 'milk', 'pollenladen air', 'government laboratory', 'devices', 'filters', 'radiation', 'precipitation', 'workers', 'consumers', 'drinking water', 'nuclear crisis', 'envelopes', 'Environmental', 'Drug Administration', 'Parker', 'milk precipitation', 'Food', 'traces', 'sea breeze', 'Vail restaurant deal', 'increasing', 'dangerous', 'Pacific Ocean', 'Japan', 'harmful', 'unload', 'Alabama', 'barrellike', 'Honolulu', 'make', 'Deep South', 'federal', 'sample', 'know', 'nationwide', 'nuclear', 'worry'] 5 | -------------------------------------------------------------------------------- /data/keyword_to_articles/test_13.txt: -------------------------------------------------------------------------------- 1 | crime-20905140.txt 2 | A 25-year-old German man has been arrested for allegedly burying a cache of bombs near a German soccer stadium in a blackmail plot, authorities say. The unnamed German national was arrested in Cologne on Tuesday after allegedly placing the explosives in a parking garage near the Westfalenstadion in Dortmund, home of the Borussia Dortmund team, the Federal Office of Criminal Investigation told The Local news agency. The bombs were safely defused, and three more were found at the man's home in Krefeld, officials said. Investigators said they began tracking the man after he e-mailed the German Embassy in Pakistan, offering information about two planned attacks in Germany by a group. The warning appeared to be a blackmail bid and was worded like an unsolved attempted blackmail case last year. "The suspect apparently acted alone with a general criminal motive," a federal spokesman said. "There are absolutely no ties to terrorist or Islamist organizations." Authorities say he admitted placing the bombs. Dortmund police spokesman Michael Stein told the BBC: "We expect no security threat at all for the upcoming Bundesliga match on Saturday. Visitors are invited to come to Dortmund. They will be safe here." 3 | german, arrested, explosives, garage, stadium, federal, criminal, investigation, bombs, home, embassy, information, blackmail, criminal, terrorist, organizations, safe 4 | ['Dortmund', 'unsolved attempted blackmail case', 'blackmail bid', 'safely defused', 'Cologne', 'Tuesday', 'cache of bombs', 'Westfalenstadion', 'explosives', 'parking garage', 'arrested', 'Germany', 'Pakistan', '25yearold German', 'burying', 'no ties to terrorist or Islamist organizations', 'soccer stadium', 'Federal Office of Criminal Investigation', 'German Embassy', 'began', 'terrorist', 'information', 'national', 'unnamed', 'stadium', 'DORTMUND', 'blackmail', 'Criminal Investigation', 'plot', 'German arrested', 'man', 'organizations Authorities', 'safe', 'bombs', 'cache', 'Bundesliga', 'worded', 'German', 'invited', 'Federal', 'unsolved', 'Saturday', 'apparently', 'April'] 5 | -------------------------------------------------------------------------------- /data/keyword_to_articles/test_14.txt: -------------------------------------------------------------------------------- 1 | business-20903055.txt 2 | An Alaska lawmaker introduced an amendment Friday that would give oil companies a tax break provided they pledged to increase production. Alaska Gov. Sean Parnell, a Republican, unveiled plans for 1 million barrels of oil production per day through the Trans-Alaska pipeline system within the next decade. Parnell said he was proposing a tax break for oil companies to encourage investments and to erase declines in state oil production. "The time to reduce oil taxes is now and I am asking all Alaskans to send a clear message to legislators in Juneau that a "do-nothing" strategy is unacceptable because Alaska's future is at stake," he said in a statement. But state Rep. Bob Miller, D-Fairbanks, introduced an amendment that would put certain restrictions on oil companies. Under Miller's plan, companies would get a tax break until January 2017. After that, they would need to have increased production by 10 percent of current levels to continue getting a break and increase production by another 2 percent each year beyond 2017. "We want to be sure that they are earning the breaks that we are giving," Miller said in a statement. "This amendment says here's the benefit. If you do not make certain metrics for the benefit of Alaska, we withdraw those benefits." 3 | barrels, production, pipeline, companies, investments, oil, tax, message, legislators, strategy, future, amendment, restrictions, companies, tax, increase, percentage, benefit 4 | -------------------------------------------------------------------------------- /data/keyword_to_articles/test_15.txt: -------------------------------------------------------------------------------- 1 | politics_us-20804565.txt 2 | I’d enjoy reading the citation of constitutional authority for this: “The bill then says if the Senate does not act, then H.R. 1 [the House-passed bill that cuts $61 billion] will be the law of the land. In addition to that, it says that if all else fails, and the Senate brings about a shutdown, then members should not get their pay.” That’s House Majority Leader Eric Cantor describing his “Government Shutdown Prevention Act.” The problem is, this would be blatantly unconstitutional: The Senate needs to pass the same piece of legislation the House does, and the president needs to either sign it or have his veto overturned. That’s how deem-and-pass worked with the health-care law, for instance: Both the Senate and the House passed the same pieces of legislation, and then the president signed them. But it seems Cantor merely misspoke. I’ve clarified with both his office and Boehner’s office that they believe the Senate and the president would still play their traditional roles. That means deem-and-pass isn't, as Cantor suggests, an alternative to actually striking a compromise. It’s just an effort to message the shutdown that’ll happen if a law isn’t passed. The problem for Cantor is that by misdescribing how the gambit would work, he’s drawing attention to the fact that it can’t. At the end of the day, we need an actual deal here. There’s no other option. 3 | senate, shutdown, pay, unconstitutional, prevention, legislation, president, sign, law, senate, compromise, problem, effort, attention, deal, option 4 | ['president', 'passed', 'piece', 'citation', 'deemandpass', 'unconstitutional', 'shutdown thatll', 'signed', 'brings', 'enjoy', 'effort', 'Prevention', 'constitutional', 'Senate', 'Cantor merely misspoke', 'of legislation', 'drawing', 'shutdown', 'legislation', 'misdescribing', 'Cantor', 'HR', 'overturned', 'Government Shutdown Prevention', 'pay', 'addition', 'blatantly unconstitutional'] 5 | -------------------------------------------------------------------------------- /data/keyword_to_articles/test_16.txt: -------------------------------------------------------------------------------- 1 | crime-20906537.txt 2 | We tend to think of Seattle in stereotypical ways: earthy, mellow, panoramic, rainy. But it's not all Patagonia and lattes out there, as the news yesterday about the Justice Department's investigation of Seattle police would indicate. The DOJ is looking into a possible pattern of the SPD using excessive force and discriminating against minorities, Justice announced here yesterday. At issue in the federal investigation are several high-profile incidents involving police violence. Last April, a detective was videotaped kicking a Latino robbery suspect and stating that he would beat the "Mexican piss" out of the suspect; In June, an officer was videotaped punching a 17-year-old African American girl who protested the arrest of a friend for jaywalking; and, In August, police shot a Native American woodcarver after he faied to drop his carving knife. Seattle Police Chief John Diaz told the Seattle Times that he welcomes the DOJ investigation and considers it like a "free audit." "We have nothing to hide," he said. "We've been open and transparent with the Department of Justice, which makes for a good working relationship." 3 | investigation, police, force, discriminating, yesterday, robbery, suspect, detective, african, native, knife, chief, investigation, justice 4 | ['dealers', 'kicking', 'discriminating', 'search', 'working', 'carving', 'vacation', 'announced', 'Native', 'indicate', "Ashby's page", 'jaywalking', 'lattes', 'blackandgold', 'faied', 'stereotypical', 'Earlier', 'little', 'Phoenix', 'miscreants', 'Seattle', 'investigation', 'yesterday', 'Investigate', 'brush', 'panoramic', 'highprofile', 'issue', 'videotaped', 'suspect', 'experience', 'transparent', 'Department', 'separate', 'welcomes', 'firsthand experience', 'Native American', 'modernday superhero Phoenix Jones', 'Seattle PostIntelligencer', 'Excessive Use', 'Force', 'Nathan Koppel', 'Latino robbery', 'DOJ', 'Feds Investigate Seattle Police', 'Ashby Jones', 'Mexican piss', "Justice Department's", 'officer', 'Seattle Times', 'Chief John Diaz', 'African American', 'videotaped punching', 'Patagonia', 'patrolled'] 5 | -------------------------------------------------------------------------------- /data/keyword_to_articles/test_4.txt: -------------------------------------------------------------------------------- 1 | politics_world-20795191.txt 2 | On Wednesday, Egypt's Supreme Military Council issued a declaration which will temporarily replace the country's former constitution, dating back to 1971. The declaration is based on amendments to the old constitution which were backed by the majority of people at a referendum on March 19. The amendments mostly concern procedures for presidential and parliamentary elections, and the terms of presidency. The old constitution was abrogated by the country's military on February 13, soon after President Mubarak's resignation. Parliamentary elections in Egypt are planned for this September. 3 | constitution, majority, elections, military, resignation 4 | ['concern', "country's", 'Military Council', 'military', '1971', 'presidential', 'Wednesday', "President Mubarak's", 'declaration', 'majority', 'parliamentary elections', 'referendum', 'backed', 'constitution', 'provisional', 'temporarily', 'February 13', 'March 19', 'planned', 'Egypt', 'amendments', 'abrogated', 'old', 'elections', 'people', 'former', 'provisional constitution', 'dating', 'Supreme Military Council', "President Mubarak's resignation", "country's military"] 5 | -------------------------------------------------------------------------------- /data/keyword_to_articles/test_5.txt: -------------------------------------------------------------------------------- 1 | politics_us-20802905.txt 2 | A group of moderate House Democrats have initiated conversations with Republican lawmakers in a bid to try to reach a deal to tackle burgeoning federal spending. The group is being led by Reps. Jim Cooper (D., Tenn.) and Kurt Schrader (D., Ore.), both of whom are leaders of the Blue Dogs, a coalition of fiscal conservative Democrats. Tuesday, the lawmakers released a series of fiscal discipline goals, many of which are similar to the targets set in place by the deficit commission established ... 3 | leaders, coalition, conservative, lawmakers, goals 4 | ['initiated conversations', 'Democrats', 'deficit', 'fiscal', 'deal', 'House', 'Mimic', 'place', 'Blue Dogs', 'spending', 'Democrats Tuesday', 'leaders', 'burgeoning', 'try', 'Republican lawmakers', 'Effort', 'released', 'reach', 'discipline', 'targets', 'similar', 'Kurt Schrader', 'moderate', 'established', 'group', 'Republican', 'tackle burgeoning federal', 'series', 'coalition', 'bid', 'discipline goals'] 5 | -------------------------------------------------------------------------------- /data/keyword_to_articles/test_8.txt: -------------------------------------------------------------------------------- 1 | sports-20933921.txt 2 | Pinch-hitter David Murphy delivered a tiebreaking, two-run double in the eighth inning and the Texas Rangers rallied for a 9-5 victory over the Boston Red Sox on Friday after raising their American League championship flag. Murphy's slicing liner off Daniel Bard (0-1) kicked up chalk when it landed on the left-field line. That sent Rangers newcomers Mike Napoli and Yorvit Torrealba home to break a 5-all tie. Murphy scored on a double by Elvis Andrus before another double by AL MVP Josh Hamilton. Napoli, Ian Kinsler and Nelson Cruz all homered off Jon Lester for Texas, which played its season opener exactly five months after a Game 5 loss to San Francisco at home ended its first World Series. Darren Oliver (1-0) got the victory after allowing a tying homer to David Ortiz in the eighth. Adrian Gonzalez had two hits and drove in three runs in his Boston debut. Carl Crawford went 0 for 4 with three strikeouts while leaving a runner in scoring position each at-bat. 3 | championship, kicked, land, newcomers, home, score, played, season, victory, hits, runs, debut, strikeouts 4 | ['tying', 'home', 'drove', 'allowing', 'two hits', 'homered', 'inning', 'strikeouts', 'American League championship', 'raising', 'Boston debut', 'ended', 'Texas Rangers', 'World Series', 'kicked', 'victory', 'Capsules', 'scored', 'season opener', 'Texas', 'played', 'Boston Red Sox', 'Indians', 'Pinchhitter David Murphy delivered', 'leaving', 'tiebreaking', 'landed', 'opener', 'newcomers', 'eighth', 'scoring position', 'Francisco', 'runs', 'debut', 'rallied', 'double', 'eighth inning', 'delivered', 'months', 'runner'] 5 | -------------------------------------------------------------------------------- /data/keyword_to_articles/test_9.txt: -------------------------------------------------------------------------------- 1 | art_and_culture-20890475.txt 2 | Johnny Depp will be asked to make a fifth Pirates Of The Caribbean film if the fourth instalment is a success. Producer Jerry Bruckheimer said he already has a screenplay in the works for a fifth Pirates tale, which would follow the forthcoming Pirates Of The Caribbean: On Stranger Tides. "As long as the audience embraces this one, we'll certainly try to make another one. It's really up to Johnny. He loves the character," he said. The original three Pirates blockbusters ended up as a trilogy continuing the same key characters and story line. But Jerry said On Stranger Tides - the first Pirates flick shot with digital 3D cameras - and future Pirates films will be stand-alone stories continuing the adventures of Johnny's woozy buccaneer Captain Jack Sparrow. "The audience told us (at test screenings of On Stranger Tides) what they loved about it is that it was fresh, it was new, it was a whole new story," Jerry said. "So that will carry over into the next one, too, to give it something fresh and different." 3 | fifth, audience, embraces, character, original, trilogy, continue, story, digital, cameras, pirates, fresh, new, carry 4 | ['screenplay', 'carry', 'told', 'Stranger', 'adventures', 'next', 'digital', 'character', 'Pirates', 'works', 'really', 'new', 'long', 'future', 'audience embraces', 'loves', 'stories', 'make', 'continuing', 'fresh', 'cameras', 'film', 'flick', 'instalment', 'success', 'whole new story', 'digital 3D cameras', 'audience embraces this', 'fifth Pirates Of The Caribbean film', 'trilogy', 'fifth', 'Producer Jerry Bruckheimer', "Johnny's woozy buccaneer Captain Jack Sparrow", 'original', 'ended up as', 'Johnny Depp', 'tale', 'On Stranger Tides', 'story line', 'same key characters'] 5 | -------------------------------------------------------------------------------- /encode_keywords.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import json 4 | import os 5 | import numpy as np 6 | import scipy.io as sio 7 | import argparse 8 | 9 | 10 | import gensim.downloader as api 11 | import pickle 12 | import argparse 13 | 14 | 15 | word_embedding = { 16 | 'glove': "glove-wiki-gigaword-300", 17 | 'word2vec': "word2vec-google-news-300" 18 | } 19 | 20 | def create_enc_dict(file_name, embedding, task): 21 | 22 | embedding_file = word_embedding[embedding] 23 | if task == 'key2article': 24 | folder_name = file_name 25 | else: 26 | folder_name = os.path.dirname(file_name) 27 | 28 | print('file_name: ', file_name) 29 | print('folder_name: ', folder_name) 30 | print('word_embedding: ', embedding) 31 | 32 | ######## Load word embedding data 33 | print('{} word embeddings loading...'.format(embedding)) 34 | encoder = api.load(embedding_file) 35 | print('{} word embeddings loaded'.format(embedding)) 36 | glove_dict = {} 37 | 38 | if not task == 'key2article': 39 | file1 = open(file_name, "r+") 40 | lines = file1.readlines() 41 | 42 | i=0 43 | for line in lines: 44 | keywords = list(line.strip().split(", ")) 45 | print(keywords) 46 | for word in keywords: 47 | glove_dict[word] = encoder[word] 48 | 49 | # save_path = folder_name + '/' + str(embedding) + '_set_' +str(i) + '.npy' 50 | # np.save(save_path, glove_words) 51 | i=i+1 52 | else: 53 | keyword_sets = [] 54 | for filename in os.listdir(folder_name): 55 | if filename.endswith('txt'): 56 | file1 = open(folder_name + filename, "r+") 57 | lines = file1.readlines() 58 | keywords = list(lines[2].strip().split(", ")) 59 | in_text = lines[1].split()[:30] 60 | keyword_sets.append((' '.join(in_text), keywords)) 61 | for word in keywords: 62 | glove_dict[word] = encoder[word] 63 | 64 | save_path_dict = folder_name + '/dict_' + str(embedding) + '.pkl' 65 | with open(save_path_dict, 'wb') as f: 66 | pickle.dump(glove_dict, f, pickle.HIGHEST_PROTOCOL) 67 | 68 | 69 | # if encode_articles == True: 70 | 71 | # for n in [4, 5, 8, 9, 10, 12, 13, 14, 15, 16]: 72 | # print(n) 73 | # file1 = open(str(os.path.dirname(os.path.abspath(__file__))) + 74 | # "/data/keyword_to_articles/test_" + str(n) + ".txt", "r+") 75 | 76 | # lines = file1.readlines() 77 | 78 | # keywords = list(lines[2].strip().split(", ")) 79 | # print(keywords) 80 | # glove_words = [] 81 | # for word in keywords: 82 | # glove = encoder[word] 83 | # glove_words.append(glove) 84 | 85 | # save_path = str(os.path.dirname( 86 | # os.path.abspath(__file__))) + '/data/keyword_to_articles/test_' +str(n) + '.npy' 87 | # np.save(save_path, glove_words) 88 | 89 | if __name__ == "__main__": 90 | ######## Parse arguments 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument('-file', type=str) 93 | parser.add_argument('-word_embedding', type=str, default='glove', 94 | choices=list(word_embedding.keys()), help='word_embedding') 95 | parser.add_argument('-task', type=str, default=None) #'key2article', 'commongen' 96 | args = parser.parse_args() 97 | file_name = args.file 98 | embedding = args.word_embedding 99 | task = args.task 100 | 101 | create_enc_dict(file_name, embedding, task) 102 | 103 | 104 | -------------------------------------------------------------------------------- /encode_keywords_word2vec.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import json 4 | import os 5 | import numpy as np 6 | import scipy.io as sio 7 | import argparse 8 | 9 | 10 | import gensim.downloader as api 11 | import pickle 12 | import argparse 13 | 14 | 15 | ######## Change to encode keywords for the desired task 16 | encode_50keywords = True 17 | encode_ROC = False 18 | encode_articles = False 19 | 20 | print('word2vec loading...') 21 | word2vec_encoder = api.load("word2vec-google-news-300") 22 | print('word2vec loaded') 23 | word2vec_dict = {} 24 | 25 | ######## Parse arguments 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('-file', type=str) 28 | args = parser.parse_args() 29 | file_name = args.file 30 | folder = os.path.dirname(file_name) 31 | 32 | print('file_name: ', file_name) 33 | 34 | if encode_50keywords == True: 35 | 36 | file1 = open(str(os.path.dirname(os.path.abspath(__file__))) + 37 | file_name, "r+") 38 | lines = file1.readlines() 39 | 40 | i=0 41 | for line in lines: 42 | keywords = list(line.strip().split(", ")) 43 | word2vec_words = [] 44 | print(keywords) 45 | for word in keywords: 46 | word2vec = word2vec_encoder[word] 47 | word2vec_words.append(word2vec) 48 | word2vec_dict[word] = word2vec 49 | save_path = str(os.path.dirname( 50 | os.path.abspath(__file__))) + folder + '/set_' +str(i) + '.npy' 51 | 52 | np.save(save_path, word2vec_words) 53 | i=i+1 54 | 55 | save_path_dict = str(os.path.dirname( 56 | os.path.abspath(__file__))) + folder + '/dict_word2vec.pkl' 57 | with open(save_path_dict, 'wb') as f: 58 | pickle.dump(word2vec_dict, f, pickle.HIGHEST_PROTOCOL) 59 | 60 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import json 4 | import os 5 | import numpy as np 6 | import scipy.io as sio 7 | import argparse 8 | import gc 9 | 10 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 11 | import torch.nn.functional as F 12 | 13 | from sklearn.metrics.pairwise import cosine_similarity 14 | from utility_gpt import * 15 | from perplexity import * 16 | import pickle 17 | import random 18 | from encode_keywords import create_enc_dict 19 | from collections import Counter 20 | 21 | from nltk.stem import PorterStemmer, LancasterStemmer 22 | porter = PorterStemmer() 23 | 24 | 25 | word_embedding = { 26 | 'glove': "glove-wiki-gigaword-300", 27 | 'word2vec': "word2vec-google-news-300" 28 | } 29 | 30 | 31 | if not os.path.exists(str(os.path.dirname(os.path.abspath(__file__))) + '/data/converter_table_glove.npy'): 32 | print("Generating table of cosine distances...") 33 | converter_table_glove() 34 | 35 | if not os.path.exists(str(os.path.dirname(os.path.abspath(__file__))) + '/data/converter_table_word2vec.npy'): 36 | print("Generating table of cosine distances...") 37 | converter_table_word2vec() 38 | 39 | 40 | def distinct_n(example, n, n_distinct, n_total, counter): 41 | """ 42 | Gives the number of distinct n-grams as well as the total n-grams 43 | Args: 44 | example: input text 45 | n: n-grams size (i.e., the n) 46 | n_distinct: distinct n-grams in previous iteration 47 | n_total: total n-grams in previous iteration 48 | counter: token counter in previous iteration, i.e., how many times a token appeared 49 | 50 | """ 51 | for token in zip(*(example[i:] for i in range(n))): 52 | if token not in counter: 53 | n_distinct += 1 54 | elif counter[token] == 1: 55 | n_distinct -= 1 56 | counter[token] += 1 57 | n_total += 1 58 | return n_distinct, n_total, counter 59 | 60 | 61 | def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): 62 | """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 63 | Args: 64 | logits: logits distribution shape (vocabulary size) 65 | top_k >0: keep only top k tokens with highest probability (top-k filtering). 66 | top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 67 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 68 | """ 69 | assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear 70 | top_k = min(top_k, logits.size(-1)) # Safety check 71 | if top_k > 0: 72 | # Remove all tokens with a probability less than the last token of the top-k 73 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 74 | logits[indices_to_remove] = filter_value 75 | 76 | if top_p > 0.0: 77 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 78 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 79 | 80 | # Remove tokens with cumulative probability above the threshold 81 | sorted_indices_to_remove = cumulative_probs > top_p 82 | # Shift the indices to the right to keep also the first token above the threshold 83 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 84 | sorted_indices_to_remove[..., 0] = 0 85 | 86 | #indices_to_remove = torch.zeros_like(logits, dtype=torch.uint8).scatter_(dim=-1, index=sorted_indices, src=sorted_indices_to_remove ) 87 | 88 | indices_to_remove = sorted_indices[sorted_indices_to_remove] 89 | logits[indices_to_remove] = filter_value 90 | 91 | return logits 92 | 93 | def get_keywords(keywords, enc_dict, tokenizer, mode): 94 | keywords_ = [w for w in keywords] 95 | 96 | # Select the next guide word(s) 97 | if keywords_: 98 | if mode=='next': 99 | keywords_ = [keywords_[0]] 100 | if mode=='random': 101 | keywords_ = [random.choice(keywords_)] 102 | 103 | keywords_enc = [enc_dict[w] for w in keywords_] 104 | keywords_gpt = {tokenizer.encode(w)[0]:w for w in keywords_} 105 | 106 | return keywords_enc, keywords_gpt 107 | 108 | def get_logits(model, tokenizer, text, this_sequence, temperature): 109 | ## GPT2 - generate logits 110 | indexed_tokens = tokenizer.encode(text) 111 | indexed_this_seq = tokenizer.encode(this_sequence) 112 | tokens_tensor = torch.tensor([indexed_tokens]) 113 | tokens_tensor = tokens_tensor.to('cuda') 114 | 115 | # Predict all tokens 116 | outputs = model(tokens_tensor) 117 | 118 | del tokens_tensor 119 | torch.cuda.empty_cache() 120 | 121 | logits = outputs.logits 122 | logits = logits[0, -1, :]/ temperature 123 | 124 | return logits, indexed_tokens, indexed_this_seq 125 | 126 | def get_sim(keywords_enc, keywords_gpt, converter_table, guarantee, mode, only_max): 127 | 128 | if len(keywords_enc)>1: 129 | sims = np.array([cosine_similarity(np.reshape(w, (1, -1)), converter_table) for w in keywords_enc]) 130 | if guarantee: 131 | for i, w in enumerate(keywords_gpt): 132 | sims[i][0][w] = 1 133 | if mode=='max': 134 | sim = np.max(sims, axis=0) 135 | elif mode=='all': 136 | sim = np.sum(sims, axis=0) 137 | else: 138 | raise Exception("keywords_enc length is greater than 1 so expect to be in mode 'max' or 'all'") 139 | else: 140 | sim = cosine_similarity(np.reshape(keywords_enc[0], (1, -1)), converter_table) 141 | 142 | # Only the target word, not the neighbour (as measured by cosine similarity) 143 | if only_max == True: 144 | sim_aux = np.zeros_like(sim) 145 | sim_aux[0,sim.argmax()] = sim.max() 146 | sim = np.squeeze(sim_aux) 147 | else: 148 | sim = np.clip(np.squeeze(sim), a_min=0, a_max=None) #tf.square(sim) 149 | 150 | return sim 151 | 152 | def get_weight(weight, guarantee, T_time, time): 153 | 154 | if guarantee: 155 | if T_time == 0: 156 | T_time = 1 157 | rate = (1/T_time)*np.log(100/weight) # 100 is the maximum value the weight will reach 158 | weight = weight*np.exp(rate*time) 159 | 160 | return weight 161 | 162 | def get_prediction(tokenizer, indexed_tokens, indexed_this_seq, keywords_gpt, predicted_index, guarantee, T_time, time): 163 | 164 | if guarantee and time > T_time: 165 | predicted_index = list(keywords_gpt.keys())[0] 166 | if guarantee and predicted_index in keywords_gpt: 167 | predicted_text = tokenizer.decode(indexed_tokens) + ' ' + keywords_gpt[predicted_index] 168 | this_sequence = tokenizer.decode(indexed_this_seq) + ' ' + keywords_gpt[predicted_index] 169 | pred_word = keywords_gpt[predicted_index] 170 | else: 171 | predicted_text = tokenizer.decode(indexed_tokens + [predicted_index]) 172 | this_sequence = tokenizer.decode(indexed_this_seq + [predicted_index]) 173 | pred_word = predicted_text.split()[-1].split('<|endoftext|>')[-1] 174 | 175 | return pred_word, predicted_text, predicted_index, this_sequence 176 | 177 | 178 | 179 | def sample_sentence(text, this_sequence, tokenizer, model, keywords, enc_dict, guide_probs, converter_table, weight, guide=False, prev_proba=1, top_k=0, top_p=0.9, temperature=1., only_max=False, mode='max', guarantee=False, time=0, T_time=1, det_BS=False, ith=0): 180 | """ Samples the next word of the sequence with logit modification (guidance) 181 | Modes: 182 | mode='max': each token is shifted by the cosine similarity to the closest guide word 183 | mode='all': each token is shifted by the cosine similarity to each guide word 184 | mode='next': the order of the guide words is fixed and each token is shifted towards the next guide word in the sequence 185 | mode='random': a random word is selected from the remaining (not yet appeared) guide words and each token is shifted towards this guide word 186 | """ 187 | 188 | # Get word stems, encode keywords and get logits from LM from context 189 | guide_word_stems = [porter.stem(w.lower()) for w in keywords] 190 | keywords_enc, keywords_gpt = get_keywords(keywords, enc_dict, tokenizer, mode) 191 | logits, indexed_tokens, indexed_this_seq = get_logits(model, tokenizer, text, this_sequence, temperature) 192 | 193 | # Get probabilities for ppl calculation and log-softmax of logits for modification 194 | proba = F.softmax(logits, dim=-1) 195 | logits = F.log_softmax(logits, dim=-1) 196 | 197 | # Calculate cosine similarity, weight with annealing and modify logits 198 | if keywords_enc and guide: 199 | sim = get_sim(keywords_enc, keywords_gpt, converter_table, guarantee, mode, only_max) 200 | weight = get_weight(weight, guarantee, T_time, time) 201 | logits = logits + torch.tensor(sim*weight).cuda() # 202 | 203 | ## Sample tokens 204 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) ### 205 | logits = F.softmax(logits, dim=-1) 206 | # Deterministic beam search or sampling, if p!=1. it means nucleus sampling 207 | 208 | predicted_index = 50256 209 | while guide and predicted_index == 50256: 210 | if det_BS: 211 | predicted_index = torch.topk(logits, ith+1)[1][ith].item() 212 | else: 213 | predicted_index = torch.multinomial(logits, 1).item() 214 | 215 | # Get predicted word and indices 216 | pred_word, predicted_text, predicted_index, this_sequence = get_prediction(tokenizer, indexed_tokens, indexed_this_seq, keywords_gpt, predicted_index, guarantee, T_time, time) 217 | 218 | # Update counters if word was predicted 219 | pred_word_stem = porter.stem(pred_word.lower()) 220 | guide_next = guide 221 | time_next = time+1 222 | T_time_next = T_time 223 | if pred_word_stem in guide_word_stems: 224 | ind = guide_word_stems.index(pred_word_stem) 225 | keywords = keywords[:ind] + keywords[ind+1:] 226 | guide_probs = guide_probs + [(pred_word_stem, proba[predicted_index].item())] 227 | guide_next = False 228 | time_next = 1 229 | T_time_next = T_time-time+1 230 | 231 | return predicted_text, keywords, guide_next, guide_probs, prev_proba*proba[predicted_index], this_sequence, time_next, T_time_next 232 | 233 | 234 | 235 | def sample_sentence_noguide(text, this_sequence, tokenizer, model, prev_proba=1, top_k=0, top_p=0.9, temperature=1., eos_c=0, det_BS=False, ith=0): 236 | """ Samples the next word of the sequence without logit modification (guidance) 237 | """ 238 | logits, indexed_tokens, indexed_this_seq = get_logits(model, tokenizer, text, this_sequence, temperature) 239 | 240 | proba = F.softmax(logits, dim=-1) 241 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 242 | logits = F.softmax(logits, dim=-1) 243 | 244 | if det_BS: 245 | predicted_index = torch.topk(logits, ith+1)[1][ith].item() 246 | else: 247 | predicted_index = torch.multinomial(logits, 1).item() 248 | 249 | predicted_text = tokenizer.decode(indexed_tokens + [predicted_index]) 250 | this_sequence = tokenizer.decode(indexed_this_seq + [predicted_index]) 251 | pred_word = predicted_text.split()[-1] 252 | 253 | if predicted_index == 50256: 254 | eos_c_next=1 255 | else: 256 | eos_c_next=eos_c 257 | 258 | if eos_c == 1: 259 | next_proba = prev_proba 260 | else: 261 | next_proba = prev_proba*proba[predicted_index].item() 262 | 263 | return predicted_text, next_proba, this_sequence, eos_c_next 264 | 265 | def get_score(k, number_of_words_per_sentence, online_probability, proba): 266 | alpha = 0.6 267 | length = (k+1)*number_of_words_per_sentence 268 | len_norm = ((5+length)**alpha)/(6**alpha) 269 | score_ = np.log(online_probability*proba)/len_norm 270 | 271 | return score_ 272 | 273 | 274 | def get_success_length(success_length, number_of_beams, guide, number_of_generated_sentences): 275 | 276 | for b in range(number_of_beams): 277 | if guide[b]: 278 | success_length[b] = 0 279 | 280 | success_length = success_length[0]/number_of_generated_sentences 281 | return success_length 282 | 283 | def get_success_rate(number_keywords, count_word_stem, keywords, full_text): 284 | # Success rate 285 | target_words = number_keywords 286 | target_count = 0 287 | for i in range(number_keywords): 288 | if count_word_stem(keywords[i], full_text[0]) > 0: 289 | target_count += 1 290 | 291 | success_rate = word_c[0]/number_keywords #target_count/target_words 292 | 293 | def conditional_language_generation( 294 | model, 295 | tokenizer, 296 | keyword_set, 297 | model_name='gpt2-large', 298 | enc_dict={}, 299 | seed=None, 300 | temperature=1., 301 | top_k=0, 302 | top_p=0.9, 303 | constant=20, 304 | number_of_concurrent_sentences = 10, 305 | number_of_generated_sentences = 20, 306 | number_of_words_per_sentence = 5, 307 | number_of_beams = 3, 308 | save_path='dummy.txt', 309 | only_max = False, 310 | no_do_wc=False, 311 | mode='max', 312 | do_guarantee=False, 313 | embedding='glove', 314 | det_BS=False, 315 | folder_name='', 316 | guide=True 317 | ): 318 | """ 319 | Main function for conditional language generation 320 | :model_name=124M : String, which model to use 321 | :seed=None : Integer seed for random number generators, fix seed to reproduce 322 | results 323 | :temperature=1 : Float value controlling randomness in boltzmann 324 | distribution. Lower temperature results in less random completions. As the 325 | temperature approaches zero, the model will become deterministic and 326 | repetitive. Higher temperature results in more random completions. 327 | :top_k=0 : Integer value controlling diversity. 1 means only 1 word is 328 | considered for each step (token), resulting in deterministic completions, 329 | while 40 means 40 words are considered at each step. 0 (default) is a 330 | special setting meaning no restrictions. 40 generally is a good value. 331 | :top_p=1 Top_p is the cummulative probability used for nucleus sampling. 1 means no nucleus sampling 332 | :constant: How much are anchors weighted 333 | :counter index of wordset which is currently evaluated 334 | :TODO ns..... 335 | :modes: 336 | mode='max': each token is shifted by the cosine similarity to the closest guide word 337 | mode='all': each token is shifted by the cosine similarity to each guide word 338 | mode='next': the order of the guide words is fixed and each token is shifted towards the next guide word in the sequence 339 | mode='random': a random word is selected from the remaining (not yet appeared) guide words and each token is shifted towards this guide word 340 | mode='best_tour': 341 | mode='worst_tour': 342 | """ 343 | 344 | start_time = time.time() 345 | total_words = number_of_words_per_sentence*number_of_generated_sentences 346 | 347 | ################################### 348 | ## Load words 349 | 350 | # Define task, keyword to article, keyword to story (ROC) or keyword to phrase 351 | in_text, keywords = keyword_set 352 | keywords_enc = [enc_dict[w] for w in keywords] 353 | number_keywords = len(keywords) 354 | print(in_text, keywords) 355 | print("N keywords: ", number_keywords) 356 | 357 | if mode=='best_tour': 358 | best_order = best_tour(keywords_enc) 359 | keywords_enc = [keywords_enc[i] for i in list(best_order)] 360 | print("Keywords: ", keywords, best_order) 361 | keywords = [keywords[i] for i in best_order] 362 | print("Keywords: ", keywords) 363 | mode = 'next' # Switch over to next (ordered) mode with the optimized order 364 | 365 | ################################### 366 | 367 | # File to save results as .txt 368 | text_file = open(save_path + '.txt', 'a+', encoding='utf8') 369 | text_file_sentences = open(save_path + 'SENTENCES.txt', 'a+', encoding='utf8') 370 | 371 | # prepare variables... 372 | #np.random.seed(seed) 373 | weight = constant 374 | converter_table = np.load(str(os.path.dirname( 375 | os.path.abspath(__file__))) + '/data/converter_table_' + str(embedding) + '.npy') 376 | 377 | full_text = [in_text] * number_of_beams 378 | guide_words_s = [keywords]*number_of_beams 379 | guide_probs_s = [[]]*number_of_beams 380 | cum_quality_score = [0]*number_of_beams 381 | word_c = [0]*number_of_beams 382 | success_length = [0]*number_of_beams 383 | online_probability = [1]*number_of_beams 384 | guide = [guide]*number_of_beams 385 | eos_count = [0]*number_of_beams 386 | total_time = [total_words-number_keywords]*number_of_beams 387 | current_time = [1]*number_of_beams 388 | 389 | for k in range(number_of_generated_sentences): 390 | # Define guidance word and index for guidance in model and quality function 391 | result_subsequences = [] 392 | for b in range(number_of_beams): 393 | ####################################### Generation loop ################################################ 394 | for i in range(number_of_concurrent_sentences): 395 | # Reset variables: 396 | context = full_text[b] 397 | guide_words = guide_words_s[b] 398 | guide_probs = guide_probs_s[b] 399 | # print(guide_probs) 400 | proba = 1 401 | this_sequence = "" 402 | w_c = 0 403 | eos_c = eos_count[b] 404 | t_time = total_time[b] 405 | c_time = current_time[b] 406 | if guide[b]: 407 | guide_next = True 408 | for j in range(number_of_words_per_sentence): 409 | context, guide_words, guide_next, guide_probs, proba, this_sequence, c_time, t_time = sample_sentence(context, 410 | this_sequence, tokenizer, model, guide_words, enc_dict, guide_probs, converter_table, 411 | weight, guide_next, proba, top_p=top_p, temperature=temperature, only_max=only_max, mode=mode, 412 | guarantee=do_guarantee, time=c_time, T_time=t_time, det_BS=det_BS, ith=i) 413 | 414 | else: # Dont't guide 415 | for j in range(number_of_words_per_sentence): 416 | context, proba, this_sequence, eos_c = sample_sentence_noguide(context, this_sequence, tokenizer, model, top_p=top_p, temperature=temperature, prev_proba=proba, eos_c=eos_c, det_BS=det_BS, ith=i) 417 | 418 | if type(proba) == torch.Tensor: 419 | proba = proba.item() 420 | 421 | score_ = get_score(k, number_of_words_per_sentence, online_probability[b], proba) 422 | w_c = number_keywords - len(guide_words) 423 | if not no_do_wc: 424 | quality_score = evaluate_quality_linear(this_sequence, w_c, score_) 425 | else: 426 | quality_score = evaluate_quality_linear(this_sequence, 0, score_) 427 | 428 | # DEBUG: 429 | # print("Beam, Guidance: ", b, str(guide_words), guide[b]) 430 | # print("txt, quality, wordC, score_: ", this_sequence, quality_score, w_c, score_) 431 | 432 | # Linear Q 433 | result_subsequences.append( 434 | [context, quality_score, w_c, score_, online_probability[b]*proba, guide_words, guide[b], eos_c, guide_probs, t_time, c_time]) 435 | 436 | if not guide[b]: 437 | break # No guiding, no multiple beams! 438 | 439 | if k==0: # First iteration of beam search is different! 440 | break 441 | ######################################################################################################## 442 | # Deterministic K2T 443 | result_subsequences_sorted = sorted(result_subsequences, key=lambda a_entry: a_entry[1], reverse=True) 444 | 445 | # Select Beams 446 | for b in range(number_of_beams): 447 | full_text[b] = result_subsequences_sorted[b][0] 448 | cum_quality_score[b] = result_subsequences_sorted[b][1] 449 | guide_words_s[b] = result_subsequences_sorted[b][5] 450 | guide_probs_s[b] = result_subsequences_sorted[b][8] 451 | guide[b] = result_subsequences_sorted[b][6] 452 | word_c[b] = result_subsequences_sorted[b][2] 453 | eos_count[b] = result_subsequences_sorted[b][7] 454 | total_time[b] = result_subsequences_sorted[b][9] 455 | current_time[b] = result_subsequences_sorted[b][10] 456 | if guide[b] and word_c[b] > number_keywords-1: # Only do this once, and then guide[b] no longer True 457 | guide[b] = False 458 | success_length[b] = k+1 459 | 460 | n_words_counter = (k+1)*number_of_words_per_sentence 461 | online_probability[b] = result_subsequences_sorted[b][4] 462 | online_perplexity = np.power(online_probability[b], (-1/n_words_counter)) 463 | 464 | ### DEBUG: Comment out to remove console output 465 | # print(">>>>>>>>>>>>> BEAM: ", b) 466 | # print("Guidance words: ", keywords) 467 | # print("Current sentence: ", full_text[b]) 468 | # print("Guidance word, word count, probs: ", guide_words_s[b], result_subsequences_sorted[b][2], guide_probs_s[b]) 469 | # print("Current perplexity, cumulative quality, eos: ", online_perplexity, cum_quality_score[b], eos_count[b]) 470 | ### 471 | 472 | if np.sum(eos_count) == number_of_beams: 473 | print("Finishing...") 474 | break 475 | 476 | 477 | ''' Uncomment to write all intermediate steps to .txt 478 | 479 | text_file.write("\nBest 10 next subsequences: \n") 480 | for result_subsequence in result_subsequences_sorted: 481 | text_file.write(result_subsequence[0] + "\n Perplexity:" + 482 | str(result_subsequence[2]) + "\n Quality Score: " + 483 | str(result_subsequence[1]) + "\n\n") 484 | 485 | text_file.write("\n\n\n\n") 486 | ''' 487 | ####################################### 488 | # final evaluation 489 | ####################################### 490 | end_time = time.time() 491 | time_needed = end_time - start_time 492 | 493 | success_length = get_success_length(success_length, number_of_beams, guide, number_of_generated_sentences) 494 | success_rate = word_c[0]/number_keywords 495 | distilGPT2_perplexity = distilGPT2_perplexity_score(full_text[0]) 496 | 497 | ### Distinct n-grams 498 | sep = '<|endoftext|>' 499 | stripped = full_text[0].strip(sep).split(sep, 2)[0] 500 | tokenized_text = tokenizer.encode(stripped) 501 | # 2_Distinct 502 | counter_2 = Counter() 503 | total_2 = 0 504 | distinct_2 = 0 505 | distinct_2, total_2, counter_2 = distinct_n(tokenized_text, 2, distinct_2, total_2, counter_2) # Need to set n 506 | 507 | # 3_Distinct 508 | counter_3 = Counter() 509 | total_3 = 0 510 | distinct_3 = 0 511 | distinct_3, total_3, counter_3 = distinct_n(tokenized_text, 3, distinct_3, total_3, counter_3) # Need to set n 512 | 513 | # 4_Distinct 514 | counter_4 = Counter() 515 | total_4 = 0 516 | distinct_4 = 0 517 | distinct_4, total_4, counter_4 = distinct_n(tokenized_text, 4, distinct_4, total_4, counter_4) # Need to set n 518 | 519 | print("------------------------------------------------------------------------------") 520 | print("FINAL TEXT: ") 521 | print(full_text[0]) 522 | print("Success rate, success length, perplexity: ", success_rate, success_length, distilGPT2_perplexity) 523 | 524 | 525 | ####################################### 526 | # Save results, write in file and return 527 | ####################################### 528 | 529 | # Declare evaluation 530 | evaluation = { 531 | "final_sequence: ": full_text[0], 532 | "keywords": keywords, 533 | #"online_perplexity": online_perplexity[0], 534 | "distilGPT2_perplexity": distilGPT2_perplexity, 535 | "success_rate": success_rate, 536 | "2_distinct": distinct_2, 537 | "2_total": total_2, 538 | "3_distinct": distinct_3, 539 | "3_total": total_3, 540 | "4_distinct": distinct_4, 541 | "4_total": total_4, 542 | "number_of_concurent_sentences": number_of_concurrent_sentences, 543 | "number_of_generated_sentences": number_of_generated_sentences, 544 | "number_of_words_per_sentence": number_of_words_per_sentence, 545 | "total_words": total_words, 546 | "top_k": top_k, 547 | "top_p": top_p, 548 | "model_name": model_name, 549 | "constant": constant, 550 | "time_needed": time_needed, 551 | "success_length": success_length, 552 | "guide_probs": guide_probs_s[0] 553 | } 554 | 555 | ### Write to text file 556 | text_file.write("Keywords: \n") 557 | for word in keywords: 558 | text_file.write(word + " ") 559 | text_file.write("\n\n") 560 | text_file.write("Final sequence: \n\n") 561 | text_file.write(full_text[0]) 562 | for b in range(number_of_beams): 563 | text_file_sentences.write(full_text[b]) 564 | text_file_sentences.write("\n\n") 565 | text_file_sentences.write("\n\nSuccess_rate: " + str(word_c[b]/number_keywords)) 566 | text_file_sentences.write("\nPerplexity: " + str(distilGPT2_perplexity_score(full_text[b]))) 567 | text_file_sentences.write("\n###############################\n") 568 | text_file.write("\n\nSuccess_rate: " + str(success_rate)) 569 | text_file.write("\nPerplexity: " + str(distilGPT2_perplexity)) 570 | text_file.write("\nTime_needed: " + str(time_needed)) 571 | text_file.write("\nSuccess_length: " + str(success_length)) 572 | text_file.write("\n2_distint_rate: " + '{0:.4f}'.format(distinct_2/total_2)) 573 | text_file.write("\n3_distint_rate: " + '{0:.4f}'.format(distinct_3/total_3)) 574 | text_file.write("\n4_distint_rate: " + '{0:.4f}'.format(distinct_4/total_4)) 575 | text_file.write("\n\n") 576 | text_file.close() 577 | text_file_sentences.close() 578 | 579 | del model 580 | torch.cuda.empty_cache() 581 | 582 | print("END: ", keywords) 583 | 584 | return evaluation 585 | 586 | def get_folderfile_name(task, file_name): 587 | 588 | if task == 'key2article': 589 | folder_name = file_name + '/' 590 | else: 591 | folder_name = os.path.dirname(file_name) 592 | 593 | 594 | abs_path = str(os.path.dirname(os.path.abspath(__file__))) 595 | file_name = str(os.path.abspath(os.path.join(abs_path, file_name))) 596 | folder_name = str(os.path.abspath(os.path.join(abs_path, folder_name))) 597 | if task == 'key2article': 598 | folder_name = folder_name + '/' 599 | file_name = file_name + '/' #Multiple files! 600 | print('file_name: ', file_name) 601 | print('folder_name: ', folder_name) 602 | 603 | return folder_name, file_name 604 | 605 | def get_savefile(args): 606 | 607 | save_file = 'Result_w_'+str(args.weight)+'_nBeams_'+str(args.n_beams)+'_nGenSent_'+str(args.n_generated_sentences)+'_nWordsPerSent_'+str(args.n_words_per_sentence)+'_topP_'+str(args.top_p) 608 | 609 | if args.det_BS: 610 | save_file = save_file + '_detBS' 611 | if not args.no_do_wc: 612 | save_file = save_file + '_WC' 613 | if args.do_guarantee: 614 | save_file = save_file + '_Guar_' + str(args.do_guarantee) 615 | if not args.guide: 616 | save_file = save_file + '_no_guide' 617 | if args.only_max == True: 618 | save_file = 'ONLYMAX_' + save_file 619 | save_file = save_file + '_' + str(args.embedding) 620 | save_file = save_file + '_' + str(args.mode) 621 | 622 | return save_file 623 | 624 | def get_savepath(task, results_subfolder, save_file, folder_name): 625 | 626 | if task == 'key2article': 627 | sub_folder = 'keyword_to_articles/' + str(results_subfolder) + '/' 628 | save_folder = 'results/' + sub_folder 629 | save_path = save_folder + save_file 630 | elif task == 'ROC': 631 | sub_folder = 'ROC/' 632 | save_folder = 'results/' + sub_folder 633 | save_path = save_folder + save_file 634 | elif task == 'commongen': 635 | sub_folder = 'commongen/' 636 | save_folder = 'results/' + sub_folder 637 | save_path = 'results/' + sub_folder + save_file 638 | else: 639 | sub_folder = os.path.basename(os.path.normpath(folder_name)) + '/' + str(results_subfolder) + '/' 640 | save_folder = 'results/' + sub_folder 641 | save_path = 'results/' + sub_folder + save_file 642 | 643 | try: 644 | os.mkdir(save_folder) 645 | print('made directory: ', save_folder) 646 | except OSError as error: 647 | print(error) 648 | 649 | return save_path 650 | 651 | def get_keywordsets(task, folder_name, file_name): 652 | 653 | if task == 'key2article': 654 | keyword_sets = [] 655 | for filename in os.listdir(folder_name): 656 | if filename.endswith('txt'): 657 | file1 = open(os.path.join(folder_name, filename), "r+") 658 | lines = file1.readlines() 659 | keywords = list(lines[2].strip().split(", ")) 660 | in_text = lines[1].split()[:30] 661 | keyword_sets.append((' '.join(in_text), keywords)) 662 | else: 663 | #File containing the keywords as text 664 | in_text = '<|endoftext|>' # 'Start with EOS 665 | #in_texts = ['I', 'It', 'A'] #Other possible start tokens 666 | file1 = open(file_name, "r+") 667 | lines = file1.readlines() 668 | if task == 'commongen': 669 | keyword_sets = [(in_text, list(line.strip().split())) for line in lines] 670 | else: 671 | keyword_sets = [(in_text, list(line.strip().split(", "))) for line in lines] 672 | # keyword_sets = [(random.choice(in_texts), list(line.strip().split(", "))) for line in lines] 673 | 674 | return keyword_sets 675 | 676 | 677 | def get_args(parser): 678 | 679 | # Get constant defined in run_gpt2.sh 680 | # Default is GPT-3 Beam Search except det_BS 681 | 682 | parser.add_argument('-top_p', type=float, default=0.9) 683 | parser.add_argument('-weight', type=float, default=5.0) #20.0 684 | parser.add_argument('-n_generated_sentences', type=int, default=90) 685 | parser.add_argument('-n_words_per_sentence', type=int, default=1) 686 | parser.add_argument('-n_beams', type=int, default=1) 687 | parser.add_argument('-n_repetitions', type=int, default=1) 688 | parser.add_argument('-temperature', type=float, default=1.) 689 | parser.add_argument('-only_max', type=bool, default=False) 690 | parser.add_argument('-no_do_wc', type=bool, default=False) 691 | parser.add_argument('-mode', type=str, default='max', 692 | choices=['max', 'next', 'all', 'random', 'best_tour'], help='modes: max, next, all, random, best_tour') 693 | parser.add_argument('-do_guarantee', type=bool, default=False) 694 | parser.add_argument('-embedding', type=str, default='glove', 695 | choices=list(word_embedding.keys()), help='word_embedding') 696 | parser.add_argument('-file_name', type=str, default='data/50_keywordsets_eval/word_sets.txt') #data/50_keywordsets_eval/word_sets data/commongen_small/commongen.dev.src_alpha_small.txt 697 | parser.add_argument('-det_BS', type=bool, default=False) 698 | parser.add_argument('-guide', type=bool, default=True) 699 | parser.add_argument('-results_subfolder', type=str, default='tmp') 700 | parser.add_argument('-task', type=str, default='50keywords', 701 | choices=['50keywords', 'ROC', 'key2article', 'commongen'], help='tasks: 50keywords, ROC, key2article, commongen') 702 | args = parser.parse_args() 703 | 704 | return args 705 | 706 | 707 | if __name__ == '__main__': 708 | 709 | parser = argparse.ArgumentParser() 710 | args = get_args(parser) 711 | 712 | file_name = args.file_name 713 | file_name = file_name.strip('/') 714 | if not file_name: 715 | raise Exception("file_name name missing. Please give the relative path to word_sets filename (or the word_sets folder in case of key2article flag is True).") 716 | 717 | ### Create model 718 | model = GPT2LMHeadModel.from_pretrained('gpt2-large') 719 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large') 720 | model.eval() 721 | model.to('cuda') # 722 | 723 | # Get keywords and save path 724 | folder_name, file_name = get_folderfile_name(args.task, file_name) 725 | save_file = get_savefile(args) 726 | save_path = get_savepath(args.task, args.results_subfolder, save_file, folder_name) 727 | keyword_sets = get_keywordsets(args.task, folder_name, file_name) 728 | 729 | print('Mode:', args.mode) 730 | print('Save path: ', save_path) 731 | 732 | # Create file containing the keyword embeddings 733 | save_path_dict = os.path.join(folder_name, 'dict_' + str(args.embedding) + '.pkl') 734 | if not os.path.isfile(save_path_dict): 735 | create_enc_dict(file_name, args.embedding, task=args.task) 736 | with open(save_path_dict, 'rb') as f: 737 | enc_dict = pickle.load(f) 738 | 739 | ################ RUN ################ 740 | all_results = np.zeros([len(keyword_sets), args.n_repetitions, 11], dtype = object) # Initialize results structure 741 | 742 | for j, keyword_set in enumerate(keyword_sets): 743 | 744 | if args.n_generated_sentences<0: 745 | in_text, keywords = keyword_set 746 | n_generated_sentences = math.ceil((len(keywords)+1) * abs(args.n_generated_sentences) / args.n_words_per_sentence) 747 | else: 748 | n_generated_sentences = args.n_generated_sentences 749 | 750 | for i in range(args.n_repetitions): 751 | results = conditional_language_generation(model,tokenizer, 752 | keyword_set=keyword_set, 753 | top_p=args.top_p, 754 | constant=args.weight, 755 | number_of_concurrent_sentences=args.n_beams, 756 | number_of_generated_sentences=n_generated_sentences, 757 | number_of_words_per_sentence=args.n_words_per_sentence, 758 | number_of_beams = args.n_beams, 759 | enc_dict=enc_dict, 760 | save_path=save_path, 761 | temperature=args.temperature, 762 | only_max=args.only_max, 763 | no_do_wc=args.no_do_wc, 764 | mode=args.mode, 765 | do_guarantee=args.do_guarantee, 766 | embedding=args.embedding, 767 | folder_name=folder_name, 768 | det_BS=args.det_BS, 769 | guide=args.guide, 770 | ) 771 | all_results[j][i][0] = results["distilGPT2_perplexity"] 772 | all_results[j][i][1] = results["time_needed"] 773 | all_results[j][i][2] = results["success_rate"] 774 | all_results[j][i][3] = results["success_length"] 775 | all_results[j][i][4] = results["2_distinct"] 776 | all_results[j][i][5] = results["2_total"] 777 | all_results[j][i][6] = results["3_distinct"] 778 | all_results[j][i][7] = results["3_total"] 779 | all_results[j][i][8] = results["4_distinct"] 780 | all_results[j][i][9] = results["4_total"] 781 | all_results[j][i][10] = results["guide_probs"] 782 | 783 | np.save(save_path, all_results) 784 | 785 | 786 | 787 | -------------------------------------------------------------------------------- /metrics_degen.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import json 4 | import os 5 | import numpy as np 6 | import scipy.io as sio 7 | import argparse 8 | os.environ['TRANSFORMERS_CACHE']='./transformer_models' # Work-around to avoid memory problems in server, comment out depending on memory availability 9 | from utility_gpt import * 10 | from perplexity import * 11 | from collections import Counter 12 | from scipy import stats 13 | import operator 14 | 15 | from transformers import GPT2Tokenizer 16 | 17 | #Self Bleu 18 | import random 19 | from functools import partial 20 | from multiprocessing.pool import Pool 21 | 22 | import spacy 23 | from tqdm import tqdm 24 | from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu 25 | 26 | 27 | def parse_args() -> argparse.Namespace: 28 | parser = argparse.ArgumentParser() 29 | 30 | parser.add_argument("-file", type=str) 31 | parser.add_argument("-zipf_N", type=int, default=5000) 32 | parser.add_argument("-dist_N", type=int, default=2, help="N in distinct-N metric") 33 | parser.add_argument("--n_sample", type=int, default=50, 34 | help="how many sentences to sample to calculate bleu") 35 | 36 | return parser.parse_args() 37 | 38 | def distinct_n(example, n, n_distinct, n_total, counter): 39 | #counter = Counter() 40 | #n_total = 0 41 | #n_distinct = 0 42 | #for example in examples: 43 | for token in zip(*(example[i:] for i in range(n))): 44 | if token not in counter: 45 | n_distinct += 1 46 | elif counter[token] == 1: 47 | n_distinct -= 1 48 | counter[token] += 1 49 | n_total += 1 50 | return n_distinct, n_total, counter 51 | 52 | def bleu_i(weights, all_sentences, smoothing_function, i): 53 | # noinspection PyTypeChecker 54 | return sentence_bleu( 55 | references=all_sentences[:i] + all_sentences[i + 1:], 56 | hypothesis=all_sentences[i], 57 | weights=weights, 58 | smoothing_function=smoothing_function) 59 | 60 | def main(): 61 | 62 | args = parse_args() 63 | 64 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large') 65 | 66 | # Self-BLEU 67 | random.seed(0) 68 | nlp = spacy.load('en_core_web_sm', disable=['parser', 'tagger', 'ner']) 69 | nlp.add_pipe(nlp.create_pipe('sentencizer')) 70 | 71 | all_sentences = [] 72 | #save_path = str(os.path.dirname(os.path.abspath(__file__))) + "/results/50_keywordsets_eval/progressive/" 73 | save_path = os.path.splitext(args.file)[0] +"_metrics" 74 | 75 | 76 | with open(args.file, 'r') as file1: 77 | lines = file1.readlines() 78 | ppl = np.zeros((50)) 79 | i = 0 80 | text = '' 81 | next_yes = False 82 | next_no = False 83 | sep = '<|endoftext|>' 84 | counter_2 = Counter() 85 | n_total_2 = 0 86 | n_distinct_2 = 0 87 | counter_3 = Counter() 88 | n_total_3 = 0 89 | n_distinct_3 = 0 90 | counter_4 = Counter() 91 | n_total_4 = 0 92 | n_distinct_4 = 0 93 | cnt_zipf = Counter() 94 | 95 | for line in tqdm(lines): 96 | if "Final sequence:" in line: 97 | next_no = True 98 | elif next_no: 99 | next_yes = True 100 | next_no = False 101 | 102 | if "Success_rate:" in line: 103 | next_yes = False 104 | stripped = text.split(sep, 2)[1] 105 | #### Put metrics here!: 106 | tokenized_text = tokenizer.encode(stripped) 107 | # Distinct n: 108 | n_distinct_2, n_total_2, counter_2 = distinct_n(tokenized_text, 2, n_distinct_2, n_total_2, counter_2) 109 | n_distinct_3, n_total_3, counter_3 = distinct_n(tokenized_text, 3, n_distinct_3, n_total_3, counter_3) 110 | n_distinct_4, n_total_4, counter_4 = distinct_n(tokenized_text, 4, n_distinct_4, n_total_4, counter_4) 111 | # Zipf coeff: 112 | cnt_zipf.update(tokenized_text) 113 | # Self-BLEU: 114 | all_sentences.append(tokenized_text) 115 | # PPL: 116 | ppl[i] = distilGPT2_perplexity_score(stripped) 117 | i = i+1 118 | text = '' 119 | elif next_yes: 120 | text = text + line 121 | 122 | text_file = open(save_path + '.txt', 'a+', encoding='utf8') 123 | print(f"{os.path.basename(args.file)}") 124 | text_file.write(f"{os.path.basename(args.file)}\n") 125 | # Perplexity 126 | mean_ppl = np.mean(ppl) 127 | print(f"perplexity\t {mean_ppl}") 128 | text_file.write(f"perplexity\t {mean_ppl}\n\n") 129 | 130 | # Distinct n 2 131 | print(f"distinct 2-grams\ttotal 2-grams\tdistinct proportion") 132 | text_file.write(f"distinct 2-grams\ttotal 2-grams\tdistinct proportion\n") 133 | print(f"{n_distinct_2}\t{n_total_2}\t{n_distinct_2/n_total_2}") 134 | text_file.write(f"{n_distinct_2}\t{n_total_2}\t{n_distinct_2/n_total_2}\n\n") 135 | 136 | # Distinct n 3 137 | print(f"distinct 3-grams\ttotal 3-grams\tdistinct proportion") 138 | text_file.write(f"distinct 3-grams\ttotal 3-grams\tdistinct proportion\n") 139 | print(f"{n_distinct_3}\t{n_total_3}\t{n_distinct_3/n_total_3}") 140 | text_file.write(f"{n_distinct_3}\t{n_total_3}\t{n_distinct_3/n_total_3}\n\n") 141 | 142 | # Distinct n 4 143 | print(f"distinct 4-grams\ttotal 4-grams\tdistinct proportion") 144 | text_file.write(f"distinct 4-grams\ttotal 4-grams\tdistinct proportion\n") 145 | print(f"{n_distinct_4}\t{n_total_4}\t{n_distinct_4/n_total_4}") 146 | text_file.write(f"{n_distinct_4}\t{n_total_4}\t{n_distinct_4/n_total_4}\n\n") 147 | 148 | # Zipf coeff 149 | xs = np.arange(1, min(len(cnt_zipf), args.zipf_N)+1) 150 | ys = np.array(sorted(cnt_zipf.values(), key=operator.neg)[:args.zipf_N]) 151 | a, b, r, p, std = stats.linregress(np.log(xs), np.log(ys)) 152 | print("zipf value\tregression r value \tregression p value") 153 | text_file.write("zipf value\tregression r value \tregression p value\n") 154 | print(f"{-a}\t{-r}\t{p}") 155 | text_file.write(f"{-a}\t{-r}\t{p}\n\n") 156 | """ 157 | ############# 158 | # Self-BLUE: 159 | smoothing_function = SmoothingFunction().method1 160 | pool = Pool(processes=os.cpu_count()) 161 | bleu_scores = [] 162 | for n_gram in range(1, 6): 163 | 164 | if n_gram == 1: 165 | weights = (1.0, 0, 0, 0) 166 | elif n_gram == 2: 167 | weights = (0.5, 0.5, 0, 0) 168 | elif n_gram == 3: 169 | weights = (1.0 / 3, 1.0 / 3, 1.0 / 3, 0) 170 | elif n_gram == 4: 171 | weights = (0.25, 0.25, 0.25, 0.25) 172 | elif n_gram == 5: 173 | weights = (0.2, 0.2, 0.2, 0.2, 0.2) 174 | else: 175 | raise ValueError 176 | #print("Len all sentences: ", len(all_sentences)) 177 | bleu_scores.append( 178 | list(tqdm( 179 | pool.imap_unordered( 180 | partial(bleu_i, weights, all_sentences, smoothing_function), 181 | random.sample(range(len(all_sentences)), args.n_sample)), 182 | total=args.n_sample, 183 | smoothing=0.0, 184 | desc=f"bleu-{n_gram}"))) 185 | print(f"\n\nbleu-{n_gram} = {sum(bleu_scores[n_gram - 1]) / args.n_sample}") 186 | text_file.write(f"\n\nbleu-{n_gram} = {sum(bleu_scores[n_gram - 1]) / args.n_sample}\n") 187 | 188 | for n_gram in range(5): 189 | print(f"bleu-{n_gram + 1} = {sum(bleu_scores[n_gram]) / args.n_sample}") 190 | text_file.write(f"bleu-{n_gram + 1} = {sum(bleu_scores[n_gram]) / args.n_sample}\n") 191 | ################ 192 | """ 193 | save_dict = {} 194 | save_dict['ppl'] = mean_ppl 195 | save_dict['2_distinct'] = [2, n_distinct_2, n_total_2, n_distinct_2/n_total_2] 196 | save_dict['3_distinct'] = [3, n_distinct_3, n_total_3, n_distinct_3/n_total_3] 197 | save_dict['4_distinct'] = [4, n_distinct_4, n_total_4, n_distinct_4/n_total_4] 198 | save_dict['zipf'] = [a, r, p] 199 | """ 200 | save_dict['self_bleu'] = [sum(bleu_scores[0]) / args.n_sample, \ 201 | sum(bleu_scores[1]) / args.n_sample, \ 202 | sum(bleu_scores[2]) / args.n_sample, \ 203 | sum(bleu_scores[3]) / args.n_sample, \ 204 | sum(bleu_scores[4]) / args.n_sample] 205 | """ 206 | np.save(save_path + ".npy", save_dict) 207 | 208 | 209 | if __name__ == '__main__': 210 | main() 211 | 212 | -------------------------------------------------------------------------------- /metrics_degen_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Change path (-file=) to the file for which you want to calculate the metrics 4 | 5 | #python metrics_degen.py -file=results/50_keywordsets_eval/LinearQLog_noSq_deterministic_result_w_20.0_nBeams_1_nConcSent_1_nGenSent_90_nWordsPerSent_1_temperature_1.0.txt 6 | 7 | -------------------------------------------------------------------------------- /perplexity.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import os 4 | os.environ['TRANSFORMERS_CACHE']='.' 5 | 6 | from transformers import GPT2TokenizerFast, GPT2LMHeadModel 7 | # Load pre-trained model (weights) 8 | model = GPT2LMHeadModel.from_pretrained('distilgpt2') 9 | model.eval() 10 | # Load pre-trained model tokenizer (vocabulary) 11 | tokenizer = GPT2TokenizerFast.from_pretrained('distilgpt2') 12 | 13 | 14 | def distilGPT2_perplexity_score(sentence): 15 | tokenize_input = tokenizer.tokenize(sentence) 16 | tensor_input = torch.tensor( 17 | [tokenizer.convert_tokens_to_ids(tokenize_input)]) 18 | loss, logits = model(tensor_input, labels=tensor_input)[:2] 19 | 20 | return math.exp(loss) 21 | 22 | 23 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.11.0 2 | aiohttp==3.7.3 3 | argon2-cffi==20.1.0 4 | astor==0.8.1 5 | async-generator==1.10 6 | async-timeout==3.0.1 7 | attrs==20.3.0 8 | backcall==0.2.0 9 | bleach==3.2.1 10 | blis==0.4.1 11 | boto==2.49.0 12 | boto3==1.16.35 13 | botocore==1.19.35 14 | brotlipy==0.7.0 15 | bz2file==0.98 16 | cached-property==1.5.2 17 | cachetools==4.2.0 18 | catalogue==1.0.0 19 | certifi==2020.12.5 20 | cffi==1.14.4 21 | chardet==3.0.4 22 | click==7.1.2 23 | cryptography==3.3.1 24 | cycler==0.10.0 25 | cymem==2.0.4 26 | dataclasses==0.8 27 | decorator==4.4.2 28 | defusedxml==0.6.0 29 | docutils==0.16 30 | entrypoints==0.3 31 | filelock==3.0.12 32 | gast==0.4.0 33 | gensim==3.8.0 34 | grpcio==1.34.0 35 | h5py==3.1.0 36 | idna==2.10 37 | idna-ssl==1.1.0 38 | importlib-metadata==3.3.0 39 | ipykernel==5.3.4 40 | ipython==7.16.1 41 | ipython-genutils==0.2.0 42 | jedi==0.17.0 43 | Jinja2==2.11.2 44 | jmespath==0.10.0 45 | joblib==1.0.0 46 | json5==0.9.5 47 | jsonschema==3.0.2 48 | jupyter-client==6.1.7 49 | jupyter-core==4.7.0 50 | jupyterlab==2.2.6 51 | jupyterlab-pygments==0.1.2 52 | jupyterlab-server==1.2.0 53 | Keras-Applications==1.0.8 54 | Keras-Preprocessing==1.1.2 55 | kiwisolver 56 | llvmlite==0.34.0 57 | Markdown==3.3.3 58 | MarkupSafe==1.1.1 59 | matplotlib 60 | mistune==0.8.4 61 | mkl-fft 62 | mkl-random 63 | multidict==4.7.6 64 | murmurhash==1.0.5 65 | nbclient==0.5.1 66 | nbconvert==6.0.7 67 | nbformat==5.0.8 68 | nest-asyncio==1.4.3 69 | nltk==3.5 70 | notebook==6.1.5 71 | numba==0.51.2 72 | numpy==1.19.0 73 | olefile==0.46 74 | packaging==20.8 75 | pandas==1.1.5 76 | pandocfilters 77 | parso==0.8.1 78 | pexpect==4.8.0 79 | pickleshare==0.7.5 80 | Pillow 81 | plac==0.9.6 82 | preshed==3.0.2 83 | prometheus-client==0.9.0 84 | prompt-toolkit==3.0.8 85 | protobuf==3.14.0 86 | ptyprocess==0.6.0 87 | pyasn1==0.4.8 88 | pyasn1-modules==0.2.8 89 | pycparser==2.20 90 | Pygments==2.7.3 91 | pyOpenSSL==20.0.0 92 | pyparsing 93 | pyrsistent==0.17.3 94 | PySocks==1.7.1 95 | python-dateutil 96 | pytorch-pretrained-bert==0.6.2 97 | pytz==2020.4 98 | pyzmq==20.0.0 99 | regex==2020.11.13 100 | requests==2.25.0 101 | rsa==4.6 102 | s3transfer==0.3.3 103 | sacremoses==0.0.43 104 | scikit-learn==0.23.2 105 | scipy==1.5.4 106 | Send2Trash==1.5.0 107 | six 108 | smart-open==4.0.1 109 | spacy==2.3.2 110 | srsly==1.0.5 111 | tensorboard==1.11.0 112 | tensorflow==1.11.0 113 | termcolor==1.1.0 114 | terminado==0.9.1 115 | testpath==0.4.4 116 | thinc==7.4.1 117 | threadpoolctl==2.1.0 118 | tikzplotlib==0.9.6 119 | tokenizers==0.9.4 120 | torch==1.4.0 121 | tornado 122 | tqdm==4.54.1 123 | traitlets==4.3.3 124 | transformers==4.0.1 125 | typing-extensions==3.7.4.3 126 | urllib3==1.26.2 127 | wasabi==0.8.0 128 | wcwidth==0.2.5 129 | webencodings==0.5.1 130 | Werkzeug==1.0.1 131 | xgboost==1.3.0.post0 132 | yarl==1.5.1 133 | zipp==3.4.0 134 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ###### Default Values 4 | 5 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=tmp -weight=5.0 -top_p=0.9 -n_generated_sentences=90 -do_guarantee=True 6 | 7 | ###### END 8 | 9 | ### no control vs guide words only vs glove guidance 10 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=guide_vs_no_guide_beams -top_p=0.9 -n_generated_sentences=90 -guide= 11 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=guide_vs_no_guide_beams -weight=5.0 -top_p=0.9 -n_generated_sentences=90 -only_max=True 12 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=guide_vs_no_guide_beams -weight=10.0 -top_p=0.9 -n_generated_sentences=90 -only_max=True 13 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=guide_vs_no_guide_beams -weight=20.0 -top_p=0.9 -n_generated_sentences=90 -only_max=True 14 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=guide_vs_no_guide_beams -weight=5.0 -top_p=0.9 -n_generated_sentences=90 15 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=guide_vs_no_guide_beams -weight=10.0 -top_p=0.9 -n_generated_sentences=90 16 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=guide_vs_no_guide_beams -weight=20.0 -top_p=0.9 -n_generated_sentences=90 17 | 18 | ### evaluating shift strength 19 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=guide_vs_no_guide_beams -weight=5.0 -top_p=0.9 -n_generated_sentences=90 -do_guarantee=True 20 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=guide_vs_no_guide_beams -weight=10.0 -top_p=0.9 -n_generated_sentences=90 -do_guarantee=True 21 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=guide_vs_no_guide_beams -weight=15.0 -top_p=0.9 -n_generated_sentences=90 -do_guarantee=True 22 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=guide_vs_no_guide_beams -weight=20.0 -top_p=0.9 -n_generated_sentences=90 -do_guarantee=True 23 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=guide_vs_no_guide_beams -weight=25.0 -top_p=0.9 -n_generated_sentences=90 -do_guarantee=True 24 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=guide_vs_no_guide_beams -weight=30.0 -top_p=0.9 -n_generated_sentences=90 -do_guarantee=True 25 | 26 | ### comparing different unordered modes 27 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=comparing_modes -weight=5.0 -top_p=0.9 -n_generated_sentences=90 -do_guarantee=True 28 | # python main_DBS.py -mode='max' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=comparing_modes -weight=5.0 -top_p=0.9 -n_generated_sentences=90 -do_guarantee=True 29 | # python main_DBS.py -mode='random' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=comparing_modes -weight=5.0 -top_p=0.9 -n_generated_sentences=90 -do_guarantee=True 30 | # python main_DBS.py -mode='all' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=comparing_modes -weight=5.0 -top_p=0.9 -n_generated_sentences=90 -do_guarantee=True 31 | 32 | ### comparing different decoding methods 33 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=decoding_methods -weight=5.0 -top_p=0.9 -n_generated_sentences=90 -do_guarantee=True 34 | # python main_DBS.py -mode='next' -file_name=/data/50_keywordsets_eval/word_sets.txt -results_subfolder=decoding_methods -weight=5.0 -top_p=0.0 -n_generated_sentences=90 -do_guarantee=True 35 | 36 | ### ROC 37 | # python main_DBS.py -mode='max' -file_name=/data/ROC/ROCStories_20_storylines_500_0.txt -results_subfolder=final4_ -weight=5.0 -top_p=0.9 -n_generated_sentences=-7 -n_beams=4 -do_guarantee=True 38 | 39 | ### Keyword to Article 40 | python main_DBS.py -mode='max' -file_name=/data/keyword_to_articles -results_subfolder=tmp -key2article=True -weight=5.0 -top_p=0.9 -n_generated_sentences=-15 -n_beams=4 -do_guarantee=True 41 | 42 | -------------------------------------------------------------------------------- /utility_gpt.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import pandas as pd 4 | import scipy.io as sio 5 | import math 6 | import torch 7 | import spacy 8 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 9 | os.environ['GENSIM_DATA_DIR']='./gensim-data' 10 | 11 | from nltk.stem import PorterStemmer, LancasterStemmer 12 | porter = PorterStemmer() 13 | 14 | from sklearn.metrics.pairwise import cosine_similarity 15 | import itertools 16 | 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | 20 | def glove_encode(glove_encoder, word): 21 | return glove_encoder(word) 22 | 23 | 24 | 25 | def checker(string): 26 | string = string.replace("'ve", '') 27 | string = string.replace("@", '') 28 | string = string.replace("'re", '') 29 | string = string.replace("'d", '') 30 | string = string.replace("?", '') 31 | string = string.replace("'s", '') 32 | string = string.replace(":", '') 33 | string = string.replace("!", '') 34 | string = string.replace('"', '') 35 | string = string.replace(".", '') 36 | string = string.replace("--", '') 37 | string = string.replace("'", '') 38 | string = string.replace(",", '') 39 | string = string.replace(';', '') 40 | string = string.replace('‘', '') 41 | string = string.replace('(', '') 42 | string = string.replace(')', '') 43 | string = string.replace('\'', '') 44 | string = string.replace(' ', '') 45 | return(string) 46 | 47 | 48 | ## Pytorch 49 | def converter_table_glove(): 50 | import gensim.downloader as api 51 | glove_encoder = api.load("glove-wiki-gigaword-300") 52 | 53 | path = str(os.path.dirname(os.path.abspath(__file__))) + \ 54 | '/data/converter_table_glove' 55 | 56 | # load gpt-2 model 57 | #model = GPT2LMHeadModel.from_pretrained('gpt2') 58 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 59 | #model.eval() 60 | 61 | holder = np.zeros((50257, 300)) 62 | 63 | # translate every word from the gpt-2 space into a glove representation 64 | for i in range(50257): 65 | try: 66 | word = tokenizer.decode([i]) 67 | word = checker(word.strip().lower()) 68 | glove = glove_encoder[word] 69 | holder[i, :] = glove 70 | except: 71 | word = tokenizer.decode([i]) 72 | holder[i, :] = np.zeros((300)) #+ 500 73 | 74 | # Save all 50'000 glove representations of the gpt-2 words 75 | np.save(file=path, arr=holder) 76 | print('Table was generated') 77 | 78 | def converter_table_word2vec(): 79 | import gensim.downloader as api 80 | word2vec_encoder = api.load("word2vec-google-news-300") 81 | 82 | path = str(os.path.dirname(os.path.abspath(__file__))) + \ 83 | '/data/converter_table_word2vec' 84 | 85 | # load gpt-2 model 86 | #model = GPT2LMHeadModel.from_pretrained('gpt2') 87 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 88 | #model.eval() 89 | 90 | holder = np.zeros((50257, 300)) 91 | 92 | # translate every word from the gpt-2 space into a word2vec representation 93 | for i in range(50257): 94 | try: 95 | word = tokenizer.decode([i]) 96 | word = checker(word.strip().lower()) 97 | word2vec = word2vec_encoder[word] 98 | holder[i, :] = word2vec 99 | except: 100 | word = tokenizer.decode([i]) 101 | holder[i, :] = np.zeros((300)) #+ 500 102 | 103 | # Save all 50'000 word2vec representations of the gpt-2 words 104 | np.save(file=path, arr=holder) 105 | print('Table was generated') 106 | 107 | 108 | def count_word_stem_one(word, sequence): 109 | #print ("Sequence", sequence) 110 | sequence = sequence.split() 111 | 112 | word_count = 0 113 | word_stem = porter.stem(word.lower()) 114 | for s_word in sequence: 115 | s_word_stem = porter.stem(s_word.lower()) 116 | if(s_word_stem == word_stem): 117 | word_count = 1 118 | break 119 | 120 | return word_count 121 | 122 | def count_word_stem(word, sequence): 123 | #print ("Sequence", sequence) 124 | sequence = sequence.split() 125 | word_count = 0 126 | 127 | word_stem = porter.stem(word.lower()) 128 | 129 | for s_word in sequence: 130 | s_word_stem = porter.stem(s_word.lower()) 131 | #print(s_word_stem) 132 | if(s_word_stem == word_stem): 133 | word_count += 1 134 | 135 | return word_count 136 | 137 | # A score function for the quality of the sentence 138 | def evaluate_quality(sequence, word, related_count, perplexity, guide, temp=1.): 139 | # we aim for one ocurance of the word, and low perplexity 140 | w_1 = 1 141 | w_3 = 0.001 142 | c_star = 2 143 | 144 | if(word == ""): 145 | quality_score = math.exp(-(w_1*(c_star) + w_3*perplexity)) 146 | return quality_score 147 | 148 | quality_score = 0 149 | word_count = count_word_stem(word, sequence) 150 | 151 | 152 | if(word_count != 0) and guide: 153 | quality_score = math.exp(-(w_1*word_count + w_3*perplexity)) 154 | else: 155 | quality_score = math.exp(-(w_1*(c_star) + w_3*perplexity)) 156 | 157 | quality_score = quality_score/temp 158 | # DEBUG 159 | #print("txt, quality_score, word_count, rel_count, ppl", sequence, quality_score, word_count, related_count, perplexity) 160 | 161 | return quality_score, word_count 162 | 163 | 164 | 165 | # A score function for the quality of the sentence 166 | def evaluate_quality_linear(sequence, word_count, perplexity, temp=1., perp=False): 167 | # we aim for one ocurance of the word, and low perplexity 168 | w_1 = 1 169 | w_3 = 0.01 170 | 171 | if perp: 172 | quality_score = word_count - w_3*perplexity 173 | else: 174 | quality_score = word_count + w_3*perplexity 175 | 176 | quality_score = quality_score/temp # Temperature for sampling 177 | 178 | return quality_score 179 | 180 | 181 | # simply the negative cosine similarity for use in calculating the 'best_tour' 182 | def neg_cosine_similarity(v,w): 183 | return -1 * cosine_similarity(np.reshape(v, (1, -1)), np.reshape(w, (1, -1))) 184 | 185 | # simply the positive cosine similarity for use in calculating the "worst" tour using 'best_tour' - used only as a sanity check (is worst worse than best?) 186 | def pos_cosine_similarity(v,w): 187 | return cosine_similarity(np.reshape(v, (1, -1)), np.reshape(w, (1, -1))) 188 | 189 | # function to calculate the total tour length when visiting the words in the given 'order' 190 | def tour_length(distance_matrix, order): 191 | length = 0 192 | for i, j in zip(order, order[1:]): 193 | length += distance_matrix[i][j] 194 | return length 195 | 196 | # find the best tour through the guide words, minimizing the pairwise distance between consecutive guide words 197 | def best_tour(glove_array, distance=neg_cosine_similarity, top_k=1): 198 | """ 199 | returns the best order to minimize the tour length 200 | default pairwise distance is the negative cosine similarity 201 | input should be an nxm array of word embeddings, where n is no. words and m is length of the word embedding 202 | *NOT IMPLEMENTED: set top_k to the beam length if you want to use a separate order per beam. 203 | """ 204 | n = len(glove_array) 205 | distance_matrix = np.zeros((n,n)) 206 | for i, v in enumerate(glove_array): 207 | for j, w in enumerate(glove_array): 208 | distance_matrix[i][j] = distance(v,w) 209 | tours = {} 210 | for i in itertools.permutations(list(range(n))): 211 | tours[i] = tour_length(distance_matrix, i) 212 | best_tour = min(tours, key=tours.get) 213 | return best_tour 214 | 215 | 216 | class KeywordsDataset(Dataset): 217 | """Keywords dataset.""" 218 | 219 | def __init__(self, keyword_sets): 220 | self.keyword_sets = keyword_sets 221 | 222 | def __len__(self): 223 | return len(self.keyword_sets) 224 | 225 | def __getitem__(self, idx): 226 | return self.keyword_sets[idx] 227 | --------------------------------------------------------------------------------