├── README.md ├── SGD_delex.zip ├── collect_sgd_intent.py ├── combine_sgd.py ├── combine_simulators.py ├── dataset.zip ├── environment.yml ├── img └── framework.png ├── qa_inference.py └── transition.py /README.md: -------------------------------------------------------------------------------- 1 | # SalesBot: Transitioning from Chit-Chat to Task-Oriented Dialogues 2 | 3 | ## Framework 4 |

5 | 6 |

7 | This paper focuses on investigating the conversations starting from open-domain social chatting and then gradually transitioning to task-oriented purposes, and releases a large-scale dataset with detailed annotations for encouraging this research direction. To achieve this goal, this paper proposes a framework to automatically generate many dialogues without human involvement, in which any powerful open-domain dialogue generation model can be easily leveraged. 8 | 9 | ## Dependency 10 | Check the packages needed or simply run the command 11 | ```console 12 | conda env create -f environment.yml 13 | ``` 14 | 15 | ## Data 16 | * selfchat: 17 | ```console 18 | mkdir selfchat 19 | parlai self_chat --model-file zoo:blender/blender_1Bdistill/model --inference nucleus --num-self-chats 20 --task blended_skill_talk --include-personas True --include-initial-utterances True --outfile selfchat/merge_sgd_20.json 20 | parlai self_chat --model-file zoo:blender/blender_1Bdistill/model --inference nucleus --num-self-chats 20 --task blended_skill_talk --include-personas True --include-initial-utterances True --outfile selfchat/simulators_20.json 21 | ``` 22 | * intent detection model: 23 | ```console 24 | python3 qa_inference.py --data_file selfchat/merge_sgd_20.jsonl --output_file merge_sgd_intent.json --device 0 25 | python3 qa_inference.py --data_file selfchat/simulators_20.jsonl --output_file simulators_intent.json --device 0 26 | ``` 27 | * task-oriented simulators: 28 | ```console 29 | python3 combine_simulators.py simulators_intent.json 30 | ``` 31 | * merge SGD: 32 | ```console 33 | # SGD_delex is the version preprocessed by "ACCENTOR: Adding Chit-Chat to Enhance Task-Oriented Dialogues" 34 | unzip SGD_delex 35 | mkdir sgd_intent_dialog 36 | python3 collect_sgd_intent.py SGD_delex 37 | python3 combine_sgd.py merge_sgd_intent.json 38 | 39 | ``` 40 | * transition: 41 | ```console 42 | python3 transition.py combine_sgd.json 43 | python3 transition.py combine_simulators.json 44 | ``` 45 | 46 | ## Citation 47 | 48 | Please cite our paper if you use SalesBot in your work: 49 | 50 | ```bibtex 51 | @inproceedings{chiu2022salesbot, 52 | title={{SalesBot}: Transitioning from Chit-Chat to Task-Oriented Dialogues}, 53 | author={Chiu, Ssu and Li, Maolin and Lin, Yen-Ting and Chen, Yun-Nung}, 54 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (ACL)}, 55 | year={2022} 56 | } 57 | ``` 58 | -------------------------------------------------------------------------------- /SGD_delex.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiuLab/SalesBot/1fdd8713b76dca04a11791519e9a445afe7aa35f/SGD_delex.zip -------------------------------------------------------------------------------- /collect_sgd_intent.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import sys 4 | 5 | intents = { 6 | "LookupSong": [], 7 | "PlaySong": [], 8 | "LookupMusic": [], 9 | "FindMovies": [], 10 | "GetTimesForMovie": [], 11 | "FindAttractions": [], 12 | } 13 | 14 | for t in ["train", "dev", "test"]: 15 | dialogue_paths = glob.glob(f"{sys.argv[1]}/{t}/dialogues_*") 16 | for p in dialogue_paths: 17 | with open(p, "r") as f: 18 | dialogues = json.load(f) 19 | 20 | for d in dialogues: 21 | intent = None 22 | turns = [] 23 | sample = {"intent_pos": 0, "dialogue": []} 24 | for t in range(len(d["turns"])): 25 | if d["turns"][t]["speaker"] == "USER": 26 | for f in d["turns"][t]["frames"]: 27 | if ( 28 | f["state"]["active_intent"] in intents.keys() 29 | and intent is None 30 | ): 31 | intent = f["state"]["active_intent"] 32 | sample["intent_pos"] = t 33 | turns.append(d["turns"][t]["utterance"]) 34 | else: 35 | turns.append(d["turns"][t]["delex"]) 36 | if intent is not None: 37 | sample["dialogue"] = turns 38 | intents[intent].append(sample) 39 | 40 | 41 | for (k, v) in intents.items(): 42 | with open(f"sgd_intent_dialog/{k}_delex.json", "w") as f: 43 | json.dump(v, f, ensure_ascii=False, indent=4) 44 | -------------------------------------------------------------------------------- /combine_sgd.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import sys 4 | from typing import Dict 5 | 6 | import torch 7 | from tqdm.auto import tqdm 8 | 9 | persona = json.load(open(sys.argv[1], "r")) 10 | intent_description: Dict[str, str] = { 11 | "LookupSong": "search for a song", 12 | "PlaySong": "play the selected song on the device", 13 | "LookupMusic": "search for a song based on the name and optionally other attributes", 14 | "FindMovies": "find movies by genre and optionally director", 15 | "GetTimesForMovie": "get show times for a movie at a location on a given date", 16 | "FindAttractions": "browse attractions in a given city", 17 | } 18 | output = open("combine_sgd.json", "w") 19 | transition_questions: Dict[str, str] = { 20 | k: f"Do you want to {v}?" for (k, v) in intent_description.items() 21 | } 22 | device = "cuda" if torch.cuda.is_available() else "cpu" 23 | intent = {} 24 | intents = {} 25 | data = [] 26 | random.seed(26) 27 | 28 | for k in intent_description.keys(): 29 | with open(f"sgd_intent_dialog/{k}_delex.json", "r") as f: 30 | intents[k] = json.load(f) 31 | random.shuffle(intents[k]) 32 | 33 | for d in tqdm(persona): 34 | intent_appear = False 35 | context = [] 36 | for i, turn in enumerate(d): 37 | context.append(turn["text"]) 38 | if len(turn["intent"]) != 0: 39 | last_chit_chat = d[i + 1]["text"] if (i + 1) < len(d) else "" 40 | intent_appear = True 41 | intent = {"type": turn["intent"], "position": i} 42 | whole_transition = ( 43 | last_chit_chat + " " + transition_questions[turn["intent"][0]] 44 | ) 45 | context.append(whole_transition) 46 | break 47 | 48 | if intent_appear and len(intents[intent["type"][0]]) != 0: 49 | sample = intents[intent["type"][0]][0] 50 | intents[intent["type"][0]] = intents[intent["type"][0]][1:] 51 | dialog = sample["dialogue"][sample["intent_pos"] :] 52 | context += dialog 53 | data.append( 54 | {"id": f"merge_{len(data):04d}", "dialog": context, "intent": intent} 55 | ) 56 | json.dump(data, output, indent=4) 57 | -------------------------------------------------------------------------------- /combine_simulators.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import sys 4 | from typing import Dict 5 | 6 | import torch 7 | from tqdm.auto import tqdm 8 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 9 | 10 | 11 | def jaccard_similarity(list1, list2): 12 | s1 = set(list1) 13 | s2 = set(list2) 14 | return float(len(s1.intersection(s2)) / len(s1.union(s2))) 15 | 16 | 17 | persona = json.load(open(sys.argv[1], "r")) 18 | intent_description: Dict[str, str] = { 19 | "LookupSong": "search for a song", 20 | "PlaySong": "play the selected song on the device", 21 | "LookupMusic": "search for a song based on the name and optionally other attributes", 22 | "FindMovies": "find movies by genre and optionally director", 23 | "GetTimesForMovie": "get show times for a movie at a location on a given date", 24 | "FindAttractions": "browse attractions in a given city", 25 | } 26 | output = open("combine_simulators.json", "w") 27 | transition_questions: Dict[str, str] = { 28 | k: f"Do you want to {v}?" for (k, v) in intent_description.items() 29 | } 30 | device = "cuda" if torch.cuda.is_available() else "cpu" 31 | end_keywords = ["goodbye", "bye"] 32 | end_sentences = [ 33 | "have a great day", 34 | "have a nice day", 35 | "have a good day", 36 | "have a wonderful day", 37 | "enjoy your day", 38 | "have a good one", 39 | "have a good time", 40 | "enjoy the rest of your day", 41 | "have a fantastic day", 42 | "i am glad i could help have a nice day", 43 | ] 44 | intent = {} 45 | data = [] 46 | 47 | for d in tqdm(persona): 48 | intent_appear = False 49 | history = [] 50 | context = [] 51 | for i, turn in enumerate(d): 52 | history.append(turn["text"]) 53 | context.append(turn["text"]) 54 | if len(turn["intent"]) != 0: 55 | last_chit_chat = d[i + 1]["text"] if (i + 1) < len(d) else "" 56 | intent_appear = True 57 | intent = {"type": turn["intent"], "position": i} 58 | whole_transition = ( 59 | last_chit_chat + " " + transition_questions[turn["intent"][0]] 60 | ) 61 | history.append(whole_transition) 62 | context.append(whole_transition) 63 | history = history[-3:] 64 | break 65 | 66 | if intent_appear: 67 | for _ in range(4): 68 | user_checkpoint = "stanleychu2/user_400M" 69 | user_tokenizer = AutoTokenizer.from_pretrained( 70 | user_checkpoint, use_fast=False 71 | ) 72 | user = AutoModelForSeq2SeqLM.from_pretrained(user_checkpoint).to(device) 73 | user.eval() 74 | 75 | prefix = "user: " 76 | inputs = user_tokenizer( 77 | " ".join(history), max_length=128, truncation=True, return_tensors="pt" 78 | ).to(device) 79 | outputs = user.generate( 80 | **inputs, 81 | do_sample=True, 82 | top_k=120, 83 | no_repeat_ngram_size=2, 84 | min_length=1, 85 | max_length=64, 86 | ).squeeze(0) 87 | # 8010 = __END__ 88 | if 8010 in outputs: 89 | print("__END__") 90 | break 91 | utterance = user_tokenizer.decode( 92 | outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True 93 | ).strip() 94 | history.append(utterance) 95 | context.append(utterance) 96 | history = history[-2:] 97 | 98 | system_checkpoint = "stanleychu2/system_400M" 99 | prefix = "sys: " 100 | sys_tokenizer = AutoTokenizer.from_pretrained( 101 | system_checkpoint, use_fast=False 102 | ) 103 | system = AutoModelForSeq2SeqLM.from_pretrained(system_checkpoint).to(device) 104 | system.eval() 105 | 106 | inputs = sys_tokenizer( 107 | " ".join(history), max_length=128, truncation=True, return_tensors="pt" 108 | ).to(device) 109 | outputs = system.generate( 110 | **inputs, 111 | do_sample=True, 112 | num_beams=5, 113 | no_repeat_ngram_size=3, 114 | num_return_sequences=5, 115 | early_stopping=True, 116 | max_length=128, 117 | ).squeeze(0) 118 | utterance = user_tokenizer.decode( 119 | outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True 120 | ).strip() 121 | processed_utterance = re.sub(r"[^\w\s]", "", utterance.lower()) 122 | processed_last_utterance = re.sub(r"[^\w\s]", "", history[-2].lower()) 123 | if ( 124 | jaccard_similarity( 125 | sys_tokenizer.tokenize(processed_last_utterance), 126 | sys_tokenizer.tokenize(processed_utterance), 127 | ) 128 | > 0.4 129 | ): 130 | print("REPEAT:", utterance) 131 | print("REPEAT:", history[-2]) 132 | break 133 | history.append(utterance) 134 | context.append(utterance) 135 | history = history[-2:] 136 | if any([(k in utterance) for k in end_keywords]) or any( 137 | [ 138 | jaccard_similarity( 139 | sys_tokenizer.tokenize(processed_utterance), 140 | sys_tokenizer.tokenize(s), 141 | ) 142 | > 0.2 143 | for s in end_sentences 144 | ] 145 | ): 146 | print("RULE:", utterance) 147 | break 148 | 149 | print(context) 150 | data.append( 151 | {"id": f"simulateTOD_{len(data):04d}", "dialog": context, "intent": intent} 152 | ) 153 | 154 | json.dump(data, output, indent=4) 155 | -------------------------------------------------------------------------------- /dataset.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiuLab/SalesBot/1fdd8713b76dca04a11791519e9a445afe7aa35f/dataset.zip -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: salesbot 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=4.5=1_gnu 7 | - ca-certificates=2022.2.1=h06a4308_0 8 | - certifi=2021.10.8=py38h06a4308_2 9 | - ld_impl_linux-64=2.35.1=h7274673_9 10 | - libffi=3.3=he6710b0_2 11 | - libgcc-ng=9.3.0=h5101ec6_17 12 | - libgomp=9.3.0=h5101ec6_17 13 | - libstdcxx-ng=9.3.0=hd4cf53a_17 14 | - ncurses=6.3=h7f8727e_2 15 | - openssl=1.1.1m=h7f8727e_0 16 | - pip=21.2.4=py38h06a4308_0 17 | - python=3.8.12=h12debd9_0 18 | - readline=8.1.2=h7f8727e_1 19 | - setuptools=58.0.4=py38h06a4308_0 20 | - sqlite=3.37.2=hc218d9a_0 21 | - tk=8.6.11=h1ccaba5_0 22 | - wheel=0.37.1=pyhd3eb1b0_0 23 | - xz=5.2.5=h7b6447c_0 24 | - zlib=1.2.11=h7f8727e_4 25 | - pip: 26 | - absl-py==1.0.0 27 | - aiohttp==3.8.1 28 | - aiosignal==1.2.0 29 | - alabaster==0.7.12 30 | - antlr4-python3-runtime==4.8 31 | - asttokens==2.0.5 32 | - async-timeout==4.0.2 33 | - attrs==20.2.0 34 | - babel==2.9.1 35 | - backcall==0.2.0 36 | - boto3==1.21.13 37 | - botocore==1.24.13 38 | - cachetools==5.0.0 39 | - charset-normalizer==2.0.12 40 | - click==8.0.4 41 | - coloredlogs==15.0.1 42 | - datasets==1.18.3 43 | - decorator==5.1.1 44 | - dill==0.3.4 45 | - docformatter==1.4 46 | - docutils==0.15.2 47 | - emoji==1.6.3 48 | - executing==0.8.3 49 | - fairscale==0.4.5 50 | - filelock==3.6.0 51 | - flake8==4.0.1 52 | - flake8-bugbear==22.1.11 53 | - frozenlist==1.3.0 54 | - fsspec==2022.2.0 55 | - gitdb==4.0.9 56 | - gitdb2==4.0.2 57 | - gitpython==3.1.27 58 | - google-auth==2.6.0 59 | - google-auth-oauthlib==0.4.6 60 | - grpcio==1.44.0 61 | - huggingface-hub==0.4.0 62 | - humanfriendly==10.0 63 | - hydra-core==1.1.1 64 | - idna==3.3 65 | - imagesize==1.3.0 66 | - importlib-metadata==4.2.0 67 | - importlib-resources==5.4.0 68 | - iniconfig==1.1.1 69 | - iopath==0.1.9 70 | - ipython==8.1.1 71 | - jedi==0.18.1 72 | - jinja2==3.0.3 73 | - jmespath==0.10.0 74 | - joblib==1.1.0 75 | - jsonlines==3.0.0 76 | - markdown==3.3.4 77 | - markdown-it-py==0.5.8 78 | - markupsafe==2.1.0 79 | - matplotlib-inline==0.1.3 80 | - mccabe==0.6.1 81 | - mock==4.0.3 82 | - multidict==6.0.2 83 | - multiprocess==0.70.12.2 84 | - myst-parser==0.12.10 85 | - nltk==3.7 86 | - numpy==1.22.2 87 | - oauthlib==3.2.0 88 | - omegaconf==2.1.1 89 | - packaging==21.3 90 | - pandas==1.4.1 91 | - parlai==1.5.1 92 | - parso==0.8.3 93 | - pexpect==4.8.0 94 | - pickleshare==0.7.5 95 | - pillow==9.0.1 96 | - pluggy==1.0.0 97 | - portalocker==2.4.0 98 | - prompt-toolkit==3.0.28 99 | - protobuf==3.19.4 100 | - ptyprocess==0.7.0 101 | - pure-eval==0.2.2 102 | - py==1.11.0 103 | - py-gfm==1.0.2 104 | - py-rouge==1.1 105 | - pyarrow==7.0.0 106 | - pyasn1==0.4.8 107 | - pyasn1-modules==0.2.8 108 | - pycodestyle==2.8.0 109 | - pyflakes==2.4.0 110 | - pygments==2.11.2 111 | - pyparsing==3.0.7 112 | - pytest==7.0.1 113 | - pytest-datadir==1.3.1 114 | - pytest-regressions==2.3.1 115 | - python-dateutil==2.8.2 116 | - pytz==2021.3 117 | - pyyaml==6.0 118 | - pyzmq==22.3.0 119 | - regex==2022.3.2 120 | - requests==2.27.1 121 | - requests-mock==1.9.3 122 | - requests-oauthlib==1.3.1 123 | - rsa==4.8 124 | - s3transfer==0.5.2 125 | - sacremoses==0.0.47 126 | - scikit-learn==1.0.2 127 | - scipy==1.8.0 128 | - sh==1.14.2 129 | - six==1.16.0 130 | - smmap==5.0.0 131 | - snowballstemmer==2.2.0 132 | - sphinx==2.2.2 133 | - sphinx-autodoc-typehints==1.10.3 134 | - sphinx-rtd-theme==1.0.0 135 | - sphinxcontrib-applehelp==1.0.2 136 | - sphinxcontrib-devhelp==1.0.2 137 | - sphinxcontrib-htmlhelp==2.0.0 138 | - sphinxcontrib-jsmath==1.0.1 139 | - sphinxcontrib-qthelp==1.0.3 140 | - sphinxcontrib-serializinghtml==1.1.5 141 | - stack-data==0.2.0 142 | - subword-nmt==0.3.8 143 | - tensorboard==2.8.0 144 | - tensorboard-data-server==0.6.1 145 | - tensorboard-plugin-wit==1.8.1 146 | - tensorboardx==2.5 147 | - threadpoolctl==3.1.0 148 | - tokenizers==0.11.6 149 | - tomli==2.0.1 150 | - torch==1.10.2+cu113 151 | - torchtext==0.11.2 152 | - tornado==6.1 153 | - tqdm==4.62.3 154 | - traitlets==5.1.1 155 | - transformers==4.17.0 156 | - typing-extensions==4.1.1 157 | - unidecode==1.3.3 158 | - untokenize==0.1.1 159 | - urllib3==1.26.8 160 | - wcwidth==0.2.5 161 | - websocket-client==1.3.1 162 | - websocket-server==0.6.4 163 | - werkzeug==2.0.3 164 | - xxhash==3.0.0 165 | - yarl==1.7.2 166 | - zipp==3.7.0 167 | prefix: /home/stanley/miniconda3/envs/salesbot 168 | -------------------------------------------------------------------------------- /img/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MiuLab/SalesBot/1fdd8713b76dca04a11791519e9a445afe7aa35f/img/framework.png -------------------------------------------------------------------------------- /qa_inference.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | from operator import itemgetter 4 | from typing import Dict, List 5 | 6 | from tqdm import tqdm 7 | from transformers import AutoModelForQuestionAnswering, AutoTokenizer, QuestionAnsweringPipeline 8 | 9 | intent_questions: Dict[str, List[str]] = { 10 | "LookupSong": [ 11 | "Is the intent asking about looking up songs ?", 12 | "Is the user asking about looking up songs ?", 13 | "Are there any websites which advertise or advertise songs for free?", 14 | "Is the users question about looking up songs?", 15 | "Is there a way to ask for help with the search of songs?", 16 | "How much time does someone waste searching for songs?", 17 | "Is the user asked about searching up song?", 18 | "Is a user ask about searching up songs?", 19 | "Does the user consider to look up songs?", 20 | ], 21 | "PlaySong": [ 22 | "Is the intent asking about playing songs ?", 23 | "Is the user asking about playing songs ?", 24 | "Is the user asking about playing songs?", 25 | "Is your user asking about playing songs?", 26 | "Is the user asking about playing music?", 27 | "Why does the user ask about playing a song?", 28 | "Is a user asking about playing songs?", 29 | "Does my iPhone asks about playing songs?", 30 | "Does the user ask about playing songs?", 31 | "Is the user planning to playing songs ?", 32 | ], 33 | "LookupMusic": [ 34 | "Is the intent asking about looking up music ?", 35 | "Is the user asking about looking up music ?", 36 | "Are you asking people to look up music?", 37 | "Is the user asking about looking up music?", 38 | "Is the user asking about searching for music?", 39 | "Why does it seem that people are obsessed with looking up music?", 40 | "Is the user asking about searching music?", 41 | "How s/he asked about searching up music?", 42 | "Will the user ask about finding other music?", 43 | "Is it helpful when I ask for help about searching for music on a website?", 44 | "Is it the user asking about looking up songs (or saying songs)?", 45 | "Why is the user so interested in looking up music?", 46 | "Does the user want to look up music ?", 47 | ], 48 | "FindMovies": [ 49 | "Is the intent asking about finding movies ?", 50 | "Is the user asking about finding movies ?", 51 | "Does someone want to find a movie?", 52 | "Does the user ask about finding movies?", 53 | "Why does user ask to find movies?", 54 | "Is the user asking about finding movies?", 55 | "Is the user about looking movies and trawl?", 56 | "Is the user asking about finding movies. Is it true that it is the same question of no different people?", 57 | "When did you start a game and you start asking about movies?", 58 | "What are the users complaints about getting movies?", 59 | "Does the user hope to find movies ?", 60 | ], 61 | "GetTimesForMovie": [ 62 | "Is the intent asking about getting the time for movies ?", 63 | "Is the user asking about getting the time for movies ?", 64 | "What's your question about getting the time for movies?", 65 | "Is my mom asking about getting time for movies?", 66 | "How can I get the time for movies?", 67 | "Is the user asking about getting the time for movies?", 68 | "Can you fix my time problem for movies?", 69 | "What is the thing the user is asking about getting a time in movie or TV watching?", 70 | "How do you determine if you have enough time to watch movies?", 71 | "Is the user asking about getting time for movies?", 72 | "If you are a movie watcher, would you like to give you a good amount of time for your filmmaking needs?", 73 | "Is getting the time for movies the purpose of the user?", 74 | ], 75 | "FindAttractions": [ 76 | "Is the intent asking about finding attractions ?", 77 | "Is the user asking about finding attractions ?", 78 | "Is the user asking about finding attractions?", 79 | "Is the user asking about how to find attractions?", 80 | "How can I find an attraction?", 81 | "What are some of the common questions asked by a visitor about how to find an attraction?", 82 | "Is it the user asking about finding attractions?", 83 | "Is the User Asking about Theme parks?", 84 | "Does the user have trouble finding attractions ?", 85 | ], 86 | } 87 | 88 | sgd_intents: Dict[str, str] = { 89 | f"{intent}-{q}": q 90 | for intent, questions in intent_questions.items() 91 | for q in questions 92 | } 93 | 94 | 95 | def classify_intent(example: Dict) -> Dict: 96 | 97 | instances = [ 98 | (idx, intent, f"yes. no. {turn}", question) 99 | for idx, turn in enumerate(example) 100 | for intent, question in sgd_intents.items() 101 | ] 102 | results = nlp( 103 | question=list(map(itemgetter(-1), instances)), 104 | context=list(map(itemgetter(-2), instances)), 105 | ) 106 | mappings = {i[:2]: r["answer"] for i, r in zip(instances, results)} 107 | new_dialog = [ 108 | { 109 | "id": idx, 110 | "text": turn, 111 | "intent": list( 112 | set( 113 | [ 114 | intent.split("-")[0] 115 | for intent in sgd_intents 116 | if mappings.get((idx, intent), None) == "yes." 117 | ] 118 | ) 119 | ), 120 | } 121 | for idx, turn in enumerate(example) 122 | ] 123 | 124 | return new_dialog 125 | 126 | 127 | parser = ArgumentParser() 128 | parser.add_argument("--device", type=int, default=-1) 129 | parser.add_argument("--data_file", type=str, default="blender.jsonl") 130 | parser.add_argument("--output_file", type=str, default="intent_sample.json") 131 | args = parser.parse_args() 132 | 133 | MODEL_NAME = "adamlin/distilbert-base-cased-sgd_qa-step5000" 134 | REVISION = "negative_sample-questions" 135 | model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME, revision=REVISION) 136 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, revision=REVISION) 137 | nlp = QuestionAnsweringPipeline(model, tokenizer, device=args.device) 138 | 139 | samples = [json.loads(i) for i in open(args.data_file, "r")] 140 | utterances = [] 141 | for s in samples: 142 | tempt = [] 143 | for d in s["dialog"]: 144 | p1, p2 = d[0]["text"], d[1]["text"] 145 | tempt.append(p1) 146 | tempt.append(p2) 147 | utterances.append(tempt) 148 | intent_samples = [] 149 | for e in tqdm(utterances): 150 | intent_samples.append(classify_intent(e)) 151 | 152 | json.dump(intent_samples, open(args.output_file, "w")) 153 | -------------------------------------------------------------------------------- /transition.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | import torch 5 | from tqdm.auto import tqdm 6 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 7 | 8 | if __name__ == "__main__": 9 | device = "cuda" if torch.cuda.is_available() else "cpu" 10 | 11 | t5_transition = [] 12 | with open(sys.argv[1], "r") as f: 13 | for dialog in tqdm(json.load(f)): 14 | position = dialog["intent"]["position"] 15 | 16 | checkpoint = "stanleychu2/t5-transition" 17 | tokenizer = AutoTokenizer.from_pretrained(checkpoint) 18 | 19 | model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device) 20 | model.eval() 21 | 22 | context = " ".join(dialog["dialog"][: position + 1]) 23 | future = dialog["dialog"][position + 2] 24 | example = ( 25 | f" {context} {future} " 26 | ) 27 | inputs = tokenizer( 28 | example, max_length=512, truncation=True, return_tensors="pt" 29 | ).to(device) 30 | 31 | outputs = model.generate( 32 | **inputs, 33 | do_sample=True, 34 | top_k=80, 35 | top_p=0.95, 36 | max_length=64, 37 | repetition_penalty=0.8, 38 | num_return_sequences=4, 39 | ).squeeze(0) 40 | 41 | transition_sentence = [ 42 | tokenizer.decode(i, skip_special_tokens=True) for i in outputs 43 | ] 44 | dialog["dialog"][position + 1] = transition_sentence[0] 45 | dialog["transition_candidates"] = transition_sentence 46 | t5_transition.append(dialog) 47 | 48 | json.dump( 49 | t5_transition, 50 | open(f"{sys.argv[1].split('.')[0]}_transition.json", "w"), 51 | indent=4, 52 | ensure_ascii=False, 53 | ) 54 | --------------------------------------------------------------------------------