├── dialog_simulator ├── memories │ └── mini_set │ │ ├── memories_metadata_all.json │ │ └── mini_set_0_memory_graph.json ├── SimulatorBase.py ├── DummyMemoryDialogModel.py ├── get_user_utterances.py ├── constants.py ├── InteractiveDialogHandler.py ├── merge_data_json.py ├── merge_synth_and_appen.py ├── MemoryDialogSimulator.py ├── GoalGenerator.py ├── MemoryDialogModel.py ├── main.py ├── utils.py └── MemoryServiceAPI.py ├── models ├── gpt2_mm │ ├── requirements.txt │ ├── LICENSE │ ├── run_me.sh │ ├── README.md │ ├── utils │ │ ├── create_result_jsons.py │ │ ├── preprocess_memory_dataset.py │ │ └── extract_memory_features.py │ ├── dataset_memory.py │ └── dataset.py └── gpt2_text │ ├── run_evaluate_gpt2.sh │ ├── run_generate_gpt2.sh │ ├── run_train_gpt2.sh │ ├── gpt2_dst │ └── scripts │ │ ├── evaluate.py │ │ ├── get_best_model.py │ │ ├── preprocess_input.py │ │ ├── evaluate_response.py │ │ ├── reformat_dst_response_outputs.py │ │ └── run_generation.py │ ├── run_preprocess_gpt2.sh │ └── utils │ └── response_evaluation.py ├── teaser_memory_dialog.png ├── data ├── mem_dials_test.json ├── mem_dials_val.json ├── mem_dials_merged.json ├── mem_dials_train.json ├── memory_may21_v1_100graphs.json └── mscoco_memory_graphs_1k.json ├── .gitattributes ├── CONTRIBUTING.md ├── .gitignore ├── CODE_OF_CONDUCT.md └── README.md /dialog_simulator/memories/mini_set/memories_metadata_all.json: -------------------------------------------------------------------------------- 1 | {} -------------------------------------------------------------------------------- /models/gpt2_mm/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.0.1 2 | pytorch-ignite==0.2.1 3 | transformers==2.1.1 4 | tqdm==4.36.1 5 | 6 | -------------------------------------------------------------------------------- /teaser_memory_dialog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/comet_memory_dialog/HEAD/teaser_memory_dialog.png -------------------------------------------------------------------------------- /data/mem_dials_test.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ecbe949172b8878737e2806bd4ba8369c5d3b473f336e4b0ddec86ec2fdf1b7a 3 | size 8252698 4 | -------------------------------------------------------------------------------- /data/mem_dials_val.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:dbffd2c3a225a0bb566670e3221ecf737903048a9b42b932433cb7c76f024e29 3 | size 8364828 4 | -------------------------------------------------------------------------------- /data/mem_dials_merged.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8c42a83c9d13f358895f62332713d6c9fd19cc31c5efbb21c5e1ec9fe49a6634 3 | size 156707220 4 | -------------------------------------------------------------------------------- /data/mem_dials_train.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:79adbb17be1565988c0afec8e72ee74b2ad014c2cbc86a5f4f81d91579be7ff9 3 | size 38741569 4 | -------------------------------------------------------------------------------- /data/memory_may21_v1_100graphs.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:dcd5307759d0b4f80842773f50af1096f8e8df04f105f012b8aa02810158fb63 3 | size 22364689 4 | -------------------------------------------------------------------------------- /data/mscoco_memory_graphs_1k.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7f45f24f7497bb29abaab75c04fe91959f0ef20b7cfa02df13ebc7ed0beeb242 3 | size 225170054 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | data/memory_may21_v1_100graphs.json filter=lfs diff=lfs merge=lfs -text 2 | data/mscoco_memory_graphs_1k.json filter=lfs diff=lfs merge=lfs -text 3 | data/mem_dials_merged.json filter=lfs diff=lfs merge=lfs -text 4 | data/mem_dials_test.json filter=lfs diff=lfs merge=lfs -text 5 | data/mem_dials_train.json filter=lfs diff=lfs merge=lfs -text 6 | data/mem_dials_val.json filter=lfs diff=lfs merge=lfs -text 7 | -------------------------------------------------------------------------------- /models/gpt2_text/run_evaluate_gpt2.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved 2 | #!/bin/bash 3 | if [[ $# -lt 1 ]] 4 | then 5 | PATH_DIR=$(realpath .) 6 | else 7 | PATH_DIR=$(realpath "$1") 8 | fi 9 | 10 | 11 | python -m gpt2_dst.scripts.evaluate_dst_response \ 12 | --input_path_target="${PATH_DIR}"/gpt2_dst/data/mem_dials_test_target.txt \ 13 | --input_path_predicted="${PATH_DIR}"/gpt2_dst/results/mem_dials_test_predicted.txt \ 14 | --compute_bert_score 15 | -------------------------------------------------------------------------------- /models/gpt2_text/run_generate_gpt2.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved 2 | #!/bin/bash 3 | if [[ $# -lt 1 ]] 4 | then 5 | PATH_DIR=$(realpath .) 6 | else 7 | PATH_DIR=$(realpath "$1") 8 | fi 9 | 10 | # Generate sentences (Furniture, multi-modal) 11 | python3 -m gpt2_dst.scripts.run_generation \ 12 | --model_type=gpt2 \ 13 | --model_name_or_path="${PATH_DIR}"/gpt2_dst/save/model_run0/checkpoint-23000 \ 14 | --num_return_sequences=1 \ 15 | --length=100 \ 16 | --stop_token='' \ 17 | --prompts_from_file="${PATH_DIR}"/gpt2_dst/data/mem_dials_test_predict.txt \ 18 | --path_output="${PATH_DIR}"/gpt2_dst/results/mem_dials_test_predicted_run0.txt 19 | -------------------------------------------------------------------------------- /dialog_simulator/SimulatorBase.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | #!/usr/bin/env python3 5 | from Data import MemoryDialog, Goal, Frame 6 | from typing import List 7 | 8 | 9 | class SimulatorBase: 10 | def register_memory_service_api(self, memory_service_api): 11 | self.memory_service_api = memory_service_api 12 | 13 | def fit_goal_to_intent(self, args): 14 | # Define the goal to intent mapping behavior 15 | pass 16 | 17 | def is_servable(self, goal: Goal) -> bool: 18 | # Check whether this simulator can serve the input goal. 19 | pass 20 | 21 | def generate_nlu_label(self, goal: Goal, context: MemoryDialog) -> Frame: 22 | # Need to define this behavior first e.g. as a config, a model, etc. 23 | pass 24 | 25 | def generate_uttr(self, nlu_label: Frame) -> str: 26 | pass 27 | -------------------------------------------------------------------------------- /models/gpt2_text/run_train_gpt2.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved 2 | #!/bin/bash 3 | if [[ $# -lt 1 ]] 4 | then 5 | PATH_DIR=$(realpath .) 6 | else 7 | PATH_DIR=$(realpath "$1") 8 | fi 9 | 10 | # Train (multi-modal) 11 | python3 -m gpt2_dst.scripts.run_language_modeling \ 12 | --output_dir="${PATH_DIR}"/gpt2_dst/save/model_run2 \ 13 | --model_type=gpt2 \ 14 | --model_name_or_path=gpt2 \ 15 | --line_by_line \ 16 | --add_special_tokens="${PATH_DIR}"/gpt2_dst/data/mem_special_tokens.json \ 17 | --do_train \ 18 | --train_data_file="${PATH_DIR}"/gpt2_dst/data/mem_dials_train_target.txt \ 19 | --do_eval --eval_all_checkpoints \ 20 | --eval_data_file="${PATH_DIR}"/gpt2_dst/data/mem_dials_val_target.txt \ 21 | --num_train_epochs=10 \ 22 | --overwrite_output_dir \ 23 | --per_gpu_train_batch_size=4 \ 24 | --per_gpu_eval_batch_size=4 25 | -------------------------------------------------------------------------------- /models/gpt2_mm/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 ICTNLP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /dialog_simulator/DummyMemoryDialogModel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | #!/usr/bin/env python3 4 | from constants import API_CALL_TYPE, TurnSpeaker, DialogAct 5 | from Data import Turn, Frame, ActAttributes, MemoryDialog, APIResponse, APIRequest 6 | from typing import Dict, Tuple 7 | 8 | 9 | class DummyMemoryDialogModel(MemoryDialogModelBase): 10 | def __init__(self, *args, **kwargs): 11 | super(DummyMemoryDialogModel, self).__init__(*args, **kwargs) 12 | 13 | def predict_api_call(self, query: str) -> Dict: 14 | return { 15 | "call_type": API_CALL_TYPE.SEARCH, 16 | "dialog_act": DialogAct.UNKNOWN, 17 | "slot_values": {}, 18 | "request_slots": [], 19 | "memories": [], 20 | } 21 | 22 | def predict_assistant_response( 23 | self, query: str, api_response: APIResponse, memory_dialog: MemoryDialog 24 | ): 25 | 26 | response_str = ( 27 | "User asked:" 28 | + query 29 | + ". Dialog history: " 30 | + str(memory_dialog) 31 | + ". API response:" 32 | + str(api_response) 33 | ) 34 | 35 | return { 36 | "uttr": response_str, 37 | "dialog_act": DialogAct.UNKNOWN, 38 | "slot_values": {}, 39 | "request_slots": [], 40 | "memories": api_response.results.get("retrieved_memories"), 41 | } 42 | -------------------------------------------------------------------------------- /dialog_simulator/get_user_utterances.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | #!/usr/bin/env python3 5 | """ 6 | Merges multiple batches of SIMMC 2.0 files into one, 7 | and also outputs train, dev, devtest, and test sets. 8 | """ 9 | import os 10 | import json 11 | import csv 12 | import random 13 | import pickle 14 | import numpy as np 15 | from utils import load_data_pickle 16 | 17 | 18 | if __name__ == "__main__": 19 | random.seed(0) 20 | np.random.seed(0) 21 | 22 | # Paths for merge 23 | path_in_pickle = "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/final_data/mem_dials_merged.p" 24 | path_out_tsv = "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/final_data/user_utterances.tsv" 25 | 26 | mm_dialogs = [] 27 | mm_dialogs.extend(load_data_pickle(path_in_pickle)) 28 | 29 | # Output 30 | print("Total: %d dialogs" % len(mm_dialogs)) 31 | 32 | with open(path_out_tsv, "w", newline="") as csvfile: 33 | writer = csv.writer(csvfile, delimiter="\t", quotechar="'") 34 | writer.writerow(["dialog_id", "turn_id", "user_utterance"]) 35 | 36 | for i, mm_dialog in enumerate(mm_dialogs): 37 | user_turns = mm_dialog.dialog.user_turns 38 | dialog_id = mm_dialog.dialog.idx 39 | 40 | for j, user_turn in enumerate(user_turns): 41 | user_uttr = user_turn.frames[-1].uttr 42 | 43 | if user_uttr not in set(["N/A", "NA"]): 44 | row = [dialog_id, j, user_uttr] 45 | writer.writerow(row) 46 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to comet\_memory\_dialog 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | ... (in particular how this is synced with internal changes to the project) 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `main`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Meta's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 29 | disclosure of security bugs. In those cases, please go through the process 30 | outlined on that page and do not file a public issue. 31 | 32 | ## Coding Style 33 | * 4 spaces for indentation rather than tabs 34 | * 80 character line length 35 | * ... 36 | 37 | ## License 38 | By contributing to comet\_memory\_dialog, you agree that your contributions will be licensed 39 | under the LICENSE file in the root directory of this source tree. 40 | -------------------------------------------------------------------------------- /models/gpt2_text/gpt2_dst/scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved 2 | #!/usr/bin/env python3 3 | """ 4 | Scripts for evaluating the GPT-2 DST model predictions. 5 | 6 | First, we parse the line-by-line stringified format into 7 | the structured DST output. 8 | 9 | We then run the main DST Evaluation script to get results. 10 | """ 11 | import argparse 12 | import json 13 | from gpt2_dst.utils.convert import parse_flattened_results_from_file 14 | from utils.evaluate_dst import evaluate_from_flat_list 15 | 16 | 17 | if __name__ == "__main__": 18 | # Parse input args 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument( 21 | "--input_path_target", help="path for target, line-separated format (.txt)" 22 | ) 23 | parser.add_argument( 24 | "--input_path_predicted", 25 | help="path for model prediction output, line-separated format (.txt)", 26 | ) 27 | parser.add_argument( 28 | "--output_path_report", help="path for saving evaluation summary (.json)" 29 | ) 30 | 31 | args = parser.parse_args() 32 | input_path_target = args.input_path_target 33 | input_path_predicted = args.input_path_predicted 34 | output_path_report = args.output_path_report 35 | 36 | # Convert the data from the GPT-2 friendly format to JSON 37 | list_target = parse_flattened_results_from_file(input_path_target) 38 | list_predicted = parse_flattened_results_from_file(input_path_predicted) 39 | 40 | # Evaluate 41 | report = evaluate_from_flat_list(list_target, list_predicted) 42 | 43 | # Save report 44 | with open(output_path_report, "w") as f_out: 45 | json.dump(report, f_out) 46 | -------------------------------------------------------------------------------- /models/gpt2_text/gpt2_dst/scripts/get_best_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved 2 | #! /usr/bin/env python 3 | """ 4 | Gets the best model given all the checkpoints. 5 | 6 | Author(s): Satwik Kottur 7 | """ 8 | 9 | from __future__ import absolute_import, division, print_function, unicode_literals 10 | import argparse 11 | import os 12 | import re 13 | 14 | 15 | def main(args): 16 | for folder_name in args["model_checkpoint_folder"]: 17 | listing = [ii for ii in os.listdir(folder_name) if "checkpoint-" in ii] 18 | valid_metrics = {} 19 | for checkpoint_name in listing: 20 | checkpoint_folder = os.path.join(folder_name, checkpoint_name) 21 | eval_path = os.path.join(checkpoint_folder, "eval_results.txt") 22 | epoch_search = re.search(r"checkpoint-(\d*)", checkpoint_name) 23 | with open(eval_path, "r") as file_id: 24 | result = [ii.strip("\n") for ii in file_id.readlines()][0] 25 | perplexity_search = re.search(r"([0-9\.]+)", result) 26 | 27 | # NOTE: Does not handle error conditions. 28 | if perplexity_search is None or epoch_search is None: 29 | print(f"Missing epoch: {checkpoint_name}") 30 | continue 31 | 32 | perplexity = float(perplexity_search.group(1)) 33 | epoch = int(epoch_search.group(1)) 34 | valid_metrics[epoch] = perplexity 35 | 36 | best_epoch, _ = sorted(valid_metrics.items(), key=lambda x: x[1])[0] 37 | best_folder = os.path.join(folder_name, f"checkpoint-{best_epoch}") 38 | print(best_folder) 39 | print("." * 50) 40 | 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser(description=__doc__) 44 | parser.add_argument( 45 | "--model_checkpoint_folder", 46 | nargs="+", 47 | required=True, 48 | help="List of model checkpoint folders", 49 | ) 50 | 51 | try: 52 | parsed_args = vars(parser.parse_args()) 53 | except (IOError) as msg: 54 | parser.error(str(msg)) 55 | main(parsed_args) 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Ignore the data folder. 132 | data 133 | -------------------------------------------------------------------------------- /models/gpt2_text/gpt2_dst/scripts/preprocess_input.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved 4 | Scripts for converting the main SIMMC datasets (.JSON format) 5 | into the line-by-line stringified format (and back). 6 | 7 | The reformatted data is used as input for the GPT-2 based 8 | DST model baseline. 9 | """ 10 | from gpt2_dst.utils.convert import convert_json_to_flattened 11 | import argparse 12 | 13 | if __name__ == "__main__": 14 | # Parse input args 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "--input_path_json", help="input path to the original dialog data" 18 | ) 19 | parser.add_argument("--output_path_predict", help="output path for model input") 20 | parser.add_argument("--output_path_target", help="output path for full target") 21 | parser.add_argument( 22 | "--input_path_special_tokens", 23 | help="input path for special tokens. blank if not provided", 24 | default="", 25 | ) 26 | parser.add_argument( 27 | "--output_path_special_tokens", 28 | help="output path for special tokens. blank if not saving", 29 | default="", 30 | ) 31 | parser.add_argument( 32 | "--len_context", 33 | help="# of turns to include as dialog context", 34 | type=int, 35 | default=2, 36 | ) 37 | parser.add_argument( 38 | "--use_multimodal_contexts", 39 | help="determine whether to use the multimodal contexts each turn", 40 | type=int, 41 | default=1, 42 | ) 43 | parser.add_argument( 44 | "--no_belief_states", 45 | dest="use_belief_states", 46 | action="store_false", 47 | default=True, 48 | help="determine whether to use belief state for each turn", 49 | ) 50 | 51 | args = parser.parse_args() 52 | input_path_json = args.input_path_json 53 | output_path_predict = args.output_path_predict 54 | output_path_target = args.output_path_target 55 | input_path_special_tokens = args.input_path_special_tokens 56 | output_path_special_tokens = args.output_path_special_tokens 57 | len_context = args.len_context 58 | use_multimodal_contexts = bool(args.use_multimodal_contexts) 59 | 60 | # DEBUG: 61 | print("Belief states: {}".format(args.use_belief_states)) 62 | 63 | # Convert the data into GPT-2 friendly format 64 | convert_json_to_flattened( 65 | input_path_json, 66 | output_path_predict, 67 | output_path_target, 68 | input_path_special_tokens=input_path_special_tokens, 69 | output_path_special_tokens=output_path_special_tokens, 70 | len_context=len_context, 71 | use_multimodal_contexts=use_multimodal_contexts, 72 | use_belief_states=args.use_belief_states, 73 | ) 74 | -------------------------------------------------------------------------------- /models/gpt2_mm/run_me.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved 2 | FEATURES="butd" 3 | LOG_PATH="logs/" 4 | MODE="train" 5 | GPU_ID=0 6 | 7 | # Test flags. 8 | MODEL_EPOCH=6 9 | OUTPUT_RESULT_FILE="$LOG_PATH/model_ep${MODEL_EPOCH}_generate.json" 10 | 11 | # Visual features. 12 | FEATURE_PATH="data/memory_features/butd_10w_features/" 13 | VISUAL_FEATURE_SIZE=2053 14 | VISUAL_FEATURE_WIDTH=10 15 | 16 | case $MODE in 17 | "train") 18 | echo "Training.." 19 | # Train Memory Dialog Model. 20 | CUDA_VISIBLE_DEVICES=$GPU_ID \ 21 | python train.py --log_path $LOG_PATH \ 22 | --train_path "data/gpt2_data/mem_dials_gpt2_train.json" \ 23 | --valid_path "data/gpt2_data/mem_dials_gpt2_val.json" \ 24 | --special_tokens_path "data/gpt2_data/mem_dials_gpt2_special_tokens.json" \ 25 | --train_batch_size 8 \ 26 | --predict_belief_state \ 27 | --n_epochs 20 \ 28 | --feature_path $FEATURE_PATH \ 29 | --visual_feature_size $VISUAL_FEATURE_SIZE \ 30 | --visual_feature_width $VISUAL_FEATURE_WIDTH 31 | ;; 32 | 33 | "generate") 34 | # Generate responses from Memory Dialog Model. 35 | CUDA_VISIBLE_DEVICES=$GPU_ID \ 36 | python generate.py \ 37 | --model_checkpoint $LOG_PATH \ 38 | --model_epoch $MODEL_EPOCH \ 39 | --test_set "data/gpt2_data/mem_dials_gpt2_test.json" \ 40 | --special_tokens_path "data/gpt2_data/mem_dials_gpt2_special_tokens.json" \ 41 | --feature_path $FEATURE_PATH \ 42 | --visual_feature_size $VISUAL_FEATURE_SIZE \ 43 | --visual_feature_width $VISUAL_FEATURE_WIDTH \ 44 | --output $OUTPUT_RESULT_FILE \ 45 | --max_len 100 46 | ;; 47 | 48 | "compile") 49 | # Compile results and create JSON files to run standard evaluation. 50 | python utils/create_result_jsons.py \ 51 | --memory_test_json "data/mem_dials_test.json" \ 52 | --model_output_json $OUTPUT_RESULT_FILE 53 | ;; 54 | esac 55 | 56 | 57 | FEATURE_PATH="/data/img_feats1.0/visdial_img_feat.lmdb" 58 | # Extracting visual features (BUTD features). 59 | # python utils/extract_memory_features.py \ 60 | # --input_dialog_json data/mem_dials_merged.json \ 61 | # --input_memory_json \ 62 | # data/memory_may21_v1_100graphs.json \ 63 | # data/mscoco_memory_graphs_1k.json \ 64 | # --input_feature_path $FEATURE_PATH \ 65 | # --max_bboxes 10 \ 66 | # --feature_save_path data/memory_features/butd_10w_features/ \ 67 | # --feature_type butd 68 | 69 | 70 | # Preprocessing the dataset. 71 | # python utils/preprocess_memory_dataset.py \ 72 | # --train_json_path "data/mem_dials_train.json" \ 73 | # --unseen_json_path \ 74 | # "data/mem_dials_val.json" \ 75 | # "data/mem_dials_test.json" \ 76 | # --save_folder "data/gpt2_data/" 77 | -------------------------------------------------------------------------------- /models/gpt2_text/gpt2_dst/scripts/evaluate_response.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved 2 | #!/usr/bin/env python3 3 | """ 4 | Scripts for evaluating the GPT-2 DST model predictions. 5 | 6 | First, we parse the line-by-line stringified format into responses 7 | and compute BLEU score. 8 | """ 9 | import argparse 10 | import json 11 | from gpt2_dst.utils.convert import parse_flattened_results_from_file 12 | from utils.evaluate_dst import evaluate_from_flat_list 13 | 14 | import nltk 15 | import numpy as np 16 | 17 | 18 | def normalize_sentence(sentence): 19 | """Normalize the sentences and tokenize.""" 20 | return nltk.tokenize.word_tokenize(sentence.lower()) 21 | 22 | 23 | def parse_response_from_file(input_path): 24 | """Parses the response from a flattened file. 25 | 26 | Args: 27 | input_path: Path to read the responses from. 28 | """ 29 | lines = [] 30 | with open(input_path, "r") as file_id: 31 | for ii in file_id.readlines(): 32 | split_line = ii.split("", 1) 33 | lines.append( 34 | (split_line[0].strip("\n"), split_line[1].strip("\n").strip("")) 35 | ) 36 | return lines 37 | 38 | 39 | if __name__ == "__main__": 40 | # Parse input args 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument( 43 | "--input_path_target", help="path for target, line-separated format (.txt)" 44 | ) 45 | parser.add_argument( 46 | "--input_path_predicted", 47 | help="path for model prediction output, line-separated format (.txt)", 48 | ) 49 | parser.add_argument( 50 | "--output_path_report", help="path for saving evaluation summary (.json)" 51 | ) 52 | 53 | args = parser.parse_args() 54 | input_path_target = args.input_path_target 55 | input_path_predicted = args.input_path_predicted 56 | output_path_report = args.output_path_report 57 | 58 | # Convert the data from the GPT-2 friendly format to JSON 59 | list_target = parse_response_from_file(input_path_target) 60 | list_predicted = parse_response_from_file(input_path_predicted) 61 | 62 | # Compute BLEU scores. 63 | bleu_scores = [] 64 | # Smoothing function. 65 | chencherry = nltk.translate.bleu_score.SmoothingFunction() 66 | 67 | for response, gt_response in zip(list_predicted, list_target): 68 | assert response[0] == gt_response[0], "Input contexts do not match!" 69 | bleu_score = nltk.translate.bleu_score.sentence_bleu( 70 | [normalize_sentence(gt_response[1])], 71 | normalize_sentence(response[1]), 72 | smoothing_function=chencherry.method7, 73 | ) 74 | bleu_scores.append(bleu_score) 75 | print( 76 | "BLEU score: {} +- {}".format( 77 | np.mean(bleu_scores), np.std(bleu_scores) / np.sqrt(len(bleu_scores)) 78 | ) 79 | ) 80 | -------------------------------------------------------------------------------- /models/gpt2_text/run_preprocess_gpt2.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved 2 | #!/bin/bash 3 | if [[ $# -lt 1 ]] 4 | then 5 | PATH_DIR=$(realpath .) 6 | PATH_DATA_DIR=$(realpath ../dialog_simulator/final_data) 7 | else 8 | PATH_DIR=$(realpath "$1") 9 | PATH_DATA_DIR=$(realpath "$2") 10 | fi 11 | 12 | # Train split 13 | python3 -m gpt2_dst.scripts.preprocess_input \ 14 | --input_path_json="${PATH_DATA_DIR}"/mem_dials_train_v2.json \ 15 | --output_path_predict="${PATH_DIR}"/gpt2_dst/data/mem_dials_train_predict.txt \ 16 | --output_path_target="${PATH_DIR}"/gpt2_dst/data/mem_dials_train_target.txt \ 17 | --len_context=2 \ 18 | --use_multimodal_contexts=1 \ 19 | --output_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/mem_special_tokens.json 20 | 21 | # --use_multimodal_contexts=1 \ 22 | # Dev split 23 | python3 -m gpt2_dst.scripts.preprocess_input \ 24 | --input_path_json="${PATH_DATA_DIR}"/mem_dials_val_v2.json \ 25 | --output_path_predict="${PATH_DIR}"/gpt2_dst/data/mem_dials_val_predict.txt \ 26 | --output_path_target="${PATH_DIR}"/gpt2_dst/data/mem_dials_val_target.txt \ 27 | --len_context=2 \ 28 | --use_multimodal_contexts=1 \ 29 | --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/mem_special_tokens.json \ 30 | --output_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/mem_special_tokens.json \ 31 | 32 | # Devtest split 33 | python3 -m gpt2_dst.scripts.preprocess_input \ 34 | --input_path_json="${PATH_DATA_DIR}"/mem_dials_test_v2.json \ 35 | --output_path_predict="${PATH_DIR}"/gpt2_dst/data/mem_dials_test_predict.txt \ 36 | --output_path_target="${PATH_DIR}"/gpt2_dst/data/mem_dials_test_target.txt \ 37 | --len_context=2 \ 38 | --use_multimodal_contexts=1 \ 39 | --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/mem_special_tokens.json \ 40 | --output_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/mem_special_tokens.json \ 41 | 42 | # Test split 43 | # python3 -m gpt2_dst.scripts.preprocess_input \ 44 | # --input_path_json="${PATH_DATA_DIR}"/mem_dials_test.json \ 45 | # --output_path_predict="${PATH_DIR}"/gpt2_dst/data/mem_dials_test_predict.txt \ 46 | # --output_path_target="${PATH_DIR}"/gpt2_dst/data/mem_dials_test_target.txt \ 47 | # --len_context=2 \ 48 | # --use_multimodal_contexts=1 \ 49 | # --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/mem_special_tokens.json \ 50 | # --output_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/mem_special_tokens.json \ 51 | 52 | # Mini split 53 | # python3 -m gpt2_dst.scripts.preprocess_input \ 54 | # --input_path_json="${PATH_DATA_DIR}"/mem_dials_mini.json \ 55 | # --output_path_predict="${PATH_DIR}"/gpt2_dst/data/mem_dials_mini_predict.txt \ 56 | # --output_path_target="${PATH_DIR}"/gpt2_dst/data/mem_dials_mini_target.txt \ 57 | # --len_context=2 \ 58 | # --use_multimodal_contexts=1 \ 59 | # --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/mem_special_tokens.json \ 60 | # --output_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/mem_special_tokens.json \ 61 | -------------------------------------------------------------------------------- /dialog_simulator/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | #!/usr/bin/env python3 5 | from enum import Enum 6 | 7 | 8 | class GoalType(Enum): 9 | UNKNOWN = "unknown" 10 | SEARCH = "search" 11 | REFINE_SEARCH = "refine_search" 12 | GET_RELATED = "get_related" 13 | GET_INFO = "get_info" 14 | GET_AGGREGATED_INFO = "get_aggregated_info" 15 | SHARE = "share" 16 | CHITCHAT = "chitchat" 17 | 18 | 19 | class DialogAct(Enum): 20 | UNKNOWN = "unknown" 21 | 22 | INFORM_GET = "INFORM:GET" 23 | INFORM_REFINE = "INFORM:REFINE" 24 | INFORM_PREFER = "INFORM:PREFER" 25 | INFORM_DISPREFER = "INFORM:DISPREFER" 26 | INFORM_SHARE = "INFORM:SHARE" 27 | INFORM_DISAMBIGUATE = "INFORM:DISAMBIGUATE" 28 | INFORM_CHITCHAT = "INFORM:CHITCHAT" 29 | 30 | REQUEST_GET = "REQUEST:GET" 31 | REQUEST_REFINE = "REQUEST:REFINE" 32 | REQUEST_PREFER = "REQUEST:PREFER" 33 | REQUEST_DISPREFER = "REQUEST:DISPREFER" 34 | REQUEST_SHARE = "REQUEST:SHARE" 35 | REQUEST_DISAMBIGUATE = "REQUEST:DISAMBIGUATE" 36 | 37 | CONFIRM_GET = "CONFIRM:GET" 38 | CONFIRM_REFINE = "CONFIRM:REFINE" 39 | CONFIRM_PREFER = "CONFIRM:PREFER" 40 | CONFIRM_DISPREFER = "CONFIRM:DISPREFER" 41 | CONFIRM_SHARE = "CONFIRM:SHARE" 42 | CONFIRM_DISAMBIGUATE = "CONFIRM:DISAMBIGUATE" 43 | 44 | PROMPT_GET = "PROMPT:GET" 45 | PROMPT_REFINE = "PROMPT:REFINE" 46 | PROMPT_PREFER = "PROMPT:PREFER" 47 | PROMPT_DISPREFER = "PROMPT:DISPREFER" 48 | PROMPT_SHARE = "PROMPT:SHARE" 49 | PROMPT_DISAMBIGUATE = "PROMPT:DISAMBIGUATE" 50 | 51 | ASK_GET = "ASK:GET" 52 | ASK_REFINE = "ASK:REFINE" 53 | ASK_PREFER = "ASK:PREFER" 54 | ASK_DISPREFER = "ASK:DISPREFER" 55 | ASK_SHARE = "ASK:SHARE" 56 | ASK_DISAMBIGUATE = "ASK:DISAMBIGUATE" 57 | 58 | 59 | class GoalMemoryRefType(Enum): 60 | PREV_TURN = "PREV_TURN" 61 | DIALOG = "DIALOG" 62 | GRAPH = "GRAPH" 63 | NOT_SPECIFIED = "Not Specified" 64 | 65 | 66 | class ObjectRefType(Enum): 67 | R1 = "R1" # Unique object in the scene 68 | R2 = "R2" # Object in the dialog history, same view point 69 | R3 = "R3" # Object in the dialog history, previous view point 70 | NOT_SPECIFIED = "Not Specified" 71 | 72 | 73 | class API_STATUS(Enum): 74 | SEARCH_FOUND = "Search Founud" 75 | SEARCH_NOT_FOUND = "Search Not Founud" 76 | INFO_FOUND = "Info Found" 77 | INFO_NOT_FOUND = "Info Not Found" 78 | SHARED = "Shared" 79 | 80 | 81 | class API_CALL_TYPE(Enum): 82 | SEARCH = "Search" 83 | REFINE_SEARCH = "Refine Search" 84 | GET_INFO = "Get Info" 85 | SHARE = "Share" 86 | GET_RELATED = "Get Related" 87 | UNDEFINED = "Undefined" 88 | 89 | 90 | class TurnSpeaker(Enum): 91 | USER = "User" 92 | ASSISTANT = "Assistant" 93 | 94 | 95 | numeric_slots = {"time"} 96 | 97 | non_visual_slots = { 98 | "location", 99 | "time", 100 | } 101 | 102 | visual_slots = {"participant", "activity"} 103 | 104 | all_slots = {"time", "location", "participant", "activity"} 105 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /models/gpt2_mm/README.md: -------------------------------------------------------------------------------- 1 | # GPT-2 (MM) 2 | This is the code for the GPT-2 model used in [Navigating Connected Memories with a Task-oriented Dialog System][code]. It is based on the AAAI2020-DSTC8-AVSD paper [Bridging Text and Video: A Universal Multimodal Transformer for Video-Audio Scene-Aware Dialog.](). 3 | 4 | 5 | ## How to Run 6 | 7 | **Requirements** 8 | 9 | ``` 10 | Python. 3.6 11 | torch==1.0.1 12 | pytorch-ignite==0.2.1 13 | transformers==2.1.1 14 | tqdm==4.36.1 15 | ``` 16 | 17 | ```shell 18 | pip install -r requirements.txt 19 | ``` 20 | 21 | **Data** 22 | 23 | Create a soft link to the `../../data` folder in the main folder here as `data/`. 24 | Please see `run_me.sh` for example of how to run the code. 25 | 26 | 27 | 28 | 29 | **Step 1: Preprocess the dataset** 30 | 31 | ```shell 32 | # Preprocessing the dataset. 33 | python utils/preprocess_memory_dataset.py \ 34 | --train_json_path "data/mem_dials_train.json" \ 35 | --unseen_json_path \ 36 | "data/mem_dials_val.json" \ 37 | "data/mem_dials_test.json" \ 38 | --save_folder "data/gpt2_data/" 39 | ``` 40 | 41 | **Step 2: Extracting the image features** 42 | 43 | We use this [repository](https://github.com/vmurahari3/visdial-bert#download-preprocessed-data) to download the image features. 44 | 45 | ```shell 46 | FEATURE_PATH="/data/img_feats1.0/visdial_img_feat.lmdb" 47 | # Extracting visual features (BUTD features). 48 | python utils/extract_memory_features.py \ 49 | --input_dialog_json data/mem_dials_merged.json \ 50 | --input_memory_json \ 51 | data/memory_may21_v1_100graphs.json \ 52 | data/mscoco_memory_graphs_1k.json \ 53 | --input_feature_path $FEATURE_PATH \ 54 | --max_bboxes 10 \ 55 | --feature_save_path data/memory_features/butd_10w_features/ \ 56 | --feature_type butd 57 | ``` 58 | 59 | **Training** 60 | 61 | ```shell 62 | FEATURES="butd" 63 | LOG_PATH="logs/" 64 | # Visual features. 65 | FEATURE_PATH="data/memory_features/butd_10w_features/" 66 | VISUAL_FEATURE_SIZE=2053 67 | VISUAL_FEATURE_WIDTH=10 68 | 69 | python train.py --log_path $LOG_PATH \ 70 | --train_path "data/gpt2_data/mem_dials_gpt2_train.json" \ 71 | --valid_path "data/gpt2_data/mem_dials_gpt2_val.json" \ 72 | --special_tokens_path "data/gpt2_data/mem_dials_gpt2_special_tokens.json" \ 73 | --train_batch_size 8 \ 74 | --predict_belief_state \ 75 | --n_epochs 20 \ 76 | --feature_path $FEATURE_PATH \ 77 | --visual_feature_size $VISUAL_FEATURE_SIZE \ 78 | --visual_feature_width $VISUAL_FEATURE_WIDTH 79 | ``` 80 | 81 | **Evaluation** 82 | 83 | ```shell 84 | python generate.py \ 85 | --model_checkpoint $LOG_PATH \ 86 | --model_epoch $MODEL_EPOCH \ 87 | --test_set "data/gpt2_data/mem_dials_gpt2_test.json" \ 88 | --special_tokens_path "data/gpt2_data/mem_dials_gpt2_special_tokens.json" \ 89 | --feature_path $FEATURE_PATH \ 90 | --visual_feature_size $VISUAL_FEATURE_SIZE \ 91 | --visual_feature_width $VISUAL_FEATURE_WIDTH \ 92 | --output \ 93 | --max_len 100 94 | ``` 95 | 96 | **Compiling Results** 97 | 98 | ```shell 99 | python utils/create_result_jsons.py \ 100 | --memory_test_json "data/mem_dials_test.json" \ 101 | --model_output_json $OUTPUT_RESULT_FILE 102 | ``` 103 | 104 | 105 | ## Citation 106 | 107 | If you use this code in your research, please cite our paper and the original AAAI 2020 DSTC8 workshop 108 | paper. 109 | 110 | ``` 111 | @inproceedings{moon-kottur-2022-navigating, 112 | title = "Navigating Connected Memories with a Task-oriented Dialog System", 113 | author = "Moon, Seungwhan and 114 | Kottur, Satwik Kottur and 115 | Geramifard, Alborz and 116 | Damavandi, Babak", 117 | booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing", 118 | month = dec, 119 | year = "2022", 120 | address = "Online and Abu Dhabi, United Arab Emirates", 121 | publisher = "Association for Computational Linguistics", 122 | } 123 | ``` 124 | 125 | ``` 126 | @article{li2020bridging, 127 | title={Bridging Text and Video: A Universal Multimodal Transformer for Video-Audio Scene-Aware Dialog}, 128 | author={Zekang Li and Zongjia Li and Jinchao Zhang and Yang Feng and Cheng Niu and Jie Zhou}, 129 | year={2020}, 130 | eprint={2002.00163}, 131 | archivePrefix={arXiv}, 132 | journal={CoRR}, 133 | primaryClass={cs.CL} 134 | } 135 | ``` 136 | 137 | [code]:https://github.com/facebookresearch/comet_memory_dialog -------------------------------------------------------------------------------- /dialog_simulator/memories/mini_set/mini_set_0_memory_graph.json: -------------------------------------------------------------------------------- 1 | { 2 | "memory_graph_id":0, 3 | "memories":[ 4 | { 5 | "memory_id":0, 6 | "time":"2021-04-10 10:00:00", 7 | "start_time":"2021-04-10 10:00:00", 8 | "end_time":"2021-04-10 10:30:00", 9 | "narrations":"fun day for skiing.", 10 | "media":[ 11 | { 12 | "media_id":1000, 13 | "type":"video" 14 | } 15 | ], 16 | "location":{ 17 | "gps":{ 18 | "lat":40.00, 19 | "lon":100.00 20 | }, 21 | "geo_tag":{ 22 | "place":"Summit at Snoqualmie", 23 | "city":"Seattle", 24 | "state":"Washington", 25 | "country":"USA" 26 | } 27 | }, 28 | "participant":[ 29 | { 30 | "name":"John", 31 | "memory_graph_id":1 32 | }, 33 | { 34 | "name":"Mary", 35 | "memory_graph_id":2 36 | } 37 | ], 38 | "activity":[ 39 | { 40 | "activity_name":"skiing" 41 | } 42 | ], 43 | "object":[ 44 | 45 | ] 46 | }, 47 | { 48 | "memory_id":1, 49 | "time":"2021-03-10 10:00:00", 50 | "start_time":"2021-03-10 10:00:00", 51 | "end_time":"2021-03-10 10:30:00", 52 | "narrations":"fun baseball day.", 53 | "media":[ 54 | { 55 | "media_id":1001, 56 | "type":"video" 57 | } 58 | ], 59 | "location":{ 60 | "gps":{ 61 | "lat":41.00, 62 | "lon":110.00 63 | }, 64 | "geo_tag":{ 65 | "place":"T-Mobile Park", 66 | "city":"Seattle", 67 | "state":"Washington", 68 | "country":"USA" 69 | } 70 | }, 71 | "participant":[ 72 | { 73 | "name":"Mary", 74 | "memory_graph_id":2 75 | }, 76 | { 77 | "name":"Jane", 78 | "memory_graph_id":3 79 | } 80 | ], 81 | "activity":[ 82 | { 83 | "activity_name":"baseball" 84 | } 85 | ], 86 | "object":[ 87 | 88 | ] 89 | }, 90 | { 91 | "memory_id":3, 92 | "time":"2020-03-10 10:00:00", 93 | "start_time":"2020-03-10 10:00:00", 94 | "end_time":"2020-03-10 10:30:00", 95 | "narrations":"fun soccer day.", 96 | "media":[ 97 | { 98 | "media_id":1002, 99 | "type":"video" 100 | } 101 | ], 102 | "location":{ 103 | "gps":{ 104 | "lat":41.00, 105 | "lon":110.00 106 | }, 107 | "geo_tag":{ 108 | "place":"CenturyLink Park", 109 | "city":"Seattle", 110 | "state":"Washington", 111 | "country":"USA" 112 | } 113 | }, 114 | "participant":[ 115 | { 116 | "name":"John", 117 | "memory_graph_id":1 118 | }, 119 | { 120 | "name":"Jane", 121 | "memory_graph_id":3 122 | } 123 | ], 124 | "activity":[ 125 | { 126 | "activity_name":"soccer" 127 | } 128 | ], 129 | "object":[ 130 | 131 | ] 132 | }, 133 | { 134 | "memory_id":4, 135 | "time":"2020-05-10 10:00:00", 136 | "start_time":"2020-05-10 10:00:00", 137 | "end_time":"2020-05-10 10:30:00", 138 | "narrations":"fun skiing day.", 139 | "media":[ 140 | { 141 | "media_id":1002, 142 | "type":"video" 143 | } 144 | ], 145 | "location":{ 146 | "gps":{ 147 | "lat":45.00, 148 | "lon":115.00 149 | }, 150 | "geo_tag":{ 151 | "place":"Stubier", 152 | "city":"Innsbruck", 153 | "state":"", 154 | "country":"Austria" 155 | } 156 | }, 157 | "participant":[ 158 | { 159 | "name":"Jane", 160 | "memory_graph_id":3 161 | } 162 | ], 163 | "activity":[ 164 | { 165 | "activity_name":"skiing" 166 | }, 167 | { 168 | "activity_name":"hiking" 169 | } 170 | ], 171 | "object":[ 172 | 173 | ] 174 | } 175 | ] 176 | } 177 | -------------------------------------------------------------------------------- /dialog_simulator/InteractiveDialogHandler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | #!/usr/bin/env python3 5 | import random 6 | import json 7 | from MemoryDialogModel import PilotMemoryDialogModel 8 | from Data import MemoryGraph, MemoryDialog, Turn 9 | from MemoryServiceAPI import MemoryServiceAPI 10 | 11 | import sys 12 | 13 | sys.path.append("/Users/shanemoon/workspace/memory_dialog/models/") 14 | from gpt2_dst.scripts.run_generation import load_model 15 | 16 | 17 | class InteractiveDialogHandler: 18 | def __init__(self, *args, **kwargs): 19 | self.model = kwargs.pop("model", None) 20 | self.memory_graph = kwargs.pop("memory_graph", None) 21 | self.api = kwargs.pop("api", None) 22 | 23 | # Start an empty dialog data 24 | self.memory_dialog = MemoryDialog(memory_graph=self.memory_graph) 25 | self.memory_dialog.initialize() 26 | 27 | def execute_turn(self, user_query: str) -> Turn: 28 | """ 29 | Given user_query, construct an API call, 30 | get the API response, and return an Assistant Turn. 31 | """ 32 | 33 | # Construct the API request 34 | try: 35 | user_turn, api_request = self.model.construct_api_request( 36 | user_query, self.memory_dialog 37 | ) 38 | print("============== API Request ==============") 39 | print(api_request) 40 | print("=========================================\n") 41 | 42 | # Call API to get responses back 43 | api_response = self.api.call_api(api_request) 44 | print("============== API Response ==============") 45 | print(api_response) 46 | print("==========================================\n") 47 | 48 | # Update the display based on the API results 49 | self.model.update_display(api_response) 50 | 51 | # Generate an Assistant response based on the API response 52 | assistant_turn = self.model.construct_assistant_response( 53 | user_query, api_request, api_response, self.memory_dialog 54 | ) 55 | print("============== Assistant Response ==============") 56 | print(assistant_turn) 57 | print("================================================\n") 58 | 59 | # Update the memory_dialog with the new user and assistant turns 60 | self.memory_dialog.dialog.add_user_turn(user_turn) 61 | self.memory_dialog.dialog.add_asst_turn(assistant_turn) 62 | 63 | # Update the model 64 | self.model.prev_asst_uttr = assistant_turn.frames[-1].uttr 65 | self.model.turn_id += 1 66 | 67 | return assistant_turn 68 | 69 | except: 70 | return None 71 | 72 | def run_loop_command_prompt(self): 73 | 74 | while True: 75 | print() 76 | user_query = input(">> Enter your query (or type quit): ") 77 | if user_query == "quit": 78 | break 79 | 80 | response = self.execute_turn(user_query=user_query) 81 | 82 | 83 | if __name__ == "__main__": 84 | # Define paths 85 | # path_memory_graph_list = '/Users/shanemoon/workspace/memory_dialog/dialog_simulator/memories/final/mscoco_memory_graphs_1k.json' 86 | path_memory_graph_list = "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/memories/final/mscoco_memory_graphs_mini.json" 87 | path_model = ( 88 | "/Users/shanemoon/workspace/memory_dialog/models/gpt2_dst/save/model_v2" 89 | ) 90 | path_parameter_ontology = "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/final_data/all_parameter_ontology.json" 91 | 92 | # Hyperparameters for the demo 93 | random_memory_graph = False 94 | 95 | # Load parameters 96 | memory_graph_list = json.load(open(path_memory_graph_list, "r")) 97 | memory_graph_bank = {} 98 | 99 | for memory_graph in memory_graph_list: 100 | memory_graph_id = memory_graph["memory_graph_id"] 101 | 102 | for i in range(len(memory_graph["memories"])): 103 | memory_graph["memories"][i]["memory_graph_id"] = memory_graph_id 104 | 105 | memory_graph_bank[memory_graph_id] = memory_graph 106 | 107 | parameter_ontology = json.load(open(path_parameter_ontology, "r")) 108 | 109 | # Select a Memory Graph 110 | if random_memory_graph: 111 | memory_graph = MemoryGraph( 112 | data=memory_graph_bank[random.choice(list(memory_graph_bank.keys()))] 113 | ) 114 | 115 | else: 116 | memory_graph_id = "RbXAfFDz8r72" 117 | memory_graph = MemoryGraph(data=memory_graph_bank[memory_graph_id]) 118 | 119 | # Load the model parameters 120 | gpt2_model, tokenizer, length = load_model( 121 | model_type="gpt2", model_name_or_path=path_model, device="cpu", length=150 122 | ) 123 | 124 | # Instsantiate the dialog handler 125 | model = PilotMemoryDialogModel( 126 | model=gpt2_model, 127 | tokenizer=tokenizer, 128 | length=length, 129 | parameter_ontology=parameter_ontology, 130 | ) 131 | 132 | api = MemoryServiceAPI() 133 | dialog_handler = InteractiveDialogHandler( 134 | model=model, memory_graph=memory_graph, api=api 135 | ) 136 | 137 | # Run loop 138 | dialog_handler.run_loop_command_prompt() 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Navigating Connected Memories with a Task-oriented Dialog System 2 | 3 | This repository contains the code to reproduce results from the following paper: 4 | 5 | **Navigating Connected Memories with a Task-oriented Dialog System** 6 | Seungwhan Moon\*, Satwik Kottur\*, Alborz Geramifard, Babak Damavandi 7 | [[PDF][paper_workplace]][[Github][github]] 8 | *Empirical Methods in Natural Language Processing (EMNLP), 2022* 9 | \*=equal contribution 10 | 11 | ### Abstract 12 | 13 |
14 | Teaser Figure 19 |
20 | 21 | 22 | Recent years have seen an increasing trend in the volume of personal media captured by users, thanks 23 | to the advent of smartphones and smart glasses, resulting in large media collections. 24 | Despite conversation being an intuitive human-computer interface, current efforts focus mostly 25 | on single-shot natural language based media retrieval to aid users query their media and 26 | re-live their memories. This severely limits the search functionality as users can neither ask 27 | follow-up queries nor obtain information without first formulating a single-turn query. 28 | 29 | In this work, we propose *dialogs for connected memories* as a powerful tool to empower 30 | users to search their media collection through a multi-turn, interactive conversation. 31 | Towards this, we collect a new task-oriented dialog dataset COMET, which contains $11.5k$ 32 | user↔assistant dialogs (totalling $103k$ utterances), grounded in simulated personal memory graphs. 33 | We employ a resource-efficient, two-phase data collection pipeline that uses: 34 | (1) a novel multimodal dialog simulator that generates synthetic dialog flows grounded in 35 | memory graphs, and, 36 | (2) manual paraphrasing to obtain natural language utterances. 37 | We analyze \dn, formulate four main tasks to benchmark meaningful progress, and adopt 38 | state-of-the-art language models as strong baselines, in order to highlight the 39 | multimodal challenges captured by our dataset. 40 | Our code \& data will be made publicly available. 41 | 42 | 43 | 44 | ### Code Structure 45 | 46 | The code is organized into two folders: 47 | 48 | **A. Multimodal Dialog Simulator** (`dialog_simulator/`): 49 | Conditioned on the memory graphs generated, the multimoda dialog simulator produces synthetic 50 | dialog flows between a user and an assistant. 51 | These flows are later paraphrased using human annotators to draw from natural language utterances. 52 | 53 | * `AssistantSimulator.py` 54 | * `Data.py` 55 | * `DummyMemoryDialogModel.py` 56 | * `GoalGenerator.py` 57 | * `InteractiveDialogHandler.py` 58 | * `MemoryDialogModel.py` 59 | * `MemoryDialogSimulator.py` 60 | * `MemoryServiceAPI.py` 61 | * `SimulatorBase.py` 62 | * `UserSimulator.py` 63 | * `constants.py` 64 | * `get_user_utterances.py` 65 | * `main.py` 66 | * `merge_data_json.py` 67 | * `merge_synth_and_appen.py` 68 | * `utils.py` 69 | 70 | 71 | **B. Memory-grounded Dialog Models** (`models/`): 72 | 73 | There are two type of models used in this work: 74 | 75 | 1. Text-only GPT-2 model: Memories are represented using their ids. 76 | * `run_preprocess_gpt2.sh`: Preprocessing the memory dialog dataset to make it ingestible for GPT-2 model training 77 | * `run_train_gpt2.sh`: Trains GPT-2 model (text-only) 78 | * `gpt2_dst/`: Folder with GPT-2 model 79 | * `run_evaluate_gpt2.sh`: Contains commands to evaluate a trained GPT-2 model on memory dialogs 80 | * `run_evaluate.sh`: Contains commands to evaluate output prediction JSON of any model on memory dialogs 81 | * `utils/`: Additional utility functions to train and evaluation GPT-2 model 82 | 83 | 2. Multimodal GPT-2 model: Memories are represented using their image features. 84 | * `run_me.sh`: Contains commands to train, evaluation, compile the results for GPT-2 (mm). 85 | * `utils/`: Additional utility functions to train and evaluation GPT-2 model(mm). 86 | 87 | 88 | Please reach out to [Satwik Kottur][satwik_link] (skottur@fb.com) 89 | or [Suengwhan Moon][shane_link] (shanemoon@fb.com) for questions related to this repository. 90 | 91 | 92 | If you find this repository useful, please cite our work: 93 | 94 | ``` 95 | @inproceedings{moon-kottur-2022-navigating, 96 | title = "Navigating Connected Memories with a Task-oriented Dialog System", 97 | author = "Moon, Seungwhan and 98 | Kottur, Satwik Kottur and 99 | Geramifard, Alborz and 100 | Damavandi, Babak", 101 | booktitle = "Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing", 102 | month = dec, 103 | year = "2022", 104 | address = "Online and Abu Dhabi, United Arab Emirates", 105 | publisher = "Association for Computational Linguistics", 106 | } 107 | ``` 108 | 109 | 110 | ### LICENSE 111 | *The majority of comet\_memory\_dialog is licensed under CC-BY-NC, however 112 | portions of the project are available under separate 113 | license terms: https://github.com/ictnlp/DSTC8-AVSD is licensed 114 | under the MIT license.* 115 | 116 | [paper_pdf]: 117 | [github]:https://github.com/facebookresearch/comet_memory_dialog 118 | [curated_lists]: https://drive.google.com/drive/folders/1V4RqUR0oSr2wwI4-ukx_V3NlP9IUHKoT?usp=sharing 119 | [satwik_link]: https://satwikkottur.github.io/ 120 | [shane_link]: https://shanemoon.com/ 121 | -------------------------------------------------------------------------------- /dialog_simulator/merge_data_json.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | #!/usr/bin/env python3 5 | """ 6 | Merges multiple batches of SIMMC 2.0 files into one, 7 | and also outputs train, dev, devtest, and test sets. 8 | """ 9 | import os 10 | import json 11 | import csv 12 | import random 13 | import pickle 14 | import numpy as np 15 | from utils import load_data_pickle 16 | 17 | 18 | if __name__ == "__main__": 19 | random.seed(0) 20 | np.random.seed(0) 21 | 22 | # Paths for merge 23 | paths_to_merge = [ 24 | #'/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/pilot_1_mem_dials_merged.p', 25 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/pilot_2_mem_dials_merged.p", 26 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_1_mem_dials_merged.p", 27 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_2_mem_dials_merged.p", 28 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_3_mem_dials_merged.p", 29 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_4_mem_dials_merged.p", 30 | ] 31 | 32 | path_out_json = "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/final_data/mem_dials_merged.json" 33 | path_out_pickle = "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/final_data/mem_dials_merged.p" 34 | 35 | mm_dialogs = [] 36 | 37 | for path_in_pickle in paths_to_merge: 38 | 39 | # Load original synth 40 | mm_dialogs.extend(load_data_pickle(path_in_pickle)) 41 | 42 | # Output 43 | print("Total: %d dialogs" % len(mm_dialogs)) 44 | 45 | json.dump( 46 | { 47 | "dialogue_data": [mm_d.to_dict() for mm_d in mm_dialogs], 48 | "split": "all", 49 | "year": 2021, 50 | "domain": "memory", 51 | }, 52 | open(path_out_json, "w"), 53 | indent=4, 54 | ) 55 | 56 | pickle.dump(mm_dialogs, open(path_out_pickle, "wb")) 57 | 58 | # Split 59 | r_train = 0.85 60 | r_dev = 0.10 61 | r_devtest = 0.04 62 | r_test = 0.01 63 | r_mini = 0.001 64 | 65 | path_out_train_json = "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/final_data/mem_dials_train.json" 66 | path_out_dev_json = "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/final_data/mem_dials_dev.json" 67 | path_out_devtest_json = "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/final_data/mem_dials_devtest.json" 68 | path_out_test_json = "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/final_data/mem_dials_test.json" 69 | path_out_mini_json = "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/final_data/mem_dials_mini.json" 70 | 71 | n_dialogs = len(mm_dialogs) 72 | indices = np.arange(n_dialogs) 73 | np.random.shuffle(indices) 74 | n_train = int(n_dialogs * r_train) 75 | n_dev = int(n_dialogs * r_dev) 76 | n_devtest = int(n_dialogs * r_devtest) 77 | n_test = int(n_dialogs * r_test) 78 | n_mini = int(n_dialogs * r_mini) 79 | 80 | train_indices = indices[:n_train] 81 | dev_indices = indices[n_train : n_train + n_dev] 82 | devtest_indices = indices[n_train + n_dev : n_train + n_dev + n_devtest] 83 | test_indices = indices[n_train + n_dev + n_devtest :] 84 | mini_indices = test_indices[:n_mini] 85 | 86 | mm_dialogs_train = [mm_d for i, mm_d in enumerate(mm_dialogs) if i in train_indices] 87 | mm_dialogs_dev = [mm_d for i, mm_d in enumerate(mm_dialogs) if i in dev_indices] 88 | mm_dialogs_devtest = [ 89 | mm_d for i, mm_d in enumerate(mm_dialogs) if i in devtest_indices 90 | ] 91 | mm_dialogs_test = [mm_d for i, mm_d in enumerate(mm_dialogs) if i in test_indices] 92 | mm_dialogs_mini = [mm_d for i, mm_d in enumerate(mm_dialogs) if i in mini_indices] 93 | 94 | json.dump( 95 | { 96 | "dialogue_data": [mm_d.to_dict() for mm_d in mm_dialogs_train], 97 | "split": "train", 98 | "year": 2021, 99 | "domain": "memory", 100 | }, 101 | open(path_out_train_json, "w"), 102 | indent=4, 103 | ) 104 | 105 | json.dump( 106 | { 107 | "dialogue_data": [mm_d.to_dict() for mm_d in mm_dialogs_dev], 108 | "split": "dev", 109 | "year": 2021, 110 | "domain": "memory", 111 | }, 112 | open(path_out_dev_json, "w"), 113 | indent=4, 114 | ) 115 | 116 | json.dump( 117 | { 118 | "dialogue_data": [mm_d.to_dict() for mm_d in mm_dialogs_devtest], 119 | "split": "devtest", 120 | "year": 2021, 121 | "domain": "memory", 122 | }, 123 | open(path_out_devtest_json, "w"), 124 | indent=4, 125 | ) 126 | 127 | json.dump( 128 | { 129 | "dialogue_data": [mm_d.to_dict() for mm_d in mm_dialogs_test], 130 | "split": "test", 131 | "year": 2021, 132 | "domain": "memory", 133 | }, 134 | open(path_out_test_json, "w"), 135 | indent=4, 136 | ) 137 | 138 | json.dump( 139 | { 140 | "dialogue_data": [mm_d.to_dict() for mm_d in mm_dialogs_mini], 141 | "split": "mini", 142 | "year": 2021, 143 | "domain": "memory", 144 | }, 145 | open(path_out_mini_json, "w"), 146 | indent=4, 147 | ) 148 | -------------------------------------------------------------------------------- /models/gpt2_mm/utils/create_result_jsons.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | """ 3 | Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved 4 | 5 | Create API and MM-DST result JSONS from model result file. 6 | 7 | Author(s): Satwik Kottur 8 | """ 9 | 10 | from __future__ import absolute_import, division, print_function, unicode_literals 11 | 12 | import argparse 13 | import collections 14 | import copy 15 | import json 16 | import ast 17 | import re 18 | 19 | 20 | def parse_flattened_result(to_parse): 21 | """ 22 | Parse out the belief state from the raw text. 23 | Return an empty list if the belief state can't be parsed 24 | 25 | Input: 26 | - A single of flattened result 27 | e.g. 'User: Show me something else => Belief State : DA:REQUEST ...' 28 | 29 | Output: 30 | - Parsed result in a JSON format, where the format is: 31 | [ 32 | { 33 | 'act': # e.g. 'DA:REQUEST', 34 | 'slots': [ 35 | slot_name, 36 | slot_value 37 | ] 38 | }, ... # End of a frame 39 | ] # End of a dialog 40 | """ 41 | dialog_act_regex = re.compile(r"([\w:?.?]*) *\[(.*)\] *\(([^\]]*)\) *\<([^\]]*)\>") 42 | slot_regex = re.compile(r"([A-Za-z0-9_.-:]*) *= *(\[([^\]]*)\]|[^,]*)") 43 | request_regex = re.compile(r"([A-Za-z0-9_.-:]+)") 44 | object_regex = re.compile(r"([A-Za-z0-9]+)") 45 | 46 | belief = [] 47 | 48 | # Parse 49 | to_parse = to_parse.strip() 50 | # to_parse: 'DIALOG_ACT_1 : [ SLOT_NAME = SLOT_VALUE, ... ] ...' 51 | for dialog_act in dialog_act_regex.finditer(to_parse): 52 | d = { 53 | "act": dialog_act.group(1), 54 | "slots": {}, 55 | "request_slots": [], 56 | "memories": [], 57 | } 58 | for slot in slot_regex.finditer(dialog_act.group(2)): 59 | # If parsing python list eval it else keep unique string. 60 | slot_name = slot.group(1).strip() 61 | slot_values = slot.group(2).strip() 62 | # If there are nones, replace them with Nones and later remove them. 63 | if re.match('\[.*\]', slot_values): 64 | try: 65 | slot_values = slot_values.replace("none", "None") 66 | parsed_slot_values = ast.literal_eval(slot_values) 67 | d["slots"][slot_name] = [ii for ii in parsed_slot_values if ii] 68 | except: 69 | # If error when parsing the slots add empty string 70 | print(f"Error parsing: {to_parse}") 71 | d["slots"][slot_name] = "" 72 | else: 73 | d["slots"][slot_name] = slot_values 74 | 75 | for request_slot in request_regex.finditer(dialog_act.group(3)): 76 | d["request_slots"].append(request_slot.group(1).strip()) 77 | for object_id in object_regex.finditer(dialog_act.group(4)): 78 | d["memories"].append(object_id.group(1).strip()) 79 | if d != {}: 80 | belief.append(d) 81 | return belief 82 | 83 | 84 | def create_result_jsons(results, test_data): 85 | """Creates two JSON files from results. 86 | 87 | Args: 88 | results: List of generated results from the model. 89 | test_data: Raw JSON test file. 90 | 91 | Returns: 92 | response_results: Dict containing response results 93 | dst_results: Dict containing DST results 94 | """ 95 | dst_results = copy.deepcopy(test_data) 96 | response_results = collections.defaultdict(list) 97 | dst_pool = {} 98 | for instance in results: 99 | dialog_id = instance["dialog_id"] 100 | turn_id = instance["turn_id"] 101 | if instance["type"] == "API": 102 | index = (dialog_id, turn_id) 103 | dst_pool[index] = instance 104 | else: 105 | if dialog_id not in response_results: 106 | response_results[dialog_id] = { 107 | "dialog_id": dialog_id, 108 | "predictions": [], 109 | } 110 | response_results[dialog_id]["predictions"].append( 111 | { 112 | "turn_id": turn_id, 113 | "response": instance["model_prediction"], 114 | } 115 | ) 116 | num_missing = 0 117 | num_present = 0 118 | 119 | for dialog_datum in dst_results["dialogue_data"]: 120 | del dialog_datum["mentioned_memory_ids"] 121 | del dialog_datum["memory_graph_id"] 122 | dialog_id = dialog_datum["dialogue_idx"] 123 | for datum in dialog_datum["dialogue"]: 124 | turn_id = datum["turn_idx"] 125 | index = (dialog_id, turn_id) 126 | if index in dst_pool: 127 | model_pred_datum = dst_pool[index] 128 | model_pred = model_pred_datum["model_prediction"].strip(" ") 129 | parsed_result = parse_flattened_result(model_pred) 130 | datum["transcript_annotated"] = parsed_result 131 | num_present += 1 132 | else: 133 | del datum["transcript_annotated"] 134 | print(f"Missing! -- {index}") 135 | num_missing += 1 136 | print(f"Missing: {num_missing} Present: {num_present}") 137 | return list(response_results.values()), dst_results 138 | 139 | 140 | def main(args): 141 | with open(args["memory_test_json"], "r") as file_id: 142 | test_data = json.load(file_id) 143 | with open(args["model_output_json"], "r") as file_id: 144 | results = json.load(file_id) 145 | response_results, dst_results = create_result_jsons(results, test_data) 146 | 147 | # Save the results. 148 | response_results_path = args["model_output_json"].replace( 149 | ".json", "_response_results.json" 150 | ) 151 | with open(response_results_path, "w") as file_id: 152 | json.dump(response_results, file_id) 153 | dst_results_path = args["model_output_json"].replace(".json", "_dst_results.json") 154 | with open(dst_results_path, "w") as file_id: 155 | json.dump(dst_results, file_id) 156 | 157 | 158 | if __name__ == "__main__": 159 | parser = argparse.ArgumentParser(description=__doc__) 160 | parser.add_argument( 161 | "--memory_test_json", 162 | required=True, 163 | help="JSON file for test data", 164 | ) 165 | parser.add_argument( 166 | "--model_output_json", required=True, help="JSON file with model outputs" 167 | ) 168 | 169 | try: 170 | parsed_args = vars(parser.parse_args()) 171 | except (IOError) as msg: 172 | parser.error(str(msg)) 173 | main(parsed_args) 174 | -------------------------------------------------------------------------------- /models/gpt2_text/utils/response_evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved 3 | 4 | Script evaluates response generation using GT responses. 5 | 6 | Expected JSON format: 7 | 8 | [ 9 | "dialog_id": , 10 | "predictions": [ 11 | { 12 | "turn_id": , 13 | "response": , 14 | } 15 | ... 16 | ] 17 | ... 18 | ] 19 | 20 | Author(s): Satwik Kottur 21 | """ 22 | 23 | from __future__ import absolute_import, division, print_function, unicode_literals 24 | 25 | import argparse 26 | import json 27 | 28 | import nltk 29 | import numpy as np 30 | import tqdm 31 | 32 | 33 | def normalize_sentence(sentence): 34 | """Normalize the sentences and tokenize.""" 35 | return nltk.tokenize.word_tokenize(sentence.lower()) 36 | 37 | 38 | def evaluate_response_generation( 39 | gt_responses, 40 | model_responses, 41 | single_round_eval=False, 42 | record_instance_results=None, 43 | compute_bert_score=False, 44 | ): 45 | """Evaluates response generation using the raw data and model predictions. 46 | 47 | Args: 48 | gt_responses: Ground truth responses. 49 | model_responses: Generated responses. 50 | single_round_eval: Evaluate only for the last turn. 51 | record_instance_results: Save path for instance level metrics. 52 | """ 53 | gt_responses_pool = {ii["dialogue_idx"]: ii for ii in gt_responses["dialogue_data"]} 54 | bleu_scores = [] 55 | # Smoothing function. 56 | chencherry = nltk.translate.bleu_score.SmoothingFunction() 57 | 58 | # Lazy initialization for bert score. 59 | if compute_bert_score: 60 | import bert_score 61 | 62 | bert_scorer = bert_score.BERTScorer(lang="en") 63 | bert_scores = [] 64 | 65 | num_evaluations = 0 66 | for model_datum in tqdm.tqdm(model_responses, desc="Evaluating"): 67 | dialog_id = model_datum["dialog_id"] 68 | num_gt_rounds = len(gt_responses_pool[dialog_id]["dialogue"]) 69 | for round_datum in model_datum["predictions"]: 70 | round_id = round_datum["turn_id"] 71 | # Skip if single_round_eval and this is not the last round. 72 | if single_round_eval and round_id != num_gt_rounds - 1: 73 | continue 74 | 75 | response = round_datum["response"] 76 | gt_datum = gt_responses_pool[dialog_id]["dialogue"][round_id] 77 | gt_response = gt_datum["system_transcript"] 78 | try: 79 | gt_response_clean = normalize_sentence(gt_response) 80 | response_clean = normalize_sentence(response) 81 | bleu_score = nltk.translate.bleu_score.sentence_bleu( 82 | [gt_response_clean], 83 | response_clean, 84 | smoothing_function=chencherry.method7, 85 | ) 86 | bleu_scores.append(bleu_score) 87 | 88 | if compute_bert_score: 89 | _, _, bert_f1 = bert_scorer.score( 90 | [" ".join(response_clean)], [" ".join(gt_response_clean)] 91 | ) 92 | bert_scores.append(bert_f1.item()) 93 | except: 94 | print(f"Model: {response} -> GT: {gt_response}") 95 | 96 | # Add the result to datum and save it back. 97 | if record_instance_results: 98 | round_datum["bleu"] = bleu_score 99 | round_datum["response_len"] = len(normalize_sentence(gt_response)) 100 | if compute_bert_score: 101 | round_datum["bert_score"] = bert_f1 102 | 103 | print("#Instances evaluated BLEU: {}".format(len(bleu_scores))) 104 | if record_instance_results: 105 | print(f"Saving per instance results: {record_instance_results}") 106 | with open(record_instance_results, "w") as file_id: 107 | json.dump(model_responses, file_id) 108 | 109 | bleu_str_mean = np.mean(bleu_scores) 110 | bleu_str_err = np.std(bleu_scores) / np.sqrt(len(bleu_scores)) 111 | if compute_bert_score: 112 | bert_score_mean = np.mean(bert_scores) 113 | bert_score_err = np.std(bert_scores) / np.sqrt(len(bert_scores)) 114 | else: 115 | bert_score_mean, bert_score_err = None, None 116 | return bleu_str_mean, bleu_str_err, bert_score_mean, bert_score_err 117 | 118 | 119 | def main(args): 120 | print("Reading: {}".format(args["data_json_path"])) 121 | with open(args["data_json_path"], "r") as file_id: 122 | gt_responses = json.load(file_id) 123 | print("Reading: {}".format(args["model_response_path"])) 124 | with open(args["model_response_path"], "r") as file_id: 125 | model_responses = json.load(file_id) 126 | 127 | if args["record_instance_results"]: 128 | instance_results_path = args["model_response_path"].replace( 129 | ".json", "_results.json" 130 | ) 131 | else: 132 | instance_results_path = None 133 | 134 | bleu_score, bleu_std_err, bert_score, bert_score_err = evaluate_response_generation( 135 | gt_responses, 136 | model_responses, 137 | args["single_round_evaluation"], 138 | instance_results_path, 139 | args["compute_bert_score"], 140 | ) 141 | print(f"BLEU Score: {bleu_score:.4f} +- {bleu_std_err}") 142 | if args["compute_bert_score"]: 143 | print(f"BERT Score: {bert_score:.4f} +- {bert_score_err}") 144 | report = { 145 | "bleu_score": bleu_score, 146 | "bleu_std_err": bleu_std_err, 147 | "bert_score": bert_score, 148 | "bert_score_err": bert_score_err, 149 | } 150 | return report 151 | 152 | 153 | if __name__ == "__main__": 154 | parser = argparse.ArgumentParser(description="Response Generation Evaluation") 155 | parser.add_argument( 156 | "--data_json_path", 157 | default="data/mem_dials_devtest.json", 158 | help="Data with gold responses", 159 | ) 160 | parser.add_argument( 161 | "--model_response_path", default=None, help="Responses generated by the model" 162 | ) 163 | parser.add_argument( 164 | "--single_round_evaluation", 165 | dest="single_round_evaluation", 166 | action="store_true", 167 | default=False, 168 | help="Single round evaluation for hidden split", 169 | ) 170 | parser.add_argument( 171 | "--record_instance_results", 172 | dest="record_instance_results", 173 | action="store_true", 174 | default=False, 175 | help="Records per instance results and save it back", 176 | ) 177 | parser.add_argument( 178 | "--compute_bert_score", 179 | dest="compute_bert_score", 180 | action="store_true", 181 | default=False, 182 | help="Compute BERT score along with BLEU-4", 183 | ) 184 | try: 185 | parsed_args = vars(parser.parse_args()) 186 | except (IOError) as msg: 187 | parser.error(str(msg)) 188 | main(parsed_args) 189 | -------------------------------------------------------------------------------- /models/gpt2_text/gpt2_dst/scripts/reformat_dst_response_outputs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved 4 | 5 | Scripts for evaluating the GPT-2 DST model predictions. 6 | 7 | First, we parse the line-by-line stringified format into responses 8 | and compute BLEU score. 9 | """ 10 | import argparse 11 | import ast 12 | import copy 13 | import json 14 | import re 15 | 16 | import numpy as np 17 | import tqdm 18 | from gpt2_dst.utils.convert import parse_flattened_result 19 | 20 | 21 | def convert_slots_to_dict(api_call_json): 22 | """Converts the slots from list of lists to a dict. 23 | 24 | Args: 25 | api_call_json: JSON containing the parsed API call 26 | """ 27 | for frame_ind, frame in enumerate(api_call_json): 28 | slot_dict = {} 29 | for slot_name, slot_value in frame["slots"]: 30 | if re.match("\[.*\]", slot_value): 31 | try: 32 | slot_dict[slot_name] = ast.literal_eval(slot_value) 33 | except: 34 | # If error when parsing the slots add empty string 35 | print(f"Error parsing: {slot_value} -> {frame}") 36 | slot_dict[slot_name] = "" 37 | else: 38 | slot_dict[slot_name] = slot_value 39 | frame["slots"] = slot_dict 40 | return api_call_json 41 | 42 | 43 | def parse_results_from_file(input_path, turn_info, original_data): 44 | """Parse targets from a flattened file to create response, dst evaluation files. 45 | 46 | Args: 47 | input_path: Path to read the responses from. 48 | turn_info: List of dialog, turn info. 49 | original_data: Original JSON target. 50 | 51 | Returns: 52 | dst_json: JSON file with DST results 53 | responses_json: JSON file with responses 54 | """ 55 | # Collate all lines to ensure they start with either or . 56 | with open(input_path, "r") as file_id: 57 | lines = [ii.strip() for ii in file_id.readlines()] 58 | 59 | fixed_lines = [] 60 | current_line = "" 61 | for line in lines: 62 | if line[:6] == "" or line[:8] == "": 63 | fixed_lines.append(line) 64 | else: 65 | fixed_lines[-1] += line 66 | print(f"Collating: {len(lines)} -> {len(fixed_lines)}") 67 | lines = fixed_lines 68 | 69 | # Identify API call string and response in each line. 70 | assert len(lines) == len(turn_info), "#lines and #turn_info do not match!" 71 | responses_json = {} 72 | dst_pool = {} 73 | for line_ind, line in enumerate(lines): 74 | dialog_id, turn_id, prediction_type = turn_info[line_ind] 75 | if prediction_type == "api_call": 76 | api_call_json = parse_flattened_result(line.split("")[0] + "") 77 | # Convert slots from list of list to dicts. 78 | api_call_json = convert_slots_to_dict(api_call_json) 79 | dst_index = (dialog_id, turn_id) 80 | assert dst_index not in dst_pool, "Result already exists!" 81 | dst_pool[dst_index] = api_call_json 82 | # Check if memories are integers, else skip. 83 | for frame_info in api_call_json: 84 | memories = [] 85 | for ii in frame_info["memories"]: 86 | try: 87 | ii_int = int(ii) 88 | memories.append(ii) 89 | except: 90 | pass 91 | frame_info["memories"] = memories 92 | 93 | elif prediction_type == "response": 94 | response_str = line.split("")[-1].strip() 95 | if dialog_id not in responses_json: 96 | responses_json[dialog_id] = { 97 | "dialog_id": dialog_id, 98 | "predictions": [], 99 | } 100 | responses_json[dialog_id]["predictions"].append( 101 | { 102 | "turn_id": turn_id, 103 | "response": response_str, 104 | } 105 | ) 106 | 107 | else: 108 | raise ValueError(f"Invalid prediction_type: {prediction_type}!") 109 | responses_json = list(responses_json.values()) 110 | 111 | num_missing = 0 112 | num_present = 0 113 | dst_json = copy.deepcopy(original_data) 114 | for dialog_datum in dst_json["dialogue_data"]: 115 | del dialog_datum["mentioned_memory_ids"] 116 | del dialog_datum["memory_graph_id"] 117 | dialog_id = dialog_datum["dialogue_idx"] 118 | for datum in dialog_datum["dialogue"]: 119 | del datum["transcript_annotated"] 120 | turn_id = datum["turn_idx"] 121 | index = (dialog_id, turn_id) 122 | if index in dst_pool: 123 | datum["transcript_annotated"] = dst_pool[index] 124 | num_present += 1 125 | else: 126 | print(f"Missing! -- {index}") 127 | num_missing += 1 128 | print(f"Missing: {num_missing} Present: {num_present}") 129 | return dst_json, responses_json 130 | 131 | 132 | if __name__ == "__main__": 133 | # Parse input args 134 | parser = argparse.ArgumentParser() 135 | parser.add_argument( 136 | "--input_target_json", required=True, help="Path to target JSON file" 137 | ) 138 | parser.add_argument( 139 | "--input_dialog_ids", 140 | required=True, 141 | help="Path for dialog, turn ids for input (.txt)", 142 | ) 143 | parser.add_argument( 144 | "--input_path_predicted", 145 | required=True, 146 | help="path for model prediction output, line-separated format (.txt)", 147 | ) 148 | parser.add_argument( 149 | "--output_path_report", 150 | required=True, 151 | help="Path to save evaluation summary (dst and response) (.json)", 152 | ) 153 | args = parser.parse_args() 154 | 155 | input_path_predicted = args.input_path_predicted 156 | output_path_report = args.output_path_report 157 | # Read the input target JSON file. 158 | with open(args.input_target_json, "r") as file_id: 159 | original_data = json.load(file_id) 160 | 161 | # Read the dialog and turn ids. 162 | with open(args.input_dialog_ids, "r") as file_id: 163 | turn_info = [ast.literal_eval(ii.strip("\n")) for ii in file_id.readlines()] 164 | # Convert the data from the GPT-2 friendly format to JSON formats. 165 | dst_json, responses_json = parse_results_from_file( 166 | input_path_predicted, turn_info, original_data 167 | ) 168 | 169 | # Saving both the DST and response JSON. 170 | dst_json_path = args.output_path_report.replace(".json", "_dst_results.json") 171 | print(f"Saving DST results: {dst_json_path}") 172 | with open(dst_json_path, "w") as file_id: 173 | json.dump(dst_json, file_id) 174 | responses_json_path = args.output_path_report.replace( 175 | ".json", "_response_results.json" 176 | ) 177 | print(f"Saving responses: {responses_json_path}") 178 | with open(responses_json_path, "w") as file_id: 179 | json.dump(responses_json, file_id) 180 | -------------------------------------------------------------------------------- /dialog_simulator/merge_synth_and_appen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | #!/usr/bin/env python3 5 | """ 6 | Description: merges the synthetically generated dialogs (.json, .p) 7 | and the tab-separated Appen annotations (.txt) 8 | to putput the merged dialogs in both .json and .p formats 9 | """ 10 | import os 11 | import json 12 | import csv 13 | import random 14 | import pickle 15 | from utils import load_data_pickle 16 | 17 | 18 | if __name__ == "__main__": 19 | # Parameters for generation 20 | path_tuples = [ 21 | # Pilot 1: 50 dialogs 22 | # [ 23 | # '/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/pilot_1_mem_dials.p', 24 | # '/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/paraphrased_0622.csv', 25 | # '/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/pilot_1_mem_dials_merged.json', 26 | # '/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/pilot_1_mem_dials_merged.p', 27 | # ], 28 | # Pilot 2: 450 dialogs 29 | [ 30 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/pilot_2_mem_dials.p", 31 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/paraphrased_0622.csv", 32 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/pilot_2_mem_dials_merged.json", 33 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/pilot_2_mem_dials_merged.p", 34 | ], 35 | # Batch 1: 2000 dialogs 36 | [ 37 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_1_mem_dials.p", 38 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/paraphrased_0622.csv", 39 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_1_mem_dials_merged.json", 40 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_1_mem_dials_merged.p", 41 | ], 42 | # Batch 2: 500 dialogs 43 | [ 44 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_2_mem_dials.p", 45 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/paraphrased_0622.csv", 46 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_2_mem_dials_merged.json", 47 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_2_mem_dials_merged.p", 48 | ], 49 | # Batch 3: 2000 dialogs 50 | [ 51 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_3_mem_dials.p", 52 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/paraphrased_0622.csv", 53 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_3_mem_dials_merged.json", 54 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_3_mem_dials_merged.p", 55 | ], 56 | # Batch 4: 6000 dialogs 57 | [ 58 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_4_mem_dials.p", 59 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/paraphrased_0622.csv", 60 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_4_mem_dials_merged.json", 61 | "/Users/shanemoon/workspace/memory_dialog/dialog_simulator/results/batch_4_mem_dials_merged.p", 62 | ], 63 | ] 64 | 65 | for path_tuple in path_tuples: 66 | path_in_synth = path_tuple[0] 67 | path_in_appen = path_tuple[1] 68 | path_out_json = path_tuple[2] 69 | path_out_pickle = path_tuple[3] 70 | 71 | # Load original synth 72 | original_dialogs = load_data_pickle(path_in_synth) 73 | mm_dialogs = [] 74 | 75 | # Load paraphrased 76 | fieldname_to_turn_idx = { 77 | "turn0_paraphrase": 0, 78 | "turn1_paraphrase": 1, 79 | "turn2_paraphrase": 2, 80 | "turn3_paraphrase": 3, 81 | "turn4_paraphrase": 4, 82 | "turn5_paraphrase": 5, 83 | "turn6_paraphrase": 6, 84 | "turn7_paraphrase": 7, 85 | "turn8_paraphrase": 8, 86 | "turn9_paraphrase": 9, 87 | "turn10_paraphrase": 10, 88 | "turn11_paraphrase": 11, 89 | "turn12_paraphrase": 12, 90 | "turn13_paraphrase": 13, 91 | "turn14_paraphrase": 14, 92 | "turn15_paraphrase": 15, 93 | "turn16_paraphrase": 16, 94 | "turn17_paraphrase": 17, 95 | "turn18_paraphrase": 18, 96 | "turn19_paraphrase": 19, 97 | "turn20_paraphrase": 20, 98 | "turn21_paraphrase": 21, 99 | "turn22_paraphrase": 22, 100 | "turn23_paraphrase": 23, 101 | } 102 | COL_DIALOG_ID = 88 103 | 104 | turn_idx_to_col = {} 105 | dialog_id_to_utter = {} 106 | 107 | with open(path_in_appen, "r", encoding="mac_roman") as f: 108 | reader = csv.reader(f, delimiter=",", quotechar='"') 109 | for i, line in enumerate(reader): 110 | if i == 0: 111 | for col_id, fieldname in enumerate(line): 112 | 113 | if fieldname in fieldname_to_turn_idx: 114 | turn_idx = fieldname_to_turn_idx[fieldname] 115 | turn_idx_to_col[turn_idx] = col_id 116 | 117 | else: 118 | dialog_id = int(line[COL_DIALOG_ID]) 119 | dialog_id_to_utter[dialog_id] = [] 120 | 121 | for turn_idx in range(len(turn_idx_to_col)): 122 | if turn_idx in turn_idx_to_col: 123 | 124 | utter = line[turn_idx_to_col[turn_idx]] 125 | utter = utter.strip() 126 | 127 | if utter != "": 128 | dialog_id_to_utter[dialog_id].append(utter) 129 | 130 | else: 131 | if turn_idx < 16: 132 | print( 133 | "Check dialog id %d, turn %d" 134 | % (dialog_id, turn_idx) 135 | ) 136 | 137 | # Merge 138 | for i, mm_d in enumerate(original_dialogs): 139 | d = mm_d.dialog 140 | dialog_id = d.idx 141 | 142 | if dialog_id not in dialog_id_to_utter: 143 | print("Dialog %d is missing." % dialog_id) 144 | continue 145 | 146 | mm_dialogs.append(mm_d) 147 | n_rounds = int(len(dialog_id_to_utter[dialog_id]) / 2) 148 | 149 | # TODO: discarding the utterances with missing paraphrases for now 150 | # Causes: residuals & incompletes from annotations, etc. 151 | mm_dialogs[-1].dialog.user_turns = mm_dialogs[-1].dialog.user_turns[ 152 | :n_rounds 153 | ] 154 | mm_dialogs[-1].dialog.asst_turns = mm_dialogs[-1].dialog.asst_turns[ 155 | :n_rounds 156 | ] 157 | 158 | for j in range(n_rounds): 159 | 160 | try: 161 | user_turn = d.user_turns[j] 162 | asst_turn = d.asst_turns[j] 163 | 164 | user_turn_idx = j * 2 165 | asst_turn_idx = j * 2 + 1 166 | 167 | user_paraphrase = dialog_id_to_utter[dialog_id][user_turn_idx] 168 | asst_paraphrase = dialog_id_to_utter[dialog_id][asst_turn_idx] 169 | 170 | mm_dialogs[-1].dialog.user_turns[j].frames[ 171 | -1 172 | ].uttr = user_paraphrase 173 | mm_dialogs[-1].dialog.asst_turns[j].frames[ 174 | -1 175 | ].uttr = asst_paraphrase 176 | 177 | except: 178 | print("Missing rounds %d from dialog %d" % (j, dialog_id)) 179 | print(len(dialog_id_to_utter[dialog_id])) 180 | print(len(d.user_turns)) 181 | 182 | # Output 183 | print("Outputting JSON file at %s..." % path_out_json) 184 | json.dump( 185 | {"dialogue_data": [mm_d.to_dict() for mm_d in mm_dialogs]}, 186 | open(path_out_json, "w"), 187 | indent=4, 188 | ) 189 | 190 | pickle.dump(mm_dialogs, open(path_out_pickle, "wb")) 191 | -------------------------------------------------------------------------------- /dialog_simulator/MemoryDialogSimulator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | #!/usr/bin/env python3 5 | import json, random, traceback, os 6 | from typing import List, Tuple 7 | from constants import TurnSpeaker, DialogAct, API_STATUS 8 | from Data import Dialog, MemoryDialog, MemoryGraph, Turn, Goal 9 | from UserSimulator import PilotUserSimulator 10 | from AssistantSimulator import PilotAssistantSimulator 11 | from GoalGenerator import RuleBasedGoalGenerator 12 | from MemoryServiceAPI import MemoryServiceAPI 13 | from utils import build_parameter_ontology 14 | 15 | random.seed(0) 16 | 17 | 18 | class MemoryDialogSimulator: 19 | def __init__(self, *args, **kwargs): 20 | # Initialize user simulator, assistant simulator, memory_graphs etc. 21 | self.domain = kwargs.pop("domain") 22 | self._memory_service_api = kwargs.pop("memory_service_api", MemoryServiceAPI()) 23 | self._user_simulator = kwargs.pop("user_simulator", PilotUserSimulator()) 24 | self._assistant_simulator = kwargs.pop( 25 | "assistant_simulator", PilotAssistantSimulator() 26 | ) 27 | self._goal_generator = kwargs.pop( 28 | "goal_generator", RuleBasedGoalGenerator(domain=self.domain) 29 | ) 30 | self._memory_graph_bank = kwargs.pop("memory_graph_bank", {}) 31 | 32 | self._user_simulator.register_memory_service_api(self._memory_service_api) 33 | self._assistant_simulator.register_memory_service_api(self._memory_service_api) 34 | 35 | def set_user_simulator(self, user_simulator): 36 | self._user_simulator = user_simulator 37 | 38 | def set_assistant_simulator(self, assistant_simulator): 39 | self._assistant_simulator = assistant_simulator 40 | 41 | def set_goal_generator(self, goal_generator): 42 | self._goal_generator = goal_generator 43 | 44 | def set_memory_service_api(self, memory_service_api): 45 | self._memory_service_api = memory_service_api 46 | 47 | def sample_goals(self, memory_graph, goal_config) -> List[Goal]: 48 | return self._goal_generator.sample_goals( 49 | memory_graph=memory_graph, goal_config=goal_config 50 | ) 51 | 52 | def sample_memory_graph(self) -> MemoryGraph: 53 | if self._memory_graph_bank == {}: 54 | # Empty memory graph 55 | return MemoryGraph() 56 | 57 | # Randomly sample a memory 58 | # TODO: allow for more organized way of sampling memories 59 | memory_graph_id = random.choice(list(self._memory_graph_bank.keys())) 60 | memory_graph = self._memory_graph_bank[memory_graph_id] 61 | 62 | return MemoryGraph(data=memory_graph) 63 | 64 | def batch_generate_dialog_flows( 65 | self, 66 | n_dialogs: int, 67 | n_max_turns: int, 68 | start_dialog_idx: int, 69 | goal_config: dict = {}, 70 | ) -> List[MemoryGraph]: 71 | 72 | # Batch generate multiple dialogs using the same simulators 73 | memory_dialogs = [] 74 | 75 | for i in range(n_dialogs): 76 | # Continue until generation is successful 77 | generation_success = False 78 | 79 | while not generation_success: 80 | try: 81 | # Sample a memory graph (user) 82 | memory_graph = self.sample_memory_graph() 83 | 84 | # Create an empty memory dialog 85 | memory_dialog = MemoryDialog(memory_graph=memory_graph) 86 | 87 | # Generate Goal Config 88 | goal_config["parameter_ontology"] = build_parameter_ontology( 89 | memory_dialog.memory_graph, 90 | self._memory_service_api.metadata, 91 | self.domain, 92 | ) 93 | 94 | # Sample goals for this dialog 95 | goals = self.sample_goals( 96 | memory_graph=memory_dialog.memory_graph, goal_config=goal_config 97 | ) 98 | 99 | # Generate dialog flow 100 | memory_dialog = self.generate_dialog_flow( 101 | goals, memory_dialog, n_max_turns 102 | ) 103 | memory_dialog.dialog.idx = start_dialog_idx + i 104 | 105 | # If everything is successful, append to memory_dialogs 106 | generation_success = True 107 | memory_dialogs.append(memory_dialog) 108 | 109 | except: 110 | # TODO: Make a more robust abort strategy 111 | print("** Error in generating dialog. Ignoring this one. **") 112 | traceback.print_exc() 113 | print() 114 | 115 | return memory_dialogs 116 | 117 | def generate_dialog_flow( 118 | self, 119 | goals: List[Goal], 120 | memory_dialog: MemoryDialog, 121 | n_max_turns: int, 122 | initialize=True, 123 | ) -> MemoryDialog: 124 | 125 | if initialize: 126 | # Initialize memory_dialog 127 | memory_dialog.initialize() 128 | 129 | # Iterate and generate a dialog turn by turn 130 | i = 0 131 | while not goals == [] and i < n_max_turns: 132 | 133 | # Pick a goal 134 | current_goal = goals.pop(0) 135 | goal_met = False 136 | print("Goal:", current_goal) 137 | 138 | while not goal_met and i < n_max_turns: 139 | 140 | # Generate a turn 141 | memory_dialog = self.generate_turn(current_goal, memory_dialog) 142 | 143 | # End of a turn: update dialog & goals 144 | i += 1 145 | goal_met = memory_dialog.is_goal_met(current_goal) 146 | 147 | is_valid_dialog = self.validate_dialog(memory_dialog) 148 | if not is_valid_dialog: 149 | # If something is not right about this dialog, abort. 150 | # TODO: abort gracefully 151 | assert False 152 | 153 | return memory_dialog 154 | 155 | def generate_turn(self, goal: Goal, memory_dialog: MemoryDialog) -> MemoryDialog: 156 | 157 | # TODO: extend it for multiple frames per turn 158 | 159 | # (1) Generate a User turn, given a target goal and a memory_dialog 160 | # Generate dialog act and slots 161 | user_frame = self._user_simulator.execute_turn(goal, memory_dialog) 162 | 163 | # Template based utterance generation 164 | user_frame = self._user_simulator.generate_uttr(user_frame, goal) 165 | 166 | # Instantiate a user turn, and update the memory_dialog 167 | user_turn = Turn([user_frame], TurnSpeaker.USER, goal) 168 | memory_dialog.dialog.add_user_turn(user_turn) 169 | print("U:", user_turn) 170 | 171 | # (2) Generate a Assistant turn, given a target goal and a memory_dialog 172 | # Generate dialog act and slots 173 | asst_frame, api_request, api_result = self._assistant_simulator.execute_turn( 174 | goal, memory_dialog 175 | ) 176 | 177 | # Template based utterance generation 178 | asst_frame = self._assistant_simulator.generate_uttr(asst_frame, goal) 179 | 180 | # Instantiate a user turn, and update the memory_dialog 181 | asst_turn = Turn([asst_frame], TurnSpeaker.ASSISTANT, goal) 182 | memory_dialog.dialog.add_asst_turn(asst_turn) 183 | print("A:", asst_turn) 184 | 185 | # Add goals and api_calls 186 | memory_dialog.dialog.add_goal(goal) 187 | memory_dialog.dialog.add_api_call(api_request) 188 | memory_dialog.dialog.add_api_result(api_result) 189 | 190 | return memory_dialog 191 | 192 | def validate_dialog(self, memory_dialog: MemoryDialog) -> bool: 193 | # Check for any undesirable traits of a dialog 194 | n_turns = len(memory_dialog.dialog.asst_turns) 195 | 196 | # (1) Multiple sharing of the same memory 197 | set_shared_memory_ids = set() 198 | for user_turn in memory_dialog.dialog.user_turns: 199 | # TODO: Handle multiple frames per turn 200 | dialog_act = user_turn.frames[-1].dialog_act 201 | 202 | if dialog_act == DialogAct.REQUEST_SHARE: 203 | memories_to_share = user_turn.frames[-1].act_attributes.memories 204 | for m in memories_to_share: 205 | memory_id = m.data["memory_id"] 206 | if memory_id in set_shared_memory_ids: 207 | # If this memory_id is already shared, abort 208 | return False 209 | set_shared_memory_ids.add(memory_id) 210 | 211 | # (2) Too frequent search fails 212 | n_search_fails = 0 213 | for api_result in memory_dialog.dialog.api_results: 214 | status = api_result.status 215 | if status == API_STATUS.SEARCH_NOT_FOUND: 216 | n_search_fails += 1 217 | 218 | if (n_turns <= 4 and n_search_fails >= 2) or ( 219 | n_turns > 4 and n_search_fails >= 3 220 | ): 221 | return False 222 | 223 | # Otherwise, this dialog is good. 224 | return True 225 | -------------------------------------------------------------------------------- /dialog_simulator/GoalGenerator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | #!/usr/bin/env python3 5 | import random 6 | from constants import ( 7 | GoalType, 8 | GoalMemoryRefType, 9 | numeric_slots, 10 | non_visual_slots, 11 | visual_slots, 12 | all_slots, 13 | ) 14 | from Data import Goal, GoalParameter, MemoryTime 15 | from utils import weighted_choice 16 | import copy 17 | 18 | random.seed(0) 19 | 20 | 21 | class RuleBasedGoalGenerator: 22 | def __init__(self, *args, **kwargs): 23 | self.non_visual_slots = non_visual_slots 24 | self.visual_slots = visual_slots 25 | self.all_slots = all_slots 26 | 27 | def sample_goals(self, *args, **kwargs): 28 | memory_graph = kwargs.pop("memory_graph", None) 29 | goal_config = kwargs.pop("goal_config", {}) 30 | n_min_goals = goal_config.get("n_min_goals", 3) 31 | n_max_goals = goal_config.get("n_max_goals", 5) 32 | n_goals = random.randint(n_min_goals, n_max_goals) 33 | 34 | goal_type_list = [ 35 | GoalType.SEARCH, 36 | GoalType.REFINE_SEARCH, 37 | GoalType.GET_RELATED, 38 | GoalType.GET_INFO, 39 | GoalType.GET_AGGREGATED_INFO, 40 | GoalType.SHARE, 41 | GoalType.CHITCHAT, 42 | ] 43 | goal_type_list_weights_start = [ 44 | 1, 45 | 0, 46 | 0, 47 | 0, 48 | 0, 49 | 0, 50 | 0, 51 | # 1, 0, 0, 0, 1, 0, 0, 52 | ] 53 | 54 | goal_type_list_weights_mid = [ 55 | 0.8, 56 | 1.1, 57 | 1.7, 58 | 1.1, 59 | 0, 60 | 0.1, 61 | 0, 62 | # 1, 0.8, 0.8, 1, 1, 0.5, 0.5, 63 | ] 64 | 65 | goal_type_list_weights_end = [ 66 | 0.3, 67 | 0.5, 68 | 0.6, 69 | 0.5, 70 | 0, 71 | 3, 72 | 0, 73 | # 0.5, 0.5, 0.5, 0.5, 0.5, 3, 1, 74 | ] 75 | 76 | # Randomly sample from the goal type list 77 | # For now, we enforce the goals to start with BROWSE 78 | # and end with ADD_TO_CART 79 | # TODO: allow for a more flexible way of generating 80 | # goal types 81 | goal_types = ( 82 | random.choices( 83 | population=goal_type_list, weights=goal_type_list_weights_start, k=1 84 | ) 85 | + random.choices( 86 | population=goal_type_list, 87 | weights=goal_type_list_weights_mid, 88 | k=n_goals - 2, 89 | ) 90 | + random.choices( 91 | population=goal_type_list, weights=goal_type_list_weights_end, k=1 92 | ) 93 | ) 94 | 95 | # Make a complete goal with an accompanying set of goal parameters 96 | # for each goal_type 97 | goals = [] 98 | for goal_type in goal_types: 99 | # For now, we pass in a random set of goal_parameters 100 | goal_parameters = self.sample_goal_parameters( 101 | goal_type, memory_graph, goal_config 102 | ) 103 | goals.append(Goal(goal_type=goal_type, goal_parameters=goal_parameters)) 104 | 105 | return goals 106 | 107 | def sample_goal_parameters(self, goal_type, memory_graph, goal_config): 108 | # Sample goal parameters according to the input sample 109 | 110 | # TODO: IMPLEMENT ** 111 | goal_parameters = [] 112 | parameter_ontology = goal_config["parameter_ontology"] 113 | 114 | # (1) Pick a search filter 115 | search_filter = {} 116 | 117 | if goal_type in set( 118 | [GoalType.SEARCH, GoalType.REFINE_SEARCH, GoalType.GET_RELATED] 119 | ): 120 | 121 | if goal_type == GoalType.GET_RELATED: 122 | n_slots = weighted_choice(population=[1, 2], weights=[0.93, 0.07]) 123 | else: 124 | n_slots = weighted_choice(population=[1, 2], weights=[0.75, 0.25]) 125 | 126 | # Candidate slots: exclude a few slots that 127 | # are semantically infeasible 128 | # **** TODO ****: confirm that there is no slot to exclude 129 | candidate_slots = self.all_slots - set([""]) 130 | 131 | search_filter_slots = random.choices( 132 | population=list(candidate_slots), k=n_slots 133 | ) 134 | 135 | for search_filter_slot in search_filter_slots: 136 | # We first randomly assign a value for a randomly selected slot 137 | if search_filter_slot == "time": 138 | # Instead of choosing a specific datetime, 139 | # search by year or month instead. 140 | random_datetime = MemoryTime( 141 | str_datetime=random.choice( 142 | parameter_ontology["all"].get(search_filter_slot) 143 | ) 144 | ) 145 | 146 | if random.random() > 0.1: 147 | search_filter_value = str(MemoryTime(year=random_datetime.year)) 148 | 149 | else: 150 | search_filter_value = str( 151 | MemoryTime( 152 | year=random_datetime.year, month=random_datetime.month 153 | ) 154 | ) 155 | 156 | if goal_type == GoalType.GET_RELATED: 157 | # A special value for refine_search: 'next' and 'prev' 158 | # e.g. "where did we go next?" 159 | if random.random() > 0.3: 160 | search_filter_value = random.choice( 161 | ["right after", "right before", "on the same day"] 162 | ) 163 | 164 | elif search_filter_slot == "location": 165 | # TODO: Instead of choosing a specific location, 166 | # occasionally search with a coarser query. 167 | search_filter_value = random.choice( 168 | parameter_ontology["all"].get(search_filter_slot) 169 | ) 170 | 171 | if random.random() > 0.7: 172 | search_filter_value = copy.deepcopy(search_filter_value) 173 | search_filter_value["geo_tag"].get("place") 174 | 175 | else: 176 | # TODO: handle subsampling of participants & activities 177 | search_filter_value = random.choice( 178 | parameter_ontology["all"].get(search_filter_slot) 179 | ) 180 | 181 | if search_filter_value != "": 182 | search_filter[search_filter_slot] = search_filter_value 183 | 184 | # (2) Pick an object reference type 185 | object_reference_type = GoalMemoryRefType.NOT_SPECIFIED 186 | 187 | if goal_type in set([GoalType.GET_RELATED, GoalType.GET_INFO, GoalType.SHARE]): 188 | 189 | object_reference_type = weighted_choice( 190 | population=[ 191 | GoalMemoryRefType.PREV_TURN, 192 | GoalMemoryRefType.DIALOG, 193 | GoalMemoryRefType.GRAPH, 194 | ], 195 | weights=[0.8, 0.2, 0.0], 196 | ) 197 | 198 | # (3) Pick slots to request (e.g. in questions) 199 | request_slots = [] 200 | 201 | if goal_type in set([GoalType.GET_INFO]): 202 | # We randomly sample slots to ask 203 | # ****** TODO *******: make sure it's not asking about 204 | # the parameters that were already in search filter 205 | 206 | ask_from_visual_slot = random.random() > 0.9 207 | 208 | if ask_from_visual_slot: 209 | # ask about visual_slots (rare): people, activity 210 | n_request_slots = 1 211 | request_slots.extend( 212 | random.sample(self.non_visual_slots, n_request_slots) 213 | ) 214 | 215 | else: 216 | # ask about non_visual_slots: time, location 217 | n_request_slots = weighted_choice(population=[1, 2], weights=[0.8, 0.2]) 218 | request_slots.extend( 219 | random.sample(self.non_visual_slots, n_request_slots) 220 | ) 221 | 222 | elif goal_type in set([GoalType.GET_RELATED]): 223 | # We randomly sample slots to ask 224 | # iff search_filter is empty 225 | if len(search_filter) == 0: 226 | n_request_slots = weighted_choice(population=[0, 1], weights=[0.4, 0.6]) 227 | request_slots.extend(random.sample(self.all_slots, n_request_slots)) 228 | 229 | elif goal_type in set([GoalType.GET_AGGREGATED_INFO]): 230 | # ****** TODO ******* 231 | pass 232 | 233 | # (4) Compile it into a goal parameter 234 | goal_parameter = GoalParameter( 235 | filter=search_filter, 236 | reference_type=object_reference_type, 237 | request_slots=request_slots, 238 | ) 239 | goal_parameters.append(goal_parameter) 240 | 241 | return goal_parameters 242 | -------------------------------------------------------------------------------- /models/gpt2_mm/dataset_memory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. 2 | 3 | 4 | # coding: utf-8 5 | """Dataset Loader for Memory Dialogs. 6 | 7 | Author(s): noctli, skottur 8 | (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 9 | """ 10 | 11 | import json 12 | import logging 13 | import os 14 | import pickle 15 | import re 16 | from itertools import chain 17 | 18 | import numpy as np 19 | import torch 20 | import torch.utils.data 21 | import tqdm 22 | 23 | from dataset import tokenize 24 | from torch.utils.data import Dataset 25 | 26 | 27 | # from train import SPECIAL_TOKENS, MODEL_INPUTS, PADDED_INPUTS 28 | # SPECIAL_TOKENS = ["", "", "", "", "