├── README.md
├── SGD_delex.zip
├── collect_sgd_intent.py
├── combine_sgd.py
├── combine_simulators.py
├── dataset.zip
├── environment.yml
├── img
└── framework.png
├── qa_inference.py
└── transition.py
/README.md:
--------------------------------------------------------------------------------
1 | # SalesBot: Transitioning from Chit-Chat to Task-Oriented Dialogues
2 |
3 | ## Framework
4 |
5 |
6 |
7 | This paper focuses on investigating the conversations starting from open-domain social chatting and then gradually transitioning to task-oriented purposes, and releases a large-scale dataset with detailed annotations for encouraging this research direction. To achieve this goal, this paper proposes a framework to automatically generate many dialogues without human involvement, in which any powerful open-domain dialogue generation model can be easily leveraged.
8 |
9 | ## Dependency
10 | Check the packages needed or simply run the command
11 | ```console
12 | conda env create -f environment.yml
13 | ```
14 |
15 | ## Data
16 | * selfchat:
17 | ```console
18 | mkdir selfchat
19 | parlai self_chat --model-file zoo:blender/blender_1Bdistill/model --inference nucleus --num-self-chats 20 --task blended_skill_talk --include-personas True --include-initial-utterances True --outfile selfchat/merge_sgd_20.json
20 | parlai self_chat --model-file zoo:blender/blender_1Bdistill/model --inference nucleus --num-self-chats 20 --task blended_skill_talk --include-personas True --include-initial-utterances True --outfile selfchat/simulators_20.json
21 | ```
22 | * intent detection model:
23 | ```console
24 | python3 qa_inference.py --data_file selfchat/merge_sgd_20.jsonl --output_file merge_sgd_intent.json --device 0
25 | python3 qa_inference.py --data_file selfchat/simulators_20.jsonl --output_file simulators_intent.json --device 0
26 | ```
27 | * task-oriented simulators:
28 | ```console
29 | python3 combine_simulators.py simulators_intent.json
30 | ```
31 | * merge SGD:
32 | ```console
33 | # SGD_delex is the version preprocessed by "ACCENTOR: Adding Chit-Chat to Enhance Task-Oriented Dialogues"
34 | unzip SGD_delex
35 | mkdir sgd_intent_dialog
36 | python3 collect_sgd_intent.py SGD_delex
37 | python3 combine_sgd.py merge_sgd_intent.json
38 |
39 | ```
40 | * transition:
41 | ```console
42 | python3 transition.py combine_sgd.json
43 | python3 transition.py combine_simulators.json
44 | ```
45 |
46 | ## Citation
47 |
48 | Please cite our paper if you use SalesBot in your work:
49 |
50 | ```bibtex
51 | @inproceedings{chiu2022salesbot,
52 | title={{SalesBot}: Transitioning from Chit-Chat to Task-Oriented Dialogues},
53 | author={Chiu, Ssu and Li, Maolin and Lin, Yen-Ting and Chen, Yun-Nung},
54 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (ACL)},
55 | year={2022}
56 | }
57 | ```
58 |
--------------------------------------------------------------------------------
/SGD_delex.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiuLab/SalesBot/1fdd8713b76dca04a11791519e9a445afe7aa35f/SGD_delex.zip
--------------------------------------------------------------------------------
/collect_sgd_intent.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 | import sys
4 |
5 | intents = {
6 | "LookupSong": [],
7 | "PlaySong": [],
8 | "LookupMusic": [],
9 | "FindMovies": [],
10 | "GetTimesForMovie": [],
11 | "FindAttractions": [],
12 | }
13 |
14 | for t in ["train", "dev", "test"]:
15 | dialogue_paths = glob.glob(f"{sys.argv[1]}/{t}/dialogues_*")
16 | for p in dialogue_paths:
17 | with open(p, "r") as f:
18 | dialogues = json.load(f)
19 |
20 | for d in dialogues:
21 | intent = None
22 | turns = []
23 | sample = {"intent_pos": 0, "dialogue": []}
24 | for t in range(len(d["turns"])):
25 | if d["turns"][t]["speaker"] == "USER":
26 | for f in d["turns"][t]["frames"]:
27 | if (
28 | f["state"]["active_intent"] in intents.keys()
29 | and intent is None
30 | ):
31 | intent = f["state"]["active_intent"]
32 | sample["intent_pos"] = t
33 | turns.append(d["turns"][t]["utterance"])
34 | else:
35 | turns.append(d["turns"][t]["delex"])
36 | if intent is not None:
37 | sample["dialogue"] = turns
38 | intents[intent].append(sample)
39 |
40 |
41 | for (k, v) in intents.items():
42 | with open(f"sgd_intent_dialog/{k}_delex.json", "w") as f:
43 | json.dump(v, f, ensure_ascii=False, indent=4)
44 |
--------------------------------------------------------------------------------
/combine_sgd.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import sys
4 | from typing import Dict
5 |
6 | import torch
7 | from tqdm.auto import tqdm
8 |
9 | persona = json.load(open(sys.argv[1], "r"))
10 | intent_description: Dict[str, str] = {
11 | "LookupSong": "search for a song",
12 | "PlaySong": "play the selected song on the device",
13 | "LookupMusic": "search for a song based on the name and optionally other attributes",
14 | "FindMovies": "find movies by genre and optionally director",
15 | "GetTimesForMovie": "get show times for a movie at a location on a given date",
16 | "FindAttractions": "browse attractions in a given city",
17 | }
18 | output = open("combine_sgd.json", "w")
19 | transition_questions: Dict[str, str] = {
20 | k: f"Do you want to {v}?" for (k, v) in intent_description.items()
21 | }
22 | device = "cuda" if torch.cuda.is_available() else "cpu"
23 | intent = {}
24 | intents = {}
25 | data = []
26 | random.seed(26)
27 |
28 | for k in intent_description.keys():
29 | with open(f"sgd_intent_dialog/{k}_delex.json", "r") as f:
30 | intents[k] = json.load(f)
31 | random.shuffle(intents[k])
32 |
33 | for d in tqdm(persona):
34 | intent_appear = False
35 | context = []
36 | for i, turn in enumerate(d):
37 | context.append(turn["text"])
38 | if len(turn["intent"]) != 0:
39 | last_chit_chat = d[i + 1]["text"] if (i + 1) < len(d) else ""
40 | intent_appear = True
41 | intent = {"type": turn["intent"], "position": i}
42 | whole_transition = (
43 | last_chit_chat + " " + transition_questions[turn["intent"][0]]
44 | )
45 | context.append(whole_transition)
46 | break
47 |
48 | if intent_appear and len(intents[intent["type"][0]]) != 0:
49 | sample = intents[intent["type"][0]][0]
50 | intents[intent["type"][0]] = intents[intent["type"][0]][1:]
51 | dialog = sample["dialogue"][sample["intent_pos"] :]
52 | context += dialog
53 | data.append(
54 | {"id": f"merge_{len(data):04d}", "dialog": context, "intent": intent}
55 | )
56 | json.dump(data, output, indent=4)
57 |
--------------------------------------------------------------------------------
/combine_simulators.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | import sys
4 | from typing import Dict
5 |
6 | import torch
7 | from tqdm.auto import tqdm
8 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9 |
10 |
11 | def jaccard_similarity(list1, list2):
12 | s1 = set(list1)
13 | s2 = set(list2)
14 | return float(len(s1.intersection(s2)) / len(s1.union(s2)))
15 |
16 |
17 | persona = json.load(open(sys.argv[1], "r"))
18 | intent_description: Dict[str, str] = {
19 | "LookupSong": "search for a song",
20 | "PlaySong": "play the selected song on the device",
21 | "LookupMusic": "search for a song based on the name and optionally other attributes",
22 | "FindMovies": "find movies by genre and optionally director",
23 | "GetTimesForMovie": "get show times for a movie at a location on a given date",
24 | "FindAttractions": "browse attractions in a given city",
25 | }
26 | output = open("combine_simulators.json", "w")
27 | transition_questions: Dict[str, str] = {
28 | k: f"Do you want to {v}?" for (k, v) in intent_description.items()
29 | }
30 | device = "cuda" if torch.cuda.is_available() else "cpu"
31 | end_keywords = ["goodbye", "bye"]
32 | end_sentences = [
33 | "have a great day",
34 | "have a nice day",
35 | "have a good day",
36 | "have a wonderful day",
37 | "enjoy your day",
38 | "have a good one",
39 | "have a good time",
40 | "enjoy the rest of your day",
41 | "have a fantastic day",
42 | "i am glad i could help have a nice day",
43 | ]
44 | intent = {}
45 | data = []
46 |
47 | for d in tqdm(persona):
48 | intent_appear = False
49 | history = []
50 | context = []
51 | for i, turn in enumerate(d):
52 | history.append(turn["text"])
53 | context.append(turn["text"])
54 | if len(turn["intent"]) != 0:
55 | last_chit_chat = d[i + 1]["text"] if (i + 1) < len(d) else ""
56 | intent_appear = True
57 | intent = {"type": turn["intent"], "position": i}
58 | whole_transition = (
59 | last_chit_chat + " " + transition_questions[turn["intent"][0]]
60 | )
61 | history.append(whole_transition)
62 | context.append(whole_transition)
63 | history = history[-3:]
64 | break
65 |
66 | if intent_appear:
67 | for _ in range(4):
68 | user_checkpoint = "stanleychu2/user_400M"
69 | user_tokenizer = AutoTokenizer.from_pretrained(
70 | user_checkpoint, use_fast=False
71 | )
72 | user = AutoModelForSeq2SeqLM.from_pretrained(user_checkpoint).to(device)
73 | user.eval()
74 |
75 | prefix = "user: "
76 | inputs = user_tokenizer(
77 | " ".join(history), max_length=128, truncation=True, return_tensors="pt"
78 | ).to(device)
79 | outputs = user.generate(
80 | **inputs,
81 | do_sample=True,
82 | top_k=120,
83 | no_repeat_ngram_size=2,
84 | min_length=1,
85 | max_length=64,
86 | ).squeeze(0)
87 | # 8010 = __END__
88 | if 8010 in outputs:
89 | print("__END__")
90 | break
91 | utterance = user_tokenizer.decode(
92 | outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True
93 | ).strip()
94 | history.append(utterance)
95 | context.append(utterance)
96 | history = history[-2:]
97 |
98 | system_checkpoint = "stanleychu2/system_400M"
99 | prefix = "sys: "
100 | sys_tokenizer = AutoTokenizer.from_pretrained(
101 | system_checkpoint, use_fast=False
102 | )
103 | system = AutoModelForSeq2SeqLM.from_pretrained(system_checkpoint).to(device)
104 | system.eval()
105 |
106 | inputs = sys_tokenizer(
107 | " ".join(history), max_length=128, truncation=True, return_tensors="pt"
108 | ).to(device)
109 | outputs = system.generate(
110 | **inputs,
111 | do_sample=True,
112 | num_beams=5,
113 | no_repeat_ngram_size=3,
114 | num_return_sequences=5,
115 | early_stopping=True,
116 | max_length=128,
117 | ).squeeze(0)
118 | utterance = user_tokenizer.decode(
119 | outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True
120 | ).strip()
121 | processed_utterance = re.sub(r"[^\w\s]", "", utterance.lower())
122 | processed_last_utterance = re.sub(r"[^\w\s]", "", history[-2].lower())
123 | if (
124 | jaccard_similarity(
125 | sys_tokenizer.tokenize(processed_last_utterance),
126 | sys_tokenizer.tokenize(processed_utterance),
127 | )
128 | > 0.4
129 | ):
130 | print("REPEAT:", utterance)
131 | print("REPEAT:", history[-2])
132 | break
133 | history.append(utterance)
134 | context.append(utterance)
135 | history = history[-2:]
136 | if any([(k in utterance) for k in end_keywords]) or any(
137 | [
138 | jaccard_similarity(
139 | sys_tokenizer.tokenize(processed_utterance),
140 | sys_tokenizer.tokenize(s),
141 | )
142 | > 0.2
143 | for s in end_sentences
144 | ]
145 | ):
146 | print("RULE:", utterance)
147 | break
148 |
149 | print(context)
150 | data.append(
151 | {"id": f"simulateTOD_{len(data):04d}", "dialog": context, "intent": intent}
152 | )
153 |
154 | json.dump(data, output, indent=4)
155 |
--------------------------------------------------------------------------------
/dataset.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiuLab/SalesBot/1fdd8713b76dca04a11791519e9a445afe7aa35f/dataset.zip
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: salesbot
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - _openmp_mutex=4.5=1_gnu
7 | - ca-certificates=2022.2.1=h06a4308_0
8 | - certifi=2021.10.8=py38h06a4308_2
9 | - ld_impl_linux-64=2.35.1=h7274673_9
10 | - libffi=3.3=he6710b0_2
11 | - libgcc-ng=9.3.0=h5101ec6_17
12 | - libgomp=9.3.0=h5101ec6_17
13 | - libstdcxx-ng=9.3.0=hd4cf53a_17
14 | - ncurses=6.3=h7f8727e_2
15 | - openssl=1.1.1m=h7f8727e_0
16 | - pip=21.2.4=py38h06a4308_0
17 | - python=3.8.12=h12debd9_0
18 | - readline=8.1.2=h7f8727e_1
19 | - setuptools=58.0.4=py38h06a4308_0
20 | - sqlite=3.37.2=hc218d9a_0
21 | - tk=8.6.11=h1ccaba5_0
22 | - wheel=0.37.1=pyhd3eb1b0_0
23 | - xz=5.2.5=h7b6447c_0
24 | - zlib=1.2.11=h7f8727e_4
25 | - pip:
26 | - absl-py==1.0.0
27 | - aiohttp==3.8.1
28 | - aiosignal==1.2.0
29 | - alabaster==0.7.12
30 | - antlr4-python3-runtime==4.8
31 | - asttokens==2.0.5
32 | - async-timeout==4.0.2
33 | - attrs==20.2.0
34 | - babel==2.9.1
35 | - backcall==0.2.0
36 | - boto3==1.21.13
37 | - botocore==1.24.13
38 | - cachetools==5.0.0
39 | - charset-normalizer==2.0.12
40 | - click==8.0.4
41 | - coloredlogs==15.0.1
42 | - datasets==1.18.3
43 | - decorator==5.1.1
44 | - dill==0.3.4
45 | - docformatter==1.4
46 | - docutils==0.15.2
47 | - emoji==1.6.3
48 | - executing==0.8.3
49 | - fairscale==0.4.5
50 | - filelock==3.6.0
51 | - flake8==4.0.1
52 | - flake8-bugbear==22.1.11
53 | - frozenlist==1.3.0
54 | - fsspec==2022.2.0
55 | - gitdb==4.0.9
56 | - gitdb2==4.0.2
57 | - gitpython==3.1.27
58 | - google-auth==2.6.0
59 | - google-auth-oauthlib==0.4.6
60 | - grpcio==1.44.0
61 | - huggingface-hub==0.4.0
62 | - humanfriendly==10.0
63 | - hydra-core==1.1.1
64 | - idna==3.3
65 | - imagesize==1.3.0
66 | - importlib-metadata==4.2.0
67 | - importlib-resources==5.4.0
68 | - iniconfig==1.1.1
69 | - iopath==0.1.9
70 | - ipython==8.1.1
71 | - jedi==0.18.1
72 | - jinja2==3.0.3
73 | - jmespath==0.10.0
74 | - joblib==1.1.0
75 | - jsonlines==3.0.0
76 | - markdown==3.3.4
77 | - markdown-it-py==0.5.8
78 | - markupsafe==2.1.0
79 | - matplotlib-inline==0.1.3
80 | - mccabe==0.6.1
81 | - mock==4.0.3
82 | - multidict==6.0.2
83 | - multiprocess==0.70.12.2
84 | - myst-parser==0.12.10
85 | - nltk==3.7
86 | - numpy==1.22.2
87 | - oauthlib==3.2.0
88 | - omegaconf==2.1.1
89 | - packaging==21.3
90 | - pandas==1.4.1
91 | - parlai==1.5.1
92 | - parso==0.8.3
93 | - pexpect==4.8.0
94 | - pickleshare==0.7.5
95 | - pillow==9.0.1
96 | - pluggy==1.0.0
97 | - portalocker==2.4.0
98 | - prompt-toolkit==3.0.28
99 | - protobuf==3.19.4
100 | - ptyprocess==0.7.0
101 | - pure-eval==0.2.2
102 | - py==1.11.0
103 | - py-gfm==1.0.2
104 | - py-rouge==1.1
105 | - pyarrow==7.0.0
106 | - pyasn1==0.4.8
107 | - pyasn1-modules==0.2.8
108 | - pycodestyle==2.8.0
109 | - pyflakes==2.4.0
110 | - pygments==2.11.2
111 | - pyparsing==3.0.7
112 | - pytest==7.0.1
113 | - pytest-datadir==1.3.1
114 | - pytest-regressions==2.3.1
115 | - python-dateutil==2.8.2
116 | - pytz==2021.3
117 | - pyyaml==6.0
118 | - pyzmq==22.3.0
119 | - regex==2022.3.2
120 | - requests==2.27.1
121 | - requests-mock==1.9.3
122 | - requests-oauthlib==1.3.1
123 | - rsa==4.8
124 | - s3transfer==0.5.2
125 | - sacremoses==0.0.47
126 | - scikit-learn==1.0.2
127 | - scipy==1.8.0
128 | - sh==1.14.2
129 | - six==1.16.0
130 | - smmap==5.0.0
131 | - snowballstemmer==2.2.0
132 | - sphinx==2.2.2
133 | - sphinx-autodoc-typehints==1.10.3
134 | - sphinx-rtd-theme==1.0.0
135 | - sphinxcontrib-applehelp==1.0.2
136 | - sphinxcontrib-devhelp==1.0.2
137 | - sphinxcontrib-htmlhelp==2.0.0
138 | - sphinxcontrib-jsmath==1.0.1
139 | - sphinxcontrib-qthelp==1.0.3
140 | - sphinxcontrib-serializinghtml==1.1.5
141 | - stack-data==0.2.0
142 | - subword-nmt==0.3.8
143 | - tensorboard==2.8.0
144 | - tensorboard-data-server==0.6.1
145 | - tensorboard-plugin-wit==1.8.1
146 | - tensorboardx==2.5
147 | - threadpoolctl==3.1.0
148 | - tokenizers==0.11.6
149 | - tomli==2.0.1
150 | - torch==1.10.2+cu113
151 | - torchtext==0.11.2
152 | - tornado==6.1
153 | - tqdm==4.62.3
154 | - traitlets==5.1.1
155 | - transformers==4.17.0
156 | - typing-extensions==4.1.1
157 | - unidecode==1.3.3
158 | - untokenize==0.1.1
159 | - urllib3==1.26.8
160 | - wcwidth==0.2.5
161 | - websocket-client==1.3.1
162 | - websocket-server==0.6.4
163 | - werkzeug==2.0.3
164 | - xxhash==3.0.0
165 | - yarl==1.7.2
166 | - zipp==3.7.0
167 | prefix: /home/stanley/miniconda3/envs/salesbot
168 |
--------------------------------------------------------------------------------
/img/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MiuLab/SalesBot/1fdd8713b76dca04a11791519e9a445afe7aa35f/img/framework.png
--------------------------------------------------------------------------------
/qa_inference.py:
--------------------------------------------------------------------------------
1 | import json
2 | from argparse import ArgumentParser
3 | from operator import itemgetter
4 | from typing import Dict, List
5 |
6 | from tqdm import tqdm
7 | from transformers import AutoModelForQuestionAnswering, AutoTokenizer, QuestionAnsweringPipeline
8 |
9 | intent_questions: Dict[str, List[str]] = {
10 | "LookupSong": [
11 | "Is the intent asking about looking up songs ?",
12 | "Is the user asking about looking up songs ?",
13 | "Are there any websites which advertise or advertise songs for free?",
14 | "Is the users question about looking up songs?",
15 | "Is there a way to ask for help with the search of songs?",
16 | "How much time does someone waste searching for songs?",
17 | "Is the user asked about searching up song?",
18 | "Is a user ask about searching up songs?",
19 | "Does the user consider to look up songs?",
20 | ],
21 | "PlaySong": [
22 | "Is the intent asking about playing songs ?",
23 | "Is the user asking about playing songs ?",
24 | "Is the user asking about playing songs?",
25 | "Is your user asking about playing songs?",
26 | "Is the user asking about playing music?",
27 | "Why does the user ask about playing a song?",
28 | "Is a user asking about playing songs?",
29 | "Does my iPhone asks about playing songs?",
30 | "Does the user ask about playing songs?",
31 | "Is the user planning to playing songs ?",
32 | ],
33 | "LookupMusic": [
34 | "Is the intent asking about looking up music ?",
35 | "Is the user asking about looking up music ?",
36 | "Are you asking people to look up music?",
37 | "Is the user asking about looking up music?",
38 | "Is the user asking about searching for music?",
39 | "Why does it seem that people are obsessed with looking up music?",
40 | "Is the user asking about searching music?",
41 | "How s/he asked about searching up music?",
42 | "Will the user ask about finding other music?",
43 | "Is it helpful when I ask for help about searching for music on a website?",
44 | "Is it the user asking about looking up songs (or saying songs)?",
45 | "Why is the user so interested in looking up music?",
46 | "Does the user want to look up music ?",
47 | ],
48 | "FindMovies": [
49 | "Is the intent asking about finding movies ?",
50 | "Is the user asking about finding movies ?",
51 | "Does someone want to find a movie?",
52 | "Does the user ask about finding movies?",
53 | "Why does user ask to find movies?",
54 | "Is the user asking about finding movies?",
55 | "Is the user about looking movies and trawl?",
56 | "Is the user asking about finding movies. Is it true that it is the same question of no different people?",
57 | "When did you start a game and you start asking about movies?",
58 | "What are the users complaints about getting movies?",
59 | "Does the user hope to find movies ?",
60 | ],
61 | "GetTimesForMovie": [
62 | "Is the intent asking about getting the time for movies ?",
63 | "Is the user asking about getting the time for movies ?",
64 | "What's your question about getting the time for movies?",
65 | "Is my mom asking about getting time for movies?",
66 | "How can I get the time for movies?",
67 | "Is the user asking about getting the time for movies?",
68 | "Can you fix my time problem for movies?",
69 | "What is the thing the user is asking about getting a time in movie or TV watching?",
70 | "How do you determine if you have enough time to watch movies?",
71 | "Is the user asking about getting time for movies?",
72 | "If you are a movie watcher, would you like to give you a good amount of time for your filmmaking needs?",
73 | "Is getting the time for movies the purpose of the user?",
74 | ],
75 | "FindAttractions": [
76 | "Is the intent asking about finding attractions ?",
77 | "Is the user asking about finding attractions ?",
78 | "Is the user asking about finding attractions?",
79 | "Is the user asking about how to find attractions?",
80 | "How can I find an attraction?",
81 | "What are some of the common questions asked by a visitor about how to find an attraction?",
82 | "Is it the user asking about finding attractions?",
83 | "Is the User Asking about Theme parks?",
84 | "Does the user have trouble finding attractions ?",
85 | ],
86 | }
87 |
88 | sgd_intents: Dict[str, str] = {
89 | f"{intent}-{q}": q
90 | for intent, questions in intent_questions.items()
91 | for q in questions
92 | }
93 |
94 |
95 | def classify_intent(example: Dict) -> Dict:
96 |
97 | instances = [
98 | (idx, intent, f"yes. no. {turn}", question)
99 | for idx, turn in enumerate(example)
100 | for intent, question in sgd_intents.items()
101 | ]
102 | results = nlp(
103 | question=list(map(itemgetter(-1), instances)),
104 | context=list(map(itemgetter(-2), instances)),
105 | )
106 | mappings = {i[:2]: r["answer"] for i, r in zip(instances, results)}
107 | new_dialog = [
108 | {
109 | "id": idx,
110 | "text": turn,
111 | "intent": list(
112 | set(
113 | [
114 | intent.split("-")[0]
115 | for intent in sgd_intents
116 | if mappings.get((idx, intent), None) == "yes."
117 | ]
118 | )
119 | ),
120 | }
121 | for idx, turn in enumerate(example)
122 | ]
123 |
124 | return new_dialog
125 |
126 |
127 | parser = ArgumentParser()
128 | parser.add_argument("--device", type=int, default=-1)
129 | parser.add_argument("--data_file", type=str, default="blender.jsonl")
130 | parser.add_argument("--output_file", type=str, default="intent_sample.json")
131 | args = parser.parse_args()
132 |
133 | MODEL_NAME = "adamlin/distilbert-base-cased-sgd_qa-step5000"
134 | REVISION = "negative_sample-questions"
135 | model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME, revision=REVISION)
136 | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, revision=REVISION)
137 | nlp = QuestionAnsweringPipeline(model, tokenizer, device=args.device)
138 |
139 | samples = [json.loads(i) for i in open(args.data_file, "r")]
140 | utterances = []
141 | for s in samples:
142 | tempt = []
143 | for d in s["dialog"]:
144 | p1, p2 = d[0]["text"], d[1]["text"]
145 | tempt.append(p1)
146 | tempt.append(p2)
147 | utterances.append(tempt)
148 | intent_samples = []
149 | for e in tqdm(utterances):
150 | intent_samples.append(classify_intent(e))
151 |
152 | json.dump(intent_samples, open(args.output_file, "w"))
153 |
--------------------------------------------------------------------------------
/transition.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 |
4 | import torch
5 | from tqdm.auto import tqdm
6 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7 |
8 | if __name__ == "__main__":
9 | device = "cuda" if torch.cuda.is_available() else "cpu"
10 |
11 | t5_transition = []
12 | with open(sys.argv[1], "r") as f:
13 | for dialog in tqdm(json.load(f)):
14 | position = dialog["intent"]["position"]
15 |
16 | checkpoint = "stanleychu2/t5-transition"
17 | tokenizer = AutoTokenizer.from_pretrained(checkpoint)
18 |
19 | model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device)
20 | model.eval()
21 |
22 | context = " ".join(dialog["dialog"][: position + 1])
23 | future = dialog["dialog"][position + 2]
24 | example = (
25 | f" {context} {future} "
26 | )
27 | inputs = tokenizer(
28 | example, max_length=512, truncation=True, return_tensors="pt"
29 | ).to(device)
30 |
31 | outputs = model.generate(
32 | **inputs,
33 | do_sample=True,
34 | top_k=80,
35 | top_p=0.95,
36 | max_length=64,
37 | repetition_penalty=0.8,
38 | num_return_sequences=4,
39 | ).squeeze(0)
40 |
41 | transition_sentence = [
42 | tokenizer.decode(i, skip_special_tokens=True) for i in outputs
43 | ]
44 | dialog["dialog"][position + 1] = transition_sentence[0]
45 | dialog["transition_candidates"] = transition_sentence
46 | t5_transition.append(dialog)
47 |
48 | json.dump(
49 | t5_transition,
50 | open(f"{sys.argv[1].split('.')[0]}_transition.json", "w"),
51 | indent=4,
52 | ensure_ascii=False,
53 | )
54 |
--------------------------------------------------------------------------------