├── .gitattributes ├── .gitignore ├── .lfsconfig ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── DSTC9_SIMMC_RESULTS.md ├── LICENSE ├── README.md ├── SUBMISSION_INSTRUCTIONS.md ├── TASK_INPUTS.md ├── data ├── README.md ├── simmc_fashion │ ├── fashion_dev_dials.json │ ├── fashion_dev_dials_retrieval_candidates.json │ ├── fashion_devtest_dials.json │ ├── fashion_devtest_dials_api_calls_teststd_format_private.json │ ├── fashion_devtest_dials_api_calls_teststd_format_public.json │ ├── fashion_devtest_dials_retrieval_candidates.json │ ├── fashion_devtest_dials_retrieval_candidates_teststd_format_private.json │ ├── fashion_devtest_dials_retrieval_candidates_teststd_format_public.json │ ├── fashion_devtest_dials_teststd_format_public.json │ ├── fashion_metadata.json │ ├── fashion_teststd_dials_api_calls.json │ ├── fashion_teststd_dials_public.json │ ├── fashion_teststd_dials_retrieval_candidates_public.json │ ├── fashion_train_dials.json │ └── fashion_train_dials_retrieval_candidates.json └── simmc_furniture │ ├── furniture_dev_dials.json │ ├── furniture_dev_dials_retrieval_candidates.json │ ├── furniture_devtest_dials.json │ ├── furniture_devtest_dials_api_calls_teststd_format_private.json │ ├── furniture_devtest_dials_api_calls_teststd_format_public.json │ ├── furniture_devtest_dials_retrieval_candidates.json │ ├── furniture_devtest_dials_retrieval_candidates_teststd_format_private.json │ ├── furniture_devtest_dials_retrieval_candidates_teststd_format_public.json │ ├── furniture_devtest_dials_teststd_format_public.json │ ├── furniture_metadata.csv │ ├── furniture_screenshot_map.json │ ├── furniture_screenshots_part_0.zip │ ├── furniture_screenshots_part_1.zip │ ├── furniture_screenshots_part_2.zip │ ├── furniture_screenshots_part_3.zip │ ├── furniture_screenshots_part_4.zip │ ├── furniture_teststd_dials_api_calls.json │ ├── furniture_teststd_dials_public.json │ ├── furniture_teststd_dials_retrieval_candidates_public.json │ ├── furniture_train_dials.json │ └── furniture_train_dials_retrieval_candidates.json ├── figures ├── simmc_dstc9_results_summary.png └── teaser.png ├── mm_action_prediction ├── README.md ├── eval_simmc_agent.py ├── loaders │ ├── __init__.py │ ├── loader_base.py │ ├── loader_simmc.py │ └── loader_vocabulary.py ├── models │ ├── __init__.py │ ├── action_executor.py │ ├── assistant.py │ ├── carousel_embedder.py │ ├── decoder.py │ ├── encoders │ │ ├── __init__.py │ │ ├── hierarchical_recurrent.py │ │ ├── history_agnostic.py │ │ ├── memory_network.py │ │ └── tf_idf_encoder.py │ ├── fashion_model_metainfo.json │ ├── furniture_model_metainfo.json │ ├── positional_encoding.py │ ├── self_attention.py │ └── user_memory_embedder.py ├── options.py ├── scripts │ ├── preprocess_simmc.sh │ ├── train_all_simmc_models.sh │ └── train_simmc_model.sh ├── tools │ ├── action_evaluation.py │ ├── build_multimodal_inputs.py │ ├── data_support.py │ ├── embed_fashion_assets.py │ ├── embed_furniture_assets.py │ ├── extract_actions.py │ ├── extract_actions_fashion.py │ ├── extract_attribute_vocabulary.py │ ├── extract_vocabulary.py │ ├── response_evaluation.py │ ├── retrieval_evaluation.py │ ├── rnn_support.py │ ├── support.py │ ├── torch_support.py │ └── weight_init.py └── train_simmc_agent.py ├── mm_dst ├── README.md ├── gpt2_dst │ ├── scripts │ │ ├── evaluate.py │ │ ├── preprocess_input.py │ │ ├── run_generation.py │ │ └── run_language_modeling.py │ └── utils │ │ └── convert.py ├── run_evaluate.sh ├── run_evaluate_gpt2.sh ├── run_generate_gpt2.sh ├── run_preprocess_gpt2.sh ├── run_train_gpt2.sh └── utils │ └── evaluate_dst.py └── mm_response_generation └── README.md /.gitattributes: -------------------------------------------------------------------------------- 1 | data/simmc_fashion/fashion_dev_dials.json filter=lfs diff=lfs merge=lfs -text 2 | data/simmc_fashion/fashion_devtest_dials.json filter=lfs diff=lfs merge=lfs -text 3 | data/simmc_fashion/fashion_test_dials.json filter=lfs diff=lfs merge=lfs -text 4 | data/simmc_fashion/fashion_train_dials.json filter=lfs diff=lfs merge=lfs -text 5 | data/simmc_furniture/furniture_dev_dials.json filter=lfs diff=lfs merge=lfs -text 6 | data/simmc_furniture/furniture_devtest_dials.json filter=lfs diff=lfs merge=lfs -text 7 | data/simmc_furniture/furniture_test_dials.json filter=lfs diff=lfs merge=lfs -text 8 | data/simmc_furniture/furniture_train_dials.json filter=lfs diff=lfs merge=lfs -text 9 | data/simmc_fashion/fashion_metadata.json filter=lfs diff=lfs merge=lfs -text 10 | data/simmc_furniture/furniture_metadata.csv filter=lfs diff=lfs merge=lfs -text 11 | data/simmc_fashion/fashion_dev_dials_retrieval_candidates.json filter=lfs diff=lfs merge=lfs -text 12 | data/simmc_fashion/fashion_devtest_dials_retrieval_candidates.json filter=lfs diff=lfs merge=lfs -text 13 | data/simmc_fashion/fashion_train_dials_retrieval_candidates.json filter=lfs diff=lfs merge=lfs -text 14 | data/simmc_furniture/furniture_train_dials_retrieval_candidates.json filter=lfs diff=lfs merge=lfs -text 15 | data/simmc_furniture/furniture_dev_dials_retrieval_candidates.json filter=lfs diff=lfs merge=lfs -text 16 | data/simmc_furniture/furniture_devtest_dials_retrieval_candidates.json filter=lfs diff=lfs merge=lfs -text 17 | data/simmc_fashion/fashion_devtest_dials_teststd_format_public.json filter=lfs diff=lfs merge=lfs -text 18 | data/simmc_furniture/furniture_devtest_dials_teststd_format_public.json filter=lfs diff=lfs merge=lfs -text 19 | data/simmc_fashion/fashion_devtest_dials_retrieval_candidates_teststd_format_private.json filter=lfs diff=lfs merge=lfs -text 20 | data/simmc_fashion/fashion_devtest_dials_retrieval_candidates_teststd_format_public.json filter=lfs diff=lfs merge=lfs -text 21 | data/simmc_fashion/fashion_devtest_dials_api_calls_teststd_format_private.json filter=lfs diff=lfs merge=lfs -text 22 | data/simmc_fashion/fashion_devtest_dials_api_calls_teststd_format_public.json filter=lfs diff=lfs merge=lfs -text 23 | data/simmc_furniture/furniture_devtest_dials_retrieval_candidates_teststd_format_private.json filter=lfs diff=lfs merge=lfs -text 24 | data/simmc_furniture/furniture_devtest_dials_retrieval_candidates_teststd_format_public.json filter=lfs diff=lfs merge=lfs -text 25 | data/simmc_furniture/furniture_devtest_dials_api_calls_teststd_format_private.json filter=lfs diff=lfs merge=lfs -text 26 | data/simmc_furniture/furniture_devtest_dials_api_calls_teststd_format_public.json filter=lfs diff=lfs merge=lfs -text 27 | data/simmc_fashion/fashion_teststd_dials_api_calls.json filter=lfs diff=lfs merge=lfs -text 28 | data/simmc_fashion/fashion_teststd_dials_public.json filter=lfs diff=lfs merge=lfs -text 29 | data/simmc_fashion/fashion_teststd_dials_retrieval_candidates_public.json filter=lfs diff=lfs merge=lfs -text 30 | data/simmc_furniture/furniture_teststd_dials_api_calls.json filter=lfs diff=lfs merge=lfs -text 31 | data/simmc_furniture/furniture_teststd_dials_public.json filter=lfs diff=lfs merge=lfs -text 32 | data/simmc_furniture/furniture_teststd_dials_retrieval_candidates_public.json filter=lfs diff=lfs merge=lfs -text 33 | data/simmc_furniture/furniture_screenshots_part_0.zip filter=lfs diff=lfs merge=lfs -text 34 | data/simmc_furniture/furniture_screenshots_part_1.zip filter=lfs diff=lfs merge=lfs -text 35 | data/simmc_furniture/furniture_screenshots_part_2.zip filter=lfs diff=lfs merge=lfs -text 36 | data/simmc_furniture/furniture_screenshots_part_3.zip filter=lfs diff=lfs merge=lfs -text 37 | data/simmc_furniture/furniture_screenshots_part_4.zip filter=lfs diff=lfs merge=lfs -text 38 | data/simmc_furniture/furniture_screenshot_map.json filter=lfs diff=lfs merge=lfs -text 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /.lfsconfig: -------------------------------------------------------------------------------- 1 | [lfs] 2 | fetchexclude = data/simmc_furniture/*.zip 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to SIMMC 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | * 4 spaces for indentation rather than tabs 31 | * 120 character line length 32 | 33 | ## License 34 | By contributing to this project, you agree that your contributions will be licensed 35 | under the LICENSE file in the root directory of this source tree. 36 | -------------------------------------------------------------------------------- /DSTC9_SIMMC_RESULTS.md: -------------------------------------------------------------------------------- 1 | ## SIMMC Track Results (DSTC9), 2020 2 | 3 | ### TL;DR 4 | The first edition of the Situated and Interactive Multimodal Conversations (SIMMC) Track at the Dialog State Tracking Challenge (DSTC) 9 came to a successful end! 5 | The challenge saw a total of **11** model entries from **5** teams across the world, setting a new state-of-the-art in all three subtasks: 6 | 7 | 1. **Subtask 1 (Assistant API Prediction)** 8 | * Action accuracy increased by 3 points (79.3% to 82.5%) 9 | * Action Attribute accuracy increased by ~10 points (63.7% to 73.9%) 10 | 2. **Subtask 2 (Assistant Response Generation)** 11 | * BLEU score increased by 0.067 points (0.061 to 0.128) 12 | * Recall@1 increased by 45 points (7.2% to 52.6%) 13 | 3. **Subtask 3 (Dialog State Tracking)** 14 | * Slot F-1 increased by 16.7 points (62.4% to 79.1%) 15 | * Intent F-1 increased by 16 points (62.1% to 78.1%) 16 | 17 | Congratulations to the winners and runners-up, and a big thanks to all the participants! 18 | 19 | ### Details 20 | We launched the SIMMC challenge in June, 2020. Please checkout the 21 | [main page][1], [paper][2], and the [Facebook AI blog][3] for more details. 22 | 23 | Please checkout the challenge page, paper, and the Facebook AI blog for further details about the challenge, dataset, tasks, and the evaluation metrics. 24 | 25 | **Results Summary** 26 | 27 |
28 | DSTC9 SIMMC Results Summary 29 |
Summary on Test-Std split, average of Furniture and Fashion.
30 |
31 | 32 | 33 | | Team | Affiliation | Github | 34 | |:----:|-----------------------------------------------------|:----:| 35 | | 1 | National Taiwan University, Taiwan | [Link][4] | 36 | | 2 | Sogang University, South Korea | [Link][5] | 37 | | 3 | Sogang University, South Korea | [Link][6] | 38 | | 4 | Institute for Infocomm Research, A-STAR, Singapore | [Link][7] | 39 | | 5 | LINKS Foundation and Politecnico di Torino, Italy | [Link][8] | 40 | 41 | 42 | **Detailed Results** 43 | 44 | All the detailed results are [here][9]. 45 | 46 | 47 | [1]: https://github.com/facebookresearch/simmc 48 | [2]: https://arxiv.org/abs/2006.01460 49 | [3]: https://ai.facebook.com/blog/simmc-a-data-set-for-developing-next-generation-shopping-assistants/ 50 | [4]: https://github.com/billkunghappy/DSTC_TRACK4_ENTER 51 | [5]: https://github.com/inkoon/simmc 52 | [6]: https://github.com/boychaboy/simmc 53 | [7]: https://github.com/i2r-simmc/i2r-simmc-2020 54 | [8]: https://github.com/D2KLab/dstc9-SIMMC 55 | [9]: https://docs.google.com/spreadsheets/d/e/2PACX-1vRPfjuesfrMrDoDZ34uNB8zDH2XutHc_ScXvao4PzUaCPXPM_uIu5hkJ2FSoByepgdEyk35Ti8lHha-/pubhtml?gid=1354274332&single=true&widget=true&headers=false 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Situated Interactive MultiModal Conversations (SIMMC) Challenge 2020 2 | 3 | Welcome to the Situated Interactive Multimodal Conversations (SIMMC) Track for [DSTC9][dstc9] 2020. 4 | 5 | The SIMMC challenge aims to lay the foundations for the real-world assistant agents that can handle multimodal inputs, and perform multimodal actions. 6 | We thus focus on **task-oriented** dialogs that encompass a **situated** multimodal user context in the form of a co-observed image or virtual reality (VR) environment. 7 | The context is **dynamically** updated on each turn based on the user input and the assistant action. 8 | Our challenge focuses on our SIMMC datasets, both of which are shopping domains: 9 | (a) furniture (grounded in a shared virtual environment) and, 10 | (b) fashion (grounded in an evolving set of images). 11 | 12 | **Organizers**: Ahmad Beirami, Eunjoon Cho, Paul A. Crook, Ankita De, Alborz Geramifard, Satwik Kottur, Seungwhan Moon, Shivani Poddar, Rajen Subba 13 | 14 |
15 | Example from SIMMC 16 |
Example from SIMMC-Furniture Dataset
17 |
18 | 19 | 20 | 21 | ### Latest News 22 | 23 | * **[Apr 15, 2021]** Released screenshots for SIMMC-Furniture 24 | ([part 0][screenshot_link_0], [part 1][screenshot_link_1], [part 2][screenshot_link_2]). 25 | Also released improved API calls with newer heuristics as SIMMC v1.2 ([PR][screenshot_pr]). 26 | * **[Dec 29, 2020]** Fixed the errors in text spans for both SIMMC-Furniture and SIMMC-Fashion, released new JSON files as SIMMC v1.1 ([PR][span_fix_pr]). 27 | * **[Sept 28, 2020]** Test-Std data released, End of Challenge Phase 1. 28 | * **[July 8, 2020]** Evaluation scripts and code to train baselines for Sub-Task #1, Sub-Task #2 released. 29 | * **[June 22, 2020]** Challenge announcement. Training / development datasets (SIMMC v1.0) are released. 30 | 31 | **Note:** DSTC9 SIMMC Challenge was conducted on SIMMC v1.0. Thus all the results and baseline performances are on SIMMC v1.0. 32 | 33 | 34 | ## Important Links 35 | 36 | * [Task Description Paper][simmc_arxiv] 37 | * [Challenge Registration](https://forms.gle/jdT79eBeySHVoa1QA) 38 | * [Data Formats](data/README.md) 39 | * **Baseline Details**: [MM Action Prediction](mm_action_prediction/README.md), [MM Response Generation](mm_response_generation/README.md), [MM-DST](mm_dst/README.md) 40 | * [Challenge Instructions](#challenge-instructions) 41 | * [Submission Instructions](SUBMISSION_INSTRUCTIONS.md) 42 | 43 | 44 | ## Timeline 45 | 46 | | **Date** | **Milestone** | 47 | | :--: | :-- | 48 | | June 22, 2020 | Training & development data released | 49 | | Sept 28, 2020 | Test-Std data released, End of Challenge Phase 1 | 50 | | Oct 5, 2020 | Entry submission deadline, End of Challenge Phase 2 | 51 | | Oct 12, 2020 | [Final results announced](DSTC9_SIMMC_RESULTS.md) | 52 | 53 | 54 | ## Track Description 55 | 56 | ### Tasks and Metrics 57 | 58 | We present three sub-tasks primarily aimed at replicating human-assistant actions in order to enable rich and interactive shopping scenarios. 59 | 60 | | Sub-Task #1 | [Multimodal Action Prediction](mm_action_prediction) | 61 | |---------|---------------------------------------------------------------------------------------------------------------------------------------| 62 | | Goal | To predict the correct Assistant API action(s) (classification) | 63 | | Input | Current user utterance, Dialog context, Multimodal context | 64 | | Output | Structural API (action & arguments) | 65 | | Metrics | Action Accuracy, Attribute Accuracy, Action Perplexity | 66 | 67 | | Sub-Task #2 | [Multimodal Dialog Response Generation & Retrieval](mm_response_generation) | 68 | |---------|---------------------------------------------------------------------------------------------------------------------------------------| 69 | | Goal | To generate Assistant responses or retrieve from a candidate pool | 70 | | Input | Current user utterance, Dialog context, Multimodal context, (Ground-truth API Calls) | 71 | | Output | Assistant response utterance | 72 | | Metrics | Generation: BLEU-4, Retrieval: MRR, R@1, R@5, R@10, Mean Rank | 73 | 74 | | Sub-Task #3 | [Multimodal Dialog State Tracking (MM-DST)](mm_dst) | 75 | |---------|---------------------------------------------------------------------------------------------------------------------------------------| 76 | | Goal | To track user belief states across multiple turns | 77 | | Input | Current user utterance, Dialogue context, Multimodal context | 78 | | Output | Belief state for current user utterance | 79 | | Metrics | Slot F1, Intent F1 | 80 | 81 | Please check the [task input](./TASK_INPUTS.md) file for a full description of inputs 82 | for each subtask. 83 | 84 | ### Evaluation 85 | 86 | For the DSTC9 SIMMC Track, we will do a two phase evaluation as follows. 87 | 88 | **Challenge Period 1**: 89 | Participants will evaluate the model performance on the provided `devtest` set. 90 | At the end of Challenge Period 1 (Sept 28), we ask participants to submit their model prediction results and a link to their code repository. 91 | 92 | **Challenge Period 2**: 93 | A `test-std` set will be released on Sept 28 for the participants who submitted the results for the Challenge Period 1. 94 | We ask participants to submit their model predictions on the `test-std` set by Oct 5. 95 | We will announce the final results and the winners on Oct 12. 96 | 97 | 98 | ## Challenge Instructions 99 | 100 | ### (1) Challenge Registration 101 | 102 | * Fill out [this form](https://forms.gle/jdT79eBeySHVoa1QA) to register at DSTC9. Check “**Track 4: Visually Grounded Dialog Track**” along with other tracks you are participating in. 103 | 104 | ### (2) Download Datasets and Code 105 | 106 | * Irrespective of participation in the challenge, we'd like to encourge those interested in this dataset to complete this [optional survey](https://oculus.qualtrics.com/jfe/form/SV_1AlazoSV7iwepZH). This will also help us communicate any future updates on the codebase, the datasets, and the challenge track. 107 | 108 | * Git clone our repository to download the datasets and the code. You may use the provided baselines as a starting point to develop your models. 109 | ``` 110 | $ git lfs install 111 | $ git clone https://github.com/facebookresearch/simmc.git 112 | ``` 113 | 114 | ### (3) Reporting Results for Challenge Phase 1 115 | * Submit your model prediction results on the `devtest` set, following the [submission instructions](./SUBMISSION_INSTRUCTIONS.md). 116 | * We will release the `test-std` set (with ground-truth labels hidden) on Sept 28. 117 | 118 | ### (4) Reporting Results for Challenge Phase 2 119 | * Submit your model prediction results on the `test-std` set, following the [submission instructions](./SUBMISSION_INSTRUCTIONS.md). 120 | * We will evaluate the participants’ model predictions using the same evaluation script for Phase 1, and announce the results. 121 | 122 | 123 | ## Contact 124 | 125 | ### Questions related to SIMMC Track, Data, and Baselines 126 | Please contact simmc@fb.com, or leave comments in the Github repository. 127 | 128 | ### DSTC Mailing List 129 | If you want to get the latest updates about DSTC9, join the [DSTC mailing list](https://groups.google.com/a/dstc.community/forum/#!forum/list/join). 130 | 131 | 132 | ## Citations 133 | 134 | If you want to publish experimental results with our datasets or use the baseline models, please cite the following articles: 135 | ``` 136 | @article{moon2020situated, 137 | title={Situated and Interactive Multimodal Conversations}, 138 | author={Moon, Seungwhan and Kottur, Satwik and Crook, Paul A and De, Ankita and Poddar, Shivani and Levin, Theodore and Whitney, David and Difranco, Daniel and Beirami, Ahmad and Cho, Eunjoon and Subba, Rajen and Geramifard, Alborz}, 139 | journal={arXiv preprint arXiv:2006.01460}, 140 | year={2020} 141 | } 142 | 143 | @article{crook2019simmc, 144 | title={SIMMC: Situated Interactive Multi-Modal Conversational Data Collection And Evaluation Platform}, 145 | author={Crook, Paul A and Poddar, Shivani and De, Ankita and Shafi, Semir and Whitney, David and Geramifard, Alborz and Subba, Rajen}, 146 | journal={arXiv preprint arXiv:1911.02690}, 147 | year={2019} 148 | } 149 | ``` 150 | **NOTE**: The [paper][simmc_arxiv] above describes in detail the datasets, the NLU/NLG/Coref annotations, and some of the baselines we provide in this challenge. The paper reports the results from an earlier version of the dataset and with different train-dev-test splits, hence the baseline performances on the challenge resources will be slightly different. 151 | 152 | ## License 153 | 154 | SIMMC is released under [CC-BY-NC-SA-4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode), see [LICENSE](LICENSE) for details. 155 | 156 | 157 | [dstc9]:https://sites.google.com/dstc.community/dstc9/home 158 | [simmc_arxiv]:https://arxiv.org/abs/2006.01460 159 | [screenshot_link_0]:./data/simmc_furniture/furniture_screenshots_part_0.zip 160 | [screenshot_link_1]:./data/simmc_furniture/furniture_screenshots_part_1.zip 161 | [screenshot_link_2]:./data/simmc_furniture/furniture_screenshots_part_2.zip 162 | [span_fix_pr]:https://github.com/facebookresearch/simmc/pull/54 163 | [screenshot_pr]:https://github.com/facebookresearch/simmc/pull/60 164 | -------------------------------------------------------------------------------- /SUBMISSION_INSTRUCTIONS.md: -------------------------------------------------------------------------------- 1 | # Final Evaluation 2 | 3 | Below we describe how the participants can submit their results, and how the winner(s) will be announced. 4 | 5 | ## Evaluation Dataset 6 | 7 | Final evaluation for the SIMMC DSTC9 track will be on the `test-std` split, different from the `devtest` split. Each test instance in `test-std` contains only `K` number of rounds (not necessarily the entire dialog), where we release the user utterances from `1` to `K` rounds, and system utterances from `1` to `K-1` utterances. Please refer to [this table](./TASK_INPUTS.md) that lists the set of allowed inputs for each subtask. 8 | 9 | For subtask 1, evaluation is on the assistant action (API call) for `K`th round. 10 | For subtask 2, evaluation is on the assistant utterance generation for `K`th round. 11 | For subtask 3, evaluation is on dialog state prediction based on user utterances from `1` through `K`. 12 | 13 | For subtasks 1 and 2 there are 1.2K predictions (1 per dialogue). For subtask 3 there are mean(`K`) * number of dialogues predictions. 14 | 15 | We provide: 16 | 17 | * **`devtest`, in the `test-std` format**: to give participants an early heads-up on how the `test-std` dataset will look like, we re-formatted the already-released `devtest` set in the format of the `test-std` file. Please ensure that your script and model are compatible and can run on [fashion_devtest_dials_teststd_format_public.json](./data/simmc_fashion/fashion_devtest_dials_teststd_format_public.json) and [furniture_devtest_dials_teststd_format_public.json](./data/simmc_furniture/furniture_devtest_dials_teststd_format_public.json). Please note that the Evaluation Phase 1 is on the entire `devtest` set. 18 | 19 | * **`test-std`**: In the [main data folder](./data), we release the `test-std` dataset for Evalaution Phase 2: Please check out `./data/simmc_{domain}/{domain}_teststd_dials{_|_api_calls_|_retrieval_candidates_}public.json`, and report the prediction results on those following the instructions below. 20 | 21 | 22 | ## Evaluation Criteria 23 | 24 | | **Subtask** | **Evaluation** | **Metric Priority List** | 25 | | :-- | :-- | :-- | 26 | | Subtask 1 (Multimodal Assistant API Prediction) | On assistant action (API call) for `K`th round | Action Accuracy, Attribute Accuracy, Action Perplexity | 27 | | Subtask 2 (Multimodal Assistant Response Generation) | On assistant utterance generation for `K`th round | * Generative category: BLEU-4
* Retrieval category: MRR, R@1, R@5, R@10, Mean Rank | 28 | | Subtask 3 (Multimodal Dialog State Tracking) | On dialog state based on user utterances from 1 through `K` | Slot F1, Intent F1 | 29 | 30 | **Separate winners** will be announced for each subtask based on the respective performance, with the exception of subtask 2 (response generation) that will have two winners based on two categories -- generative metrics and retrieval metrics. 31 | 32 | Rules to select the winner for each subtask (and categories) are given below: 33 | 34 | * For each subtask, we enforce a **priority over the respective metrics** (shown above) to highlight the model behavior desired by this challenge 35 | 36 | * The entry with the most favorable (higher or lower) performance on the metric will be labelled as a winner candidate. Further, all other entries within one standard error of this candidate’s performance will also be considered as candidates. If there are more than one candidate according to the metric, we will move to the next metric in the priority list and repeat this process until we have a single winner candidate, which would be declared as the "**subtask winner**". 37 | 38 | * In case of multiple candidates even after running through the list of metrics in the priority order, all of them will be declared as "**joint subtask winners**". 39 | 40 | **NOTE**: Only entries that are able to open-sourced their code will be considered for the final evaluation. In all other cases, we can only give “honorable mentions” depending on the devtest performance and cannot declare them as winners of any subtask. 41 | 42 | 43 | ## Submission Format 44 | 45 | Participants must submit the model prediction results in JSON format that can be scored with the automatic scripts provided for that sub-task. Specifically, please name your JSON output as follows (format for subtask1 and 2 is given in the respective READMEs): 46 | 47 | ``` 48 | 49 | dstc9-simmc-teststd-{domain}-subtask-1.json 50 | 51 | 52 | dstc9-simmc-teststd-{domain}-subtask-2-generation.json 53 | dstc9-simmc-teststd-{domain}-subtask-2-retrieval.json 54 | 55 | 56 | dstc9-simmc-teststd-{domain}-subtask-3.txt (line-separated output) 57 | or 58 | dstc9-simmc-teststd-{domain}-subtask-3.json (JSON format) 59 | ``` 60 | 61 | The SIMMC organizers will then evaluate them internally using the following scripts: 62 | 63 | ``` 64 | 65 | python tools/action_evaluation.py \ 66 | --action_json_path={PATH_TO_API_CALLS} \ 67 | --model_output_path={PATH_TO_MODEL_PREDICTIONS} \ 68 | --single_round_evaluation 69 | 70 | 71 | python tools/response_evaluation.py \ 72 | --data_json_path={PATH_TO_GOLD_RESPONSES} \ 73 | --model_response_path={PATH_TO_MODEL_RESPONSES} \ 74 | --single_round_evaluation 75 | 76 | 77 | python tools/retrieval_evaluation.py \ 78 | --retrieval_json_path={PATH_TO_GROUNDTRUTH_RETRIEVAL} \ 79 | --model_score_path={PATH_TO_MODEL_CANDIDATE_SCORES} \ 80 | --single_round_evaluation 81 | 82 | 83 | (line-by-line evaluation) 84 | python -m gpt2_dst.scripts.evaluate \ 85 | --input_path_target={PATH_TO_GROUNDTRUTH_TARGET} \ 86 | --input_path_predicted={PATH_TO_MODEL_PREDICTIONS} \ 87 | --output_path_report={PATH_TO_REPORT} 88 | 89 | (Or, dialog level evaluation) 90 | python -m utils.evaluate_dst \ 91 | --input_path_target={PATH_TO_GROUNDTRUTH_TARGET} \ 92 | --input_path_predicted={PATH_TO_MODEL_PREDICTIONS} \ 93 | --output_path_report={PATH_TO_REPORT} 94 | ``` 95 | 96 | ## Submission Instructions and Timeline 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 |
Before Sept 28th 2020Each TeamEach participating team should create a repository, e.g. in github.com, that can be made public under a permissive open source license (MIT License preferred). Repository doesn’t need to be publicly viewable at that time.
Before Sept 28th tag a repository commit that contains both runable code and model parameter files that are the team’s entries for all sub-tasks attempted.
Tag commit with `dstc9-simmc-entry`.
Models (model parameter files) and code should have associated date-time stamps which are before Sept 27 23:59:59 anywhere on Earth.
Sept 28th 2020SIMMC OrganizersTest-Std data released (during US Pacific coast working hours).
Before Oct 5th 2020Each TeamGenerate test data predictions using the code & model versions tagged previously with `dstc9-simmc-entry`.
For each sub-task attempted, create a PR and check-in to the team’s repository where: 126 |
    127 |
  • The PR/check-in contains an output directory with the model output in JSON format that can be scored with the automatic scripts provided for that sub-task.
  • 128 |
  • The PR comments contain a short technical summary of model.
  • 129 |
  • Tag the commit with `dstc9-simmc-test-subtask-{N}`; where `{N}` is the sub-task number.
  • 130 |
131 |
By Oct 5th 2020Each TeamMake the team repository public under a permissive Open Source license (MIT license is prefered).
Email the SIMMC Organizers a link to the repository at simmc@fb.com
Oct 5th - Oct 12th 2020SIMMC OrganizersSIMMC organizers to validate sub-task results.
Oct 12th 2020SIMMC OrganizersPublish anonymized team rankings on the SIMMC track github and email each team their anonymized team identity.
Post Oct 12th 2020SIMMC OrganizersOur plan is to write up a challenge summary paper. In this we may conduct error analysis of the results and may look to extend, e.g. possibly with human scoring, the submitted results.
158 | -------------------------------------------------------------------------------- /TASK_INPUTS.md: -------------------------------------------------------------------------------- 1 | **Allowed Inputs** 2 | 3 | * The guideline below shows the input fields that are allowed (default) and disallowed (marked as 'X') at **inference time**, for each subtask. 4 | * Participants are free to use any of the fields below during **training** though as additional supervision signals, and *e.g.* at the inference time use the reconstructed / predicted values instead. 5 | 6 | 7 | | Key | Subtask #1
(API Prediction) | Subtask #2
(Response Generation) | Subtask #3
(MM-DST) | 8 | |:---|:---:|:---:|:---:| 9 | |**JSON File (Turn Level Input Fields)**| | | | 10 | | `belief_state` | ✗ | ✗ | ✗
(prediction target) | 11 | | `domain` | 12 | |`state_graph_0`| ✗ | ✗ | ✗ | 13 | |`state_graph_1`| ✗ | ✗ | ✗ | 14 | |`state_graph_2`| ✗ | ✗ | ✗ | 15 | |`system_transcript`
(current turn) | ✗ | ✗
(prediction target) | ✗ | 16 | |`system_transcript`
(previous turns)| | | | 17 | |`system_transcript_annotated`| ✗ | ✗ | ✗ | 18 | |`system_turn_label`| ✗ | ✗ | ✗ | 19 | |`transcript`| | | | 20 | | `transcript_annotated` | ✗ | ✗ | ✗ | 21 | |`turn_idx`| | | | 22 | |`turn_label`| ✗ | ✗ | ✗ | 23 | |`visual_objects`| | | | 24 | |`raw_assistant_keystrokes`| ✗ | ✗ | ✗ | 25 | |**JSON File (Dialog Level Input Fields)**| | | | 26 | |`dialogue_coref_map`| ✗ | ✗ | ✗ | 27 | | `dialogue_idx` | 28 | | `domains` | 29 | |**API Call File**| | | | 30 | |`action`
(current turn)| ✗
(prediction target) | | ✗ | 31 | |`action`
(previous turns)| | | | 32 | |`action_supervision`
(current turn)| ✗ | | ✗ | 33 | |`action_supervision`
(previous turns)| | | | 34 | |`focus_images (Fashion)`| | | | 35 | |`carousel_state (Furniture)`| | | | 36 | |`action_output_state(Furniture)`| ✗ | | ✗ | 37 | |**Metadata Files**| | | | 38 | |`fashion_metadata.json`| | | | 39 | |`furniture_metadata.csv`| | | | 40 | 41 | **Notes** 42 | 43 | `transcript_annotated` provides the detailed structural intents, slots and values for each USER turn, including the text spans. `system_transcript_annotated` provides the similar information for ASSISTANT turns. 44 | 45 | `turn_label` expands `transcript_annotated` with the coreference labels annotated as well. `objects` field in `turn_label` includes a list of objects referred to in each turn - each marked with a local index throughout the dialog (`obj_idx`) and `obj_type`. `system_turn_label` provides the similar information for ASSISTANT turns. 46 | 47 | `belief_state` provides the intents, slots, and values, where their slots and values are cumulative throughout the dialog whenever applicable. Each slot name is prepended with its domain name, e.g. `{domain}-{slot_name}`. Specifically, we include an object slot called `{domain}-O` whose values are `OBJECT_{local_idx}`. For instance, a `belief_state` with `act: DA:REQUEST:ADD_TO_CART:CLOTHING` with a slot `[[‘fashion-O’, ‘OBJECT_2’], [‘fashion-O’, ‘OBJECT_3’]]` would annotate a user belief state with the intention of adding objects 2 and 3 to the cart. 48 | 49 | The entire catalog information is stored in either `fashion_metadata.json` or `furniture_metadata.csv`. The API calls provide the state of the carousel (`furniture`) or focus item (`fashion`) after the ground truth API / actions have been called. By using these two, one should be able to retrieve the entire information about the catalog items that are potentially described in the system response. 50 | 51 | For more details, please refer to the full description in the [data README document](https://github.com/facebookresearch/simmc/tree/master/data). 52 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # SIMMC Datasets 2 | 3 | ## Summary 4 | 5 | Our challenge focuses on two SIMMC datasets, both in the shopping domain: 6 | (a) furniture (grounded in a shared virtual environment) and, 7 | (b) fashion (grounded in an evolving set of images). 8 | 9 | Both datasets were collected through the SIMMC Platform, an extension to ParlAI for multimodal conversational data collection and system evaluation that allows human annotators to each play the role of either the assistant or the user. 10 | 11 | The following papers describe in detail the dataset, the collection platform, and the NLU/NLG/Coref annotations we provide: 12 | 13 | Seungwhan Moon*, Satwik Kottur*, Paul A. Crook^, Ankita De^, Shivani Poddar^, Theodore Levin, David Whitney, Daniel Difranco, Ahmad Beirami, Eunjoon Cho, Rajen Subba, Alborz Geramifard. ["Situated and Interactive Multimodal Conversations"](https://arxiv.org/pdf/2006.01460.pdf) (2020). 14 | 15 | Paul A. Crook*, Shivani Poddar*, Ankita De, Semir Shafi, David Whitney, Alborz Geramifard, Rajen Subba. ["SIMMC: Situated Interactive Multi-Modal Conversational Data Collection And Evaluation Platform"](https://arxiv.org/pdf/1911.02690.pdf) (2020). 16 | 17 | If you want to publish experimental results with our datasets or use the baseline models, please cite the following articles: 18 | ``` 19 | @article{moon2020situated, 20 | title={Situated and Interactive Multimodal Conversations}, 21 | author={Moon, Seungwhan and Kottur, Satwik and Crook, Paul A and De, Ankita and Poddar, Shivani and Levin, Theodore and Whitney, David and Difranco, Daniel and Beirami, Ahmad and Cho, Eunjoon and Subba, Rajen and Geramifard, Alborz}, 22 | journal={arXiv preprint arXiv:2006.01460}, 23 | year={2020} 24 | } 25 | 26 | @article{crook2019simmc, 27 | title={SIMMC: Situated Interactive Multi-Modal Conversational Data Collection And Evaluation Platform}, 28 | author={Crook, Paul A and Poddar, Shivani and De, Ankita and Shafi, Semir and Whitney, David and Geramifard, Alborz and Subba, Rajen}, 29 | journal={arXiv preprint arXiv:1911.02690}, 30 | year={2019} 31 | ``` 32 | 33 | ### Dataset Splits 34 | 35 | We randomly split each of our SIMMC-Furniture and SIMMC-Fashion datasets into four components: 36 | 37 | | **Split** | **Furniture** | **Fashion** | 38 | | :--: | :--: | :--: | 39 | | Train (60%) | 3839 | 3929 | 40 | | Dev (10%) | 640 | 655 | 41 | | Test-Dev (15%) | 960 | 982 | 42 | | Test-Std (15%) | 960 | 983 | 43 | 44 | **NOTE** 45 | * **Dev** is for hyperparameter selection and other modeling choices. 46 | * **Test-Dev** is the publicly available test set to measure model performance and report results outside the challenge. 47 | * **Test-Std** is used as the main test set for evaluation for Challenge Phase 2 (to be released on Sept 28). 48 | 49 | ## Download the Datasets 50 | We are hosting our datasets in this Github Repository (with [Git LFS](https://git-lfs.github.com/)). 51 | First, install Git LFS 52 | ``` 53 | $ git lfs install 54 | ``` 55 | 56 | Clone our repository to download both the dataset and the code: 57 | ``` 58 | $ git clone https://github.com/facebookresearch/simmc.git 59 | ``` 60 | 61 | ## Overview of the Dataset Repository 62 | 63 | The data are made available for each `domain` (`simmc_furniture` | `simmc_fashion`) in the following files: 64 | ``` 65 | [Main Data] 66 | - full dialogs: ./{domain}/{train|dev|devtest|test}_dials.json 67 | - list of dialog IDs per split: ./{domain}/{train|dev|devtest|test}_dialog_ids 68 | 69 | [Metadata] 70 | - Fashion metadta: ./simmc_fashion/fashion_metadata.json 71 | - Furniture metadata: ./simmc_furniture/furniture_metadata.csv 72 | - images: ./simmc-furniture/figures/{object_id}.png 73 | ``` 74 | **NOTE**: The test set will be made available after DSTC9. 75 | 76 | ## Data Format 77 | 78 | For each `{train|dev|devtest}` split, the JSON data (`./{domain}/{train|dev|devtest}_dials.json` 79 | ) is formatted as follows: 80 | 81 | 82 | ``` 83 | { 84 | "split": support.extract_split_from_filename(json_path), 85 | "version": 1.0, 86 | "year": 2020, 87 | "domain": FLAGS.domain, 88 | "dialogue_data": [ 89 | { 90 | “dialogue”: [ 91 | { 92 | “belief_state”: [ 93 | { 94 | “act”: , 95 | “slots”: [ 96 | [ slot_name, slot_value ], // end of a slot name-value pair 97 | ... 98 | ] 99 | }, // end of an act-slot pair 100 | ... 101 | ], 102 | “domain”: , 103 | “raw_assistant_keystrokes”: , 104 | “state_graph_{idx}”: , 105 | “syste_belief_state”: , 106 | “system_transcript”: , 107 | “system_transcript_annotated”: , 108 | “transcript”: , 109 | “transcript_annotated”: , 110 | “turn_idx”: , 111 | “turn_label”: [ ], 112 | “visual_objects”: 113 | }, // end of a turn (always sorted by turn_idx) 114 | ... 115 | ], 116 | “dialogue_coref_map”: { 117 | // map from object_id to local_id re-indexed for each dialog 118 | : 119 | }, 120 | “dialogue_idx”: , 121 | “domains”: [ ] 122 | } 123 | ] 124 | } 125 | ``` 126 | The data can be processed with respective data readers / preprocessing scripts for each sub-task (please refer to the respective README documents). Each sub-task will describe which fields can be used as input. 127 | 128 | **NOTES** 129 | 130 | `transcript_annotated` provides the detailed structural intents, slots and values for each USER turn, including the text spans. `system_transcript_annotated` provides the similar information for ASSISTANT turns. 131 | 132 | `turn_label` expands `transcript_annotated` with the coreference labels annotated as well. `objects` field in `turn_label` includes a list of objects referred to in each turn - each marked with a local index throughout the dialog (`obj_idx`) and `obj_type`. `system_turn_label` provides the similar information for ASSISTANT turns. 133 | 134 | `belief_state` provides the intents, slots, and values, where their slots and values are cumulative throughout the dialog whenever applicable. Each slot name is prepended with its domain name, e.g. `{domain}-{slot_name}`. Specifically, we include an object slot called `{domain}-O` whose values are `OBJECT_{local_idx}`. For instance, a `belief_state` with `act: DA:REQUEST:ADD_TO_CART:CLOTHING` with a slot `[[‘fashion-O’, ‘OBJECT_2’], [‘fashion-O’, ‘OBJECT_3’]]` would annotate a user belief state with the intention of adding objects 2 and 3 to the cart. 135 | 136 | `visual_objects` refer to the list of objects and their visual attributes that are shown to the user at each given turn (via a VR environment or an image). 137 | ``` 138 | { 139 | obj_name: { 140 | attribute_name: or attribute_values 141 | } 142 | } 143 | ``` 144 | 145 | `state_graph_{idx}` refers to the graph representation of the cumulative dialog and the multimodal contexts known to the user, each at a different phase during the dialog (e.g. via a multimodal action of showing items, an assistant providing information, a user providing preferences, etc.). 146 | - state_graph_0: initial state before the start of the user utterance 147 | - state_graph_1: state modified after the user utterance 148 | - state_graph_2: final state modified after the assistant utterance & assistant action. 149 | 150 | Participants may use this information for inspection, or as additional training signals for some of the sub-tasks (but not at inference time). `belief_state`, `system_beilef_state`, and `visual_objects` provide the same information. 151 | 152 | Each state graph is represented as follows: 153 | ``` 154 | { 155 | obj_name: { 156 | attribute_name: or attribute_values 157 | } 158 | } 159 | ``` 160 | 161 | `raw_assistant_keystrokes` are the raw UI interactions made by the human Assistant (wizard) using the Unity interface during data collection. We distil target actions for the action prediction task (sub-task #1) from these raw keystrokes and NLU/NLG annotation 162 | 163 | We also release the metadata for each object referred in the dialog data: 164 | ``` 165 | 166 | { 167 | object_id: { 168 | “metadata”: {dict}, 169 | “url”: source image 170 | }, // end of an object 171 | } 172 | 173 | 174 | columns: 175 | - product_name 176 | - product_description 177 | - product_thumbnail_image_url 178 | - material 179 | - color 180 | - obj ({object_id}.zip) 181 | ... 182 | ``` 183 | Attributes for each object either pulled from the original sources or annotated manually. Note that some of the catalog-specific attributes (e.g. availableSizes, brand, etc.) were randomly and synthetically generated. 184 | 185 | Each item in a catalog metadata has a unique ` object_id`. `dialog_coref_map` defines the mapping from the `local_idx` (local to each dialog), to its canonical `object_id` reference, for each dialog. This `local_idx` is used in `belief_state` as an object slot. For example, given a `dialog_coref_map = {0: 123, 1: 234, 2: 345}` -- the belief state: `{‘act’: ‘DA:REQUEST:ADD_TO_CART’, ‘slots’: [‘O’: ‘OBJECT_2’]}` would indicate this particular dialog act performed upon `OBJECT_2` (`2 == local_idx`), which has a canonical reference to an object with `object_id: 345`. We are including this information in case you want to refer to the additional information provided in the `metadata.{json|csv}` file. 186 | -------------------------------------------------------------------------------- /data/simmc_fashion/fashion_dev_dials.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a561e0bb90475bf84ca69f21f3d7c66b04853ac92eb36112704d9fd9895e718b 3 | size 10390760 4 | -------------------------------------------------------------------------------- /data/simmc_fashion/fashion_dev_dials_retrieval_candidates.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d9c03dc2f1f0dc9f3309502b232c141876827aa2dd88a629105d1dd1cf7e8a95 3 | size 2429407 4 | -------------------------------------------------------------------------------- /data/simmc_fashion/fashion_devtest_dials.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:18a9a2bb2fefb8eb64f5cca705a107b4135e0c1bcba7687446fd8678449ec335 3 | size 16193739 4 | -------------------------------------------------------------------------------- /data/simmc_fashion/fashion_devtest_dials_api_calls_teststd_format_private.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:39a1cb2e4b0c973603ff09deab01e2513205d5005b207838100b236f195e46c2 3 | size 283653 4 | -------------------------------------------------------------------------------- /data/simmc_fashion/fashion_devtest_dials_api_calls_teststd_format_public.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f3b0446ff012ecedd77731fb4c698b1f6e2d6642e37f7e2aa623e7ea56c66c2f 3 | size 375695 4 | -------------------------------------------------------------------------------- /data/simmc_fashion/fashion_devtest_dials_retrieval_candidates.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5698e5c7c42ed3d4d90fa90eae4520adb62a3352af7ea1e37bb209c49604a50b 3 | size 3795628 4 | -------------------------------------------------------------------------------- /data/simmc_fashion/fashion_devtest_dials_retrieval_candidates_teststd_format_private.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4cc04ea10b4c48d3836e13ce0d02de9773bea4450ad78abbcae9bb28aa5e452f 3 | size 2476135 4 | -------------------------------------------------------------------------------- /data/simmc_fashion/fashion_devtest_dials_retrieval_candidates_teststd_format_public.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4cc04ea10b4c48d3836e13ce0d02de9773bea4450ad78abbcae9bb28aa5e452f 3 | size 2476135 4 | -------------------------------------------------------------------------------- /data/simmc_fashion/fashion_devtest_dials_teststd_format_public.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:dda4c8bc91823d8a5b81b0f2999713052da51fce69e2edc881e94f73b91277c6 3 | size 3579702 4 | -------------------------------------------------------------------------------- /data/simmc_fashion/fashion_metadata.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:56831bbdc79ccc351ed6231b6ead64c9788c925a036131b7caab00e688285207 3 | size 1633149 4 | -------------------------------------------------------------------------------- /data/simmc_fashion/fashion_teststd_dials_api_calls.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d53fe269ec10204d828095abf5365478374913998e96ca491eb6b31c4c15a3a0 3 | size 283588 4 | -------------------------------------------------------------------------------- /data/simmc_fashion/fashion_teststd_dials_public.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7d095d00c832156a795defc041ccd2d83b78bf30f00f66f0a6a6e1e287ee547b 3 | size 3529872 4 | -------------------------------------------------------------------------------- /data/simmc_fashion/fashion_teststd_dials_retrieval_candidates_public.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e0e684f5b2678e9cd88401efde162f66038b7ed9f21b4f133fe6fae027f75803 3 | size 2456265 4 | -------------------------------------------------------------------------------- /data/simmc_fashion/fashion_train_dials.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:efde37cb1f00009e5ed63d70d5576d8ae5d7e1a5c922e1031f20e2aa9911a723 3 | size 62887336 4 | -------------------------------------------------------------------------------- /data/simmc_fashion/fashion_train_dials_retrieval_candidates.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a41841ffe583347198dd8f2784d61da22e19c7df4d2e08d1b9debdb2f3535738 3 | size 16193506 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_dev_dials.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fc8d2cb184d7599221b1dff0ee8da16f26c0c0b8bf1d2451cfcfe1db4634ef5c 3 | size 19192164 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_dev_dials_retrieval_candidates.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:131b8174a03d8538861904846a9815fd4aa3e5ecdd24b73149f06b4e2e5c5d61 3 | size 3350259 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_devtest_dials.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:22d46eb16d2b476be1b318dd78c00ada335dd8cef897fcb4bdd6fc4ab8ea1ae5 3 | size 30443325 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_devtest_dials_api_calls_teststd_format_private.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bda19380611dac455e3a80e53e040a4f219e45c5971bab53fa886b18550edd3e 3 | size 7420590 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_devtest_dials_api_calls_teststd_format_public.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:98039fd0e199affb59e930922e734792bc8444ca1f3a10d4036130610c5de9ae 3 | size 9793941 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_devtest_dials_retrieval_candidates.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b2ed1c55b743c599cef5e0282f5b229f0035d20ccebf63eb7533ef4c717b1982 3 | size 5163613 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_devtest_dials_retrieval_candidates_teststd_format_private.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e28243b04f261533bfc825e39353d3389c8f282fcfda70938178e642debc9462 3 | size 3005836 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_devtest_dials_retrieval_candidates_teststd_format_public.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e28243b04f261533bfc825e39353d3389c8f282fcfda70938178e642debc9462 3 | size 3005836 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_devtest_dials_teststd_format_public.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:32082d5576147c37df8ae3e9830795bfbbb4ef16b6db05fa339ed03f8c461042 3 | size 3295308 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_metadata.csv: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:9a04d021520169425da0f7cb655091e5e825df0e9c22cc49987de66ad65787ac 3 | size 162471 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_screenshot_map.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:02464dc904eebadced867316cb0b22827038fbe6e1e63845d2a0226a1f3f25ac 3 | size 8815128 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_screenshots_part_0.zip: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:593f5dcc347704cbcfdf56492d1c9914580dc16c3f23ae1a0760287534954d78 3 | size 1828315793 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_screenshots_part_1.zip: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:0b97eff5bfc4a43e586c846cbf0ad2e4b36437f3dbb04461a904022e7c622f10 3 | size 1833003830 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_screenshots_part_2.zip: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2e4ac61bd68841f258df921fb00477224889690915e9fc8d658b395c11e7cda8 3 | size 1824970766 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_screenshots_part_3.zip: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:fae518466622c8556fe4d39987b14bf6188e2f9efc66fbd830170d2f43f6e7f3 3 | size 1825655585 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_screenshots_part_4.zip: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:98edd6c6a5e9972c7dbc6b71e75e11d4dd2578a8675e86538bf1e75d423f42b9 3 | size 1252560255 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_teststd_dials_api_calls.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6b81252fe09ca22462cd259e4238e9fe8cac3c5109e4c74cbdc9a40de08c4f98 3 | size 7731156 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_teststd_dials_public.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4e2dfbd7e4884cc051ce97adf708b42c5aba86b7e866ca87f7857b6413859313 3 | size 3414744 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_teststd_dials_retrieval_candidates_public.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4b781f8db72265188c4bca9134807df9b41a6fe91cabeb5140961a81008b3259 3 | size 3146285 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_train_dials.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c392fabd7d3008cc137cbde02e2efeb82f219914fbb6f27e3a394000b6f41a40 3 | size 122377679 4 | -------------------------------------------------------------------------------- /data/simmc_furniture/furniture_train_dials_retrieval_candidates.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:55985dbeade284a040487ca833bc158585efeea09657ea06ec5080bf26c65f5f 3 | size 22596588 4 | -------------------------------------------------------------------------------- /figures/simmc_dstc9_results_summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/simmc/27474218ee927e757ee55eec1094813dc84acb16/figures/simmc_dstc9_results_summary.png -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/simmc/27474218ee927e757ee55eec1094813dc84acb16/figures/teaser.png -------------------------------------------------------------------------------- /mm_action_prediction/eval_simmc_agent.py: -------------------------------------------------------------------------------- 1 | """Evaluate SIMMC agent for Furniture and Fashion datasets. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import argparse 9 | import json 10 | import math 11 | import torch 12 | from tqdm import tqdm as progressbar 13 | 14 | import loaders 15 | import models 16 | from tools import support 17 | 18 | 19 | def main(args): 20 | """Evaluate model and save the results. 21 | """ 22 | # Read the checkpoint and train args. 23 | print("Loading checkpoint: {}".format(args["checkpoint"])) 24 | checkpoint = torch.load(args["checkpoint"], map_location=torch.device("cpu")) 25 | saved_args = checkpoint["args"] 26 | saved_args.update(args) 27 | # Save model outputs for teststd. 28 | if "teststd" in args["eval_data_path"].rsplit("/", 1)[1]: 29 | saved_args["save_model_output"] = True 30 | support.pretty_print_dict(saved_args) 31 | 32 | # Dataloader for evaluation. 33 | dataloader_args = { 34 | "single_pass": True, 35 | "shuffle": False, 36 | "data_read_path": args["eval_data_path"], 37 | "get_retrieval_candidates": True 38 | } 39 | dataloader_args.update(saved_args) 40 | val_loader = loaders.DataloaderSIMMC(dataloader_args) 41 | saved_args.update(val_loader.get_data_related_arguments()) 42 | 43 | # Model. 44 | wizard = models.Assistant(saved_args) 45 | # Load the checkpoint. 46 | wizard.load_state_dict(checkpoint["model_state"]) 47 | 48 | # Evaluate the SIMMC model. 49 | eval_dict, eval_outputs = evaluate_agent(wizard, val_loader, saved_args) 50 | save_path = saved_args["checkpoint"].replace(".tar", "_eval.json") 51 | print("Saving results: {}".format(save_path)) 52 | with open(save_path, "w") as file_id: 53 | json.dump(eval_dict, file_id) 54 | 55 | 56 | def evaluate_agent(wizard, val_loader, args): 57 | """Evaluate a SIMMC agent given a dataloader. 58 | 59 | Args: 60 | wizard: SIMMC model 61 | dataloader: Dataloader to use to run the model on 62 | args: Arguments for evaluation 63 | """ 64 | total_iters = int(val_loader.num_instances / args["batch_size"]) 65 | # Turn autograd off for evaluation -- light-weight and faster. 66 | with torch.no_grad(): 67 | wizard.eval() 68 | matches = [] 69 | for batch in progressbar(val_loader.get_batch(), total=int(total_iters)): 70 | if args["bleu_evaluation"]: 71 | mode = {"next_token": "ARGMAX", "beam_size": 5} 72 | else: 73 | mode = None 74 | batch_outputs = wizard(batch, mode) 75 | # Stringify model responses. 76 | if args["bleu_evaluation"]: 77 | batch_outputs["model_response"] = ( 78 | val_loader.stringify_beam_outputs( 79 | batch_outputs["beam_output"], batch 80 | ) 81 | ) 82 | # Remove beam output to avoid memory issues. 83 | del batch_outputs["beam_output"] 84 | matches.append(batch_outputs) 85 | wizard.train() 86 | 87 | # Compute perplexity. 88 | total_loss_sum = sum(ii["loss_sum"].item() for ii in matches) 89 | num_tokens = sum(ii["num_tokens"].item() for ii in matches) 90 | avg_loss_eval = total_loss_sum / num_tokens 91 | 92 | # Compute BLEU score. 93 | model_responses = None 94 | bleu_score = -1. 95 | if args["bleu_evaluation"]: 96 | model_responses = [jj for ii in matches for jj in ii["model_response"]] 97 | # Save the JSON file. 98 | if args.get("save_model_output", False): 99 | save_path = args["checkpoint"].replace(".tar", "_response_gen.json") 100 | with open(save_path, "w") as file_id: 101 | json.dump(model_responses, file_id) 102 | else: 103 | bleu_score = val_loader.evaluate_response_generation(model_responses) 104 | 105 | # Evaluate retrieval score. 106 | retrieval_metrics = {} 107 | if args["retrieval_evaluation"]: 108 | candidate_scores = [jj for ii in matches for jj in ii["candidate_scores"]] 109 | # Save the JSON file. 110 | if args.get("save_model_output", False): 111 | save_path = args["checkpoint"].replace(".tar", "_response_ret.json") 112 | with open(save_path, "w") as file_id: 113 | json.dump(candidate_scores, file_id) 114 | else: 115 | retrieval_metrics = val_loader.evaluate_response_retrieval( 116 | candidate_scores 117 | ) 118 | print(retrieval_metrics) 119 | 120 | # Evaluate action prediction. 121 | action_predictions = [jj for ii in matches for jj in ii["action_preds"]] 122 | # Save the JSON file. 123 | if args.get("save_model_output", False): 124 | save_path = args["checkpoint"].replace(".tar", "_action_gen.json") 125 | with open(save_path, "w") as file_id: 126 | json.dump(action_predictions, file_id) 127 | action_metrics = val_loader.evaluate_action_prediction(action_predictions) 128 | print(action_metrics["confusion_matrix"]) 129 | print_str = ( 130 | "\nEvaluation\n\tLoss: {:.2f}\n\t" 131 | "Perplexity: {:.2f}\n\tBLEU: {:.3f}\n\t" 132 | "Action: {:.2f}\n\t" 133 | "Action Perplexity: {:.2f}\n\t" 134 | "Action Attribute Accuracy: {:.2f}" 135 | ) 136 | print( 137 | print_str.format( 138 | avg_loss_eval, 139 | math.exp(avg_loss_eval), 140 | bleu_score, 141 | 100 * action_metrics["action_accuracy"], 142 | action_metrics["action_perplexity"], 143 | 100 * action_metrics["attribute_accuracy"] 144 | ) 145 | ) 146 | # Save the results to a file. 147 | eval_dict = { 148 | "loss": avg_loss_eval, 149 | "perplexity": math.exp(avg_loss_eval), 150 | "bleu": bleu_score, 151 | "action_accuracy": action_metrics["action_accuracy"], 152 | "action_perplexity": action_metrics["action_perplexity"], 153 | "action_attribute": action_metrics["attribute_accuracy"] 154 | } 155 | eval_dict.update(retrieval_metrics) 156 | eval_outputs = { 157 | "model_actions": action_predictions, 158 | "model_responses": model_responses 159 | } 160 | return eval_dict, eval_outputs 161 | 162 | 163 | if __name__ == "__main__": 164 | # Read command line options. 165 | parser = argparse.ArgumentParser() 166 | parser.add_argument("--checkpoint", required=True, help="Checkpoint to load") 167 | parser.add_argument("--batch_size", default=10, type=int, help="Batch size") 168 | parser.add_argument( 169 | "--eval_data_path", required=True, help="Evaluation data split" 170 | ) 171 | parser.add_argument("--gpu_id", type=int, default=-1) 172 | parser.add_argument( 173 | "--skip_bleu_evaluation", 174 | dest="bleu_evaluation", 175 | action="store_false", 176 | default=True, 177 | help="Use beamsearch to compute BLEU score when evaluation" 178 | ) 179 | parser.add_argument( 180 | "--skip_retrieval_evaluation", 181 | dest="retrieval_evaluation", 182 | action="store_false", 183 | default=True, 184 | help="Evaluation response generation through retrieval" 185 | ) 186 | parser.add_argument( 187 | "--domain", 188 | default=None, 189 | choices=["furniture", "fashion"], 190 | help="Domain to train the model on", 191 | ) 192 | try: 193 | args = vars(parser.parse_args()) 194 | except (IOError) as msg: 195 | parser.error(str(msg)) 196 | # Setup CUDA environment. 197 | args["use_gpu"] = support.setup_cuda_environment(args["gpu_id"]) 198 | 199 | main(args) 200 | -------------------------------------------------------------------------------- /mm_action_prediction/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | from .loader_vocabulary import Vocabulary 5 | from .loader_base import LoaderParent 6 | from .loader_simmc import DataloaderSIMMC 7 | 8 | 9 | __all__ = [ 10 | "Vocabulary", 11 | "LoaderParent", 12 | "DataloaderSIMMC", 13 | ] 14 | -------------------------------------------------------------------------------- /mm_action_prediction/loaders/loader_base.py: -------------------------------------------------------------------------------- 1 | """Parent class for data loaders. 2 | 3 | Author: Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import threading 9 | import queue 10 | 11 | import numpy as np 12 | import torch 13 | 14 | 15 | class LoaderParent: 16 | def __init__(self): 17 | """Class constructor. 18 | """ 19 | # Assert the presence of mandatory attributes to setup prefetch daemon. 20 | mandatory_attrs = ["single_pass", "shuffle", "use_gpu"] 21 | assert hasattr(self, "params"), "Params is mandatory attribute!" 22 | for attr in mandatory_attrs: 23 | assert attr in self.params, "{0} is mandatory!".format(attr) 24 | self.params["prefetch_num"] = self.params.get("prefetch_num", 1) 25 | self._setup_prefetching() 26 | 27 | def load_one_batch(self, sample_ids): 28 | """Load one batch given the sample indices. 29 | 30 | Args: 31 | sample_ids: Ids of the instances -- either train/val. 32 | 33 | Return: 34 | batch: Dictionary of features for the instance with sample_ids. 35 | """ 36 | raise NotImplementedError 37 | 38 | def _setup_prefetching(self): 39 | """Prefetches batches to save time. 40 | """ 41 | # Setup and start prefetching daemon. 42 | self._prefetch_queue = queue.Queue(maxsize=self.params["prefetch_num"]) 43 | self._prefetch_thread = threading.Thread(target=self._run_prefetch) 44 | self._prefetch_thread.daemon = True 45 | self._prefetch_thread.start() 46 | 47 | def get_batch(self): 48 | """Batch generator depending on train/eval mode. 49 | """ 50 | while True: 51 | # Get a batch from the prefetching queue 52 | if self._prefetch_queue.empty(): 53 | pass 54 | # print('DataLoader: Waiting for data loading (IO is slow)...') 55 | batch = self._prefetch_queue.get(block=True) 56 | if batch is None: 57 | assert self.params["single_pass"], "Mode set to one pass!" 58 | return 59 | yield batch 60 | 61 | def _run_prefetch(self): 62 | batch_size = self.params["batch_size"] 63 | fetch_order = np.arange(self.num_instances) 64 | n_sample = 0 65 | while True: 66 | # Shuffle the sample order for every epoch. 67 | if n_sample == 0 and self.params["shuffle"]: 68 | fetch_order = np.random.permutation(self.num_instances) 69 | # Load batch from file 70 | # note that len(sample_ids) <= batch_size, not necessarily equal. 71 | sample_ids = fetch_order[n_sample : n_sample + batch_size] 72 | batch = self.load_one_batch(sample_ids) 73 | self._prefetch_queue.put(batch, block=True) 74 | n_sample += len(sample_ids) 75 | if n_sample >= self.num_instances: 76 | # Put in a None batch to indicate a whole pass is over. 77 | if self.params["single_pass"]: 78 | self._prefetch_queue.put(None, block=True) 79 | n_sample = 0 80 | 81 | def _ship_torch_batch(self, batch): 82 | """Ship a batch in PyTorch. 83 | 84 | Useful for cross-package dataloader. 85 | 86 | Args: 87 | batch: Dictionary of the batch. 88 | 89 | Returns: 90 | Batch members changed in place to torch Tensors (with GPU, if needed) 91 | """ 92 | for key, value in batch.items(): 93 | # Check if numpy array or list of numpy arrays. 94 | if isinstance(value, np.ndarray): 95 | batch[key] = self._ship_helper(value) 96 | elif isinstance(value, list) and isinstance(value[0], np.ndarray): 97 | for index, element in enumerate(value): 98 | batch[key][index] = self._ship_helper(element) 99 | return batch 100 | 101 | def _ship_helper(self, numpy_array): 102 | """Helper to ship numpy arrays to torch. 103 | """ 104 | # int32 get mapped to int64 and float to double 105 | if numpy_array.dtype == np.int32 or numpy_array.dtype == np.int64: 106 | new_type = torch.int64 107 | elif numpy_array.dtype == bool: 108 | new_type = torch.bool 109 | else: 110 | new_type = torch.float 111 | torch_tensor = torch.tensor(numpy_array, dtype=new_type) 112 | if self.params["use_gpu"]: 113 | torch_tensor = torch_tensor.cuda() 114 | return torch_tensor 115 | 116 | def compute_idf_features(self): 117 | """Computes idf scores based on train set. 118 | """ 119 | # Should not be invoked if mandatory fields are absent. 120 | mandatory_fields = [ 121 | "user_sent", 122 | "user_sent_len", 123 | "user_utt_id", 124 | "assist_sent", 125 | "assist_sent_len", 126 | "assist_utt_id", 127 | ] 128 | for field in mandatory_fields: 129 | assert field in self.raw_data, "{} missing!".format(field) 130 | # Get document frequency of words for both user / assistant utterances. 131 | IDF = np.ones(self.vocab_size) 132 | num_inst, max_len = self.raw_data["user_utt_id"].shape 133 | for _, dialog_utt in enumerate(self.raw_data["user_utt_id"]): 134 | for _, utt_id in enumerate(dialog_utt): 135 | if utt_id == -1: 136 | break 137 | utt_len = self.raw_data["user_sent_len"][utt_id] 138 | utterance = self.raw_data["user_sent"][utt_id, :utt_len] 139 | IDF[np.unique(utterance)] += 1 140 | for _, dialog_utt in enumerate(self.raw_data["assist_utt_id"]): 141 | for _, utt_id in enumerate(dialog_utt): 142 | if utt_id == -1: 143 | break 144 | utt_len = self.raw_data["assist_sent_len"][utt_id] 145 | utterance = self.raw_data["assist_sent"][utt_id, :utt_len] 146 | IDF[np.unique(utterance)] += 1 147 | num_utterances = (self.raw_data["user_utt_id"] != -1).sum() + ( 148 | self.raw_data["user_utt_id"] != -1 149 | ).sum() 150 | self.IDF = np.log(num_utterances / IDF) 151 | 152 | def compute_tf_features(self, utterances, utterance_lens): 153 | """Compute TF features for either train/val/test set. 154 | 155 | Args: 156 | Utterances: arguments to compute TF features 157 | utterance_lens: Length of the utterances 158 | 159 | Returns: 160 | tf_idf_features: tf_idf features 161 | """ 162 | assert hasattr(self, "IDF"), "IDF has not been computed/loaded!" 163 | batch_size, num_rounds, max_len = utterances.shape 164 | num_utterances = batch_size * num_rounds 165 | utterances = utterances.reshape(-1, max_len) 166 | utterance_lens = utterance_lens.reshape(-1) 167 | tf_features = np.zeros((num_utterances, self.vocab_size)) 168 | for utt_id, utterance in enumerate(utterances): 169 | tokens = utterance[: utterance_lens[utt_id]] 170 | for tt in tokens: 171 | tf_features[utt_id, tt] += 1.0 / utterance_lens[utt_id] 172 | return tf_features.reshape(batch_size, num_rounds, -1) 173 | 174 | def get_data_related_arguments(self): 175 | """Get data related arguments like vocab_size, etc. 176 | 177 | Complete list: Vocab size, pad_token, start_token, end_token, 178 | num_actions, asset_feature_size (if exists). 179 | 180 | Returns: 181 | related_args: Dictionary containing the above arguments 182 | """ 183 | related_args = { 184 | "vocab_size": self.vocab_size, 185 | "pad_token": self.pad_token, 186 | "start_token": self.start_token, 187 | "end_token": self.end_token, 188 | "num_actions": self.num_actions, 189 | } 190 | if self.params["encoder"] == "pretrained_transformer": 191 | related_args["vocab_path"] = self.raw_data["vocabulary"] 192 | related_args["vocab_size"] += len(self.words.added_tokens_encoder) 193 | if hasattr(self, "asset_feature_size"): 194 | related_args["asset_feature_size"] = self.asset_feature_size 195 | return related_args 196 | 197 | @staticmethod 198 | def numpy(batch_torch): 199 | """Convert a batch into numpy arrays. 200 | 201 | Args: 202 | batch_torch: A batch with torch tensors 203 | 204 | Returns: 205 | batch_numpy: batch_torch with all tensors moved to numpy 206 | """ 207 | batch_numpy = {} 208 | for key, value in batch_torch.items(): 209 | # Check if numpy array or list of numpy arrays. 210 | if isinstance(value, torch.Tensor): 211 | batch_numpy[key] = value.cpu().numpy() 212 | elif isinstance(value, list) and isinstance(value[0], torch.Tensor): 213 | batch_numpy[key] = [None] * len(batch_torch[key]) 214 | for index, element in enumerate(value): 215 | batch_numpy[key][index] = element.cpu().numpy() 216 | else: 217 | batch_numpy[key] = value 218 | return batch_numpy 219 | 220 | @property 221 | def num_instances(self): 222 | """Number of instances in the dataloader. 223 | """ 224 | raise NotImplementedError 225 | -------------------------------------------------------------------------------- /mm_action_prediction/loaders/loader_vocabulary.py: -------------------------------------------------------------------------------- 1 | """Loads vocabulary and performs additional text processing. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | #!/usr/bin/python3 6 | 7 | from __future__ import absolute_import, division, print_function, unicode_literals 8 | 9 | import copy 10 | 11 | 12 | class Vocabulary: 13 | def __init__(self, vocabulary_path=None, immutable=False, verbose=True): 14 | """Initialize the vocabulary object given a path, else empty object. 15 | 16 | Args: 17 | vocabulary_path: List of words in a text file, one in each line. 18 | immutable: Once initialized, no new words can be added. 19 | """ 20 | self.immutable = immutable 21 | self.verbose = verbose 22 | # Read file else create empty object. 23 | if vocabulary_path is not None: 24 | if verbose: 25 | print("Reading vocabulary: {0}...".format(vocabulary_path), end="") 26 | with open(vocabulary_path, "r") as file_id: 27 | self._words = [ii.strip() for ii in file_id.readlines()] 28 | if verbose: 29 | print("done") 30 | # Setup rest of the object. 31 | self._setup_vocabulary() 32 | else: 33 | if verbose: 34 | print("Initializing empty vocabulary object..") 35 | 36 | def __contains__(self, key): 37 | """Check if a word is contained in a vocabulary. 38 | """ 39 | return key in self._words 40 | 41 | def _setup_vocabulary(self): 42 | """Sets up internal dictionaries. 43 | """ 44 | # Check whether ,, and are part of the word list. 45 | # Else add them. 46 | for special_word in ["", "", "", ""]: 47 | if special_word not in self._words: 48 | if not self.immutable: 49 | self._words.append(special_word) 50 | if self.verbose: 51 | print("Adding new word to vocabulary: {}".format(special_word)) 52 | else: 53 | if self.verbose: 54 | print("Immutable, cannot add missing {}".format(special_word)) 55 | # Create word_index and word_string dictionaries. 56 | self.word_index = {word: index for index, word in enumerate(self._words)} 57 | self.word_string = {index: word for word, index in self.word_index.items()} 58 | if self.verbose: 59 | print("Vocabulary size updated: {0}".format(len(self.word_index))) 60 | 61 | def add_new_word(self, *new_words): 62 | """Adds new words to an existing vocabulary object. 63 | 64 | Args: 65 | *new_words: List of new word(s) to be added. 66 | """ 67 | raise NotImplementedError 68 | 69 | def word(self, index): 70 | """Returns the word given the index. 71 | 72 | Args: 73 | index: Index of the word 74 | 75 | Returns: 76 | Word string for the given index. 77 | """ 78 | assert index in self.word_string, "{0} missing in vocabulary!".format(index) 79 | return self.word_string[index] 80 | 81 | def index(self, word, unk_default=False): 82 | """Returns the index given the word. 83 | 84 | Args: 85 | word: Word string. 86 | 87 | Returns: 88 | Index for the given word string. 89 | """ 90 | if not unk_default: 91 | assert word in self.word_index, "{0} missing in vocabulary!".format(word) 92 | return self.word_index[word] 93 | else: 94 | return self.word_index.get(word, self.word_index[""]) 95 | 96 | def set_vocabulary_state(self, state): 97 | """Given a state (list of words), setup the vocabulary object state. 98 | 99 | Args: 100 | state: List of words 101 | """ 102 | self._words = copy.deepcopy(state) 103 | self._setup_vocabulary() 104 | 105 | def get_vocabulary_state(self): 106 | """Returns the vocabulary state (deepcopy). 107 | 108 | Returns: 109 | Deepcopy of list of words. 110 | """ 111 | return copy.deepcopy(self._words) 112 | 113 | def get_tensor_string(self, tensor): 114 | """Converts a tensor into a string after decoding it using vocabulary. 115 | """ 116 | pad_token = self.index("") 117 | string = " ".join( 118 | [self.word(int(ii)) for ii in tensor.squeeze() if ii != pad_token] 119 | ) 120 | return string 121 | 122 | @property 123 | def vocab_size(self): 124 | return len(self._words) 125 | -------------------------------------------------------------------------------- /mm_action_prediction/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function, unicode_literals 2 | 3 | 4 | from .assistant import Assistant 5 | from . import encoders 6 | from .decoder import GenerativeDecoder 7 | from .action_executor import ActionExecutor 8 | from .positional_encoding import PositionalEncoding 9 | from .self_attention import SelfAttention 10 | from .carousel_embedder import CarouselEmbedder 11 | from .user_memory_embedder import UserMemoryEmbedder 12 | 13 | 14 | __all__ = [ 15 | "Assistant", 16 | "GenerativeDecoder", 17 | "ActionExecutor", 18 | "PositionalEncoding", 19 | "SelfAttention", 20 | "CarouselEmbedder", 21 | "UserMemoryEmbedder", 22 | "encoders" 23 | ] 24 | -------------------------------------------------------------------------------- /mm_action_prediction/models/assistant.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Assistant Model for Furniture Genie. 3 | 4 | Author(s): Satwik Kottur 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from tools import weight_init, torch_support 11 | import models 12 | import models.encoders as encoders 13 | 14 | 15 | class Assistant(nn.Module): 16 | """SIMMC Assistant Agent. 17 | """ 18 | 19 | def __init__(self, params): 20 | super(Assistant, self).__init__() 21 | self.params = params 22 | 23 | self.encoder = encoders.ENCODER_REGISTRY[params["encoder"]](params) 24 | self.decoder = models.GenerativeDecoder(params) 25 | 26 | if params["encoder"] == "pretrained_transformer": 27 | self.decoder.word_embed_net = ( 28 | self.encoder.models.decoder.bert.embeddings.word_embeddings 29 | ) 30 | self.decoder.decoder_unit = self.encoder.models.decoder 31 | 32 | # Learn to predict and execute actions. 33 | self.action_executor = models.ActionExecutor(params) 34 | self.criterion = nn.CrossEntropyLoss(reduction="none") 35 | 36 | # Initialize weights. 37 | weight_init.weight_init(self) 38 | if params["use_gpu"]: 39 | self = self.cuda() 40 | # Sharing word embeddings across encoder and decoder. 41 | if self.params["share_embeddings"]: 42 | if hasattr(self.encoder, "word_embed_net") and hasattr( 43 | self.decoder, "word_embed_net" 44 | ): 45 | self.decoder.word_embed_net = self.encoder.word_embed_net 46 | 47 | def forward(self, batch, mode=None): 48 | """Forward propagation. 49 | 50 | Args: 51 | batch: Dict of batch input variables. 52 | mode: None for training or teaching forcing evaluation; 53 | BEAMSEARCH / SAMPLE / MAX to generate text 54 | """ 55 | outputs = self.encoder(batch) 56 | action_output = self.action_executor(batch, outputs) 57 | outputs.update(action_output) 58 | decoder_output = self.decoder(batch, outputs) 59 | if mode: 60 | generation_output = self.decoder.forward_beamsearch_multiple( 61 | batch, outputs, mode 62 | ) 63 | outputs.update(generation_output) 64 | 65 | # If evaluating by retrieval, construct fake batch for each candidate. 66 | # Inputs from batch used in decoder: 67 | # assist_in, assist_out, assist_in_len, assist_mask 68 | if self.params["retrieval_evaluation"] and not self.training: 69 | option_scores = [] 70 | batch_size, num_rounds, num_candidates, _ = batch["candidate_in"].shape 71 | replace_keys = ("assist_in", "assist_out", "assist_in_len", "assist_mask") 72 | for ii in range(num_candidates): 73 | for key in replace_keys: 74 | new_key = key.replace("assist", "candidate") 75 | batch[key] = batch[new_key][:, :, ii] 76 | decoder_output = self.decoder(batch, outputs) 77 | log_probs = torch_support.unflatten( 78 | decoder_output["loss_token"], batch_size, num_rounds 79 | ) 80 | option_scores.append(-1 * log_probs.sum(-1)) 81 | option_scores = torch.stack(option_scores, 2) 82 | outputs["candidate_scores"] = [ 83 | { 84 | "dialog_id": batch["dialog_id"][ii].item(), 85 | "candidate_scores": [ 86 | { 87 | "scores": [ 88 | float(kk) for kk in option_scores[ii, jj].cpu() 89 | ], 90 | "turn_id": jj 91 | } 92 | for jj in range(batch["dialog_len"][ii]) 93 | ] 94 | } 95 | for ii in range(batch_size) 96 | ] 97 | 98 | # Local aliases. 99 | loss_token = decoder_output["loss_token"] 100 | pad_mask = decoder_output["pad_mask"] 101 | if self.training: 102 | loss_token = loss_token.sum() / (~pad_mask).sum().item() 103 | loss_action = action_output["action_loss"] 104 | loss_action_attr = action_output["action_attr_loss"] 105 | loss_total = loss_action + loss_token + loss_action_attr 106 | return { 107 | "token": loss_token, 108 | "action": loss_action, 109 | "action_attr": loss_action_attr, 110 | "total": loss_total, 111 | } 112 | else: 113 | outputs.update( 114 | {"loss_sum": loss_token.sum(), "num_tokens": (~pad_mask).sum()} 115 | ) 116 | return outputs 117 | -------------------------------------------------------------------------------- /mm_action_prediction/models/carousel_embedder.py: -------------------------------------------------------------------------------- 1 | """Embedding carousel for action predictions. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class CarouselEmbedder(nn.Module): 13 | def __init__(self, params): 14 | super(CarouselEmbedder, self).__init__() 15 | self.params = params 16 | self.host = torch.cuda if params["use_gpu"] else torch 17 | self.positions = ["left", "center", "right", "focus", "empty"] 18 | self.occupancy_states = {} 19 | self.carousel_pos = {} 20 | for position in self.positions: 21 | pos_parameter = torch.randn(params["word_embed_size"]) 22 | if params["use_gpu"]: 23 | pos_parameter = pos_parameter.cuda() 24 | pos_parameter = nn.Parameter(pos_parameter) 25 | self.carousel_pos[position] = pos_parameter 26 | # Register the parameter for training/saving. 27 | self.register_parameter(position, pos_parameter) 28 | 29 | # Project carousel embedding to same size as encoder. 30 | input_size = params["asset_feature_size"] + params["word_embed_size"] 31 | if params["text_encoder"] == "lstm": 32 | output_size = params["hidden_size"] 33 | else: 34 | output_size = params["word_embed_size"] 35 | self.carousel_embed_net = nn.Linear(input_size, output_size) 36 | self.carousel_attend = nn.MultiheadAttention(output_size, 1) 37 | self.carousel_mask = self._generate_carousel_mask(3) 38 | 39 | def forward(self, carousel_state, encoder_state, encoder_size): 40 | """Carousel Embedding. 41 | 42 | Args: 43 | carousel_state: State of the carousel 44 | encoder_state: State of the encoder 45 | encoder_size: (batch_size, num_rounds) 46 | 47 | Returns: 48 | new_encoder_state: 49 | """ 50 | if len(self.occupancy_states) == 0: 51 | self._setup_occupancy_states() 52 | 53 | batch_size, num_rounds = encoder_size 54 | carousel_states = [] 55 | carousel_sizes = [] 56 | for inst_id in range(batch_size): 57 | for round_id in range(num_rounds): 58 | round_datum = carousel_state[inst_id][round_id] 59 | if round_datum is None: 60 | carousel_features = self.none_features 61 | carousel_sizes.append(1) 62 | elif "focus" in round_datum: 63 | carousel_features = torch.cat( 64 | [round_datum["focus"], self.carousel_pos["focus"]] 65 | ).unsqueeze(0) 66 | carousel_features = torch.cat( 67 | [carousel_features, self.empty_feature, self.empty_feature], 68 | dim=0, 69 | ) 70 | carousel_sizes.append(1) 71 | elif "carousel" in round_datum: 72 | carousel_size = len(round_datum["carousel"]) 73 | if carousel_size < 3: 74 | all_embeds = torch.cat( 75 | [round_datum["carousel"]] 76 | + self.occupancy_embeds[carousel_size], 77 | dim=0, 78 | ) 79 | else: 80 | all_embeds = round_datum["carousel"] 81 | all_states = self.occupancy_states[carousel_size] 82 | carousel_features = torch.cat([all_embeds, all_states], -1) 83 | carousel_sizes.append(carousel_size) 84 | # Project into same feature shape. 85 | carousel_features = self.carousel_embed_net(carousel_features) 86 | carousel_states.append(carousel_features) 87 | # Shape: (L,N,E) 88 | carousel_states = torch.stack(carousel_states, dim=1) 89 | # Mask: (N,S) 90 | carousel_len = self.host.LongTensor(carousel_sizes) 91 | query = encoder_state.unsqueeze(0) 92 | attended_query, attented_wts = self.carousel_attend( 93 | query, 94 | carousel_states, 95 | carousel_states, 96 | key_padding_mask=self.carousel_mask[carousel_len - 1], 97 | ) 98 | carousel_encode = torch.cat([attended_query.squeeze(0), encoder_state], dim=-1) 99 | return carousel_encode 100 | 101 | def empty_carousel(self, carousel_state): 102 | """Check if carousel is empty in the standard representation. 103 | 104 | Args: 105 | carousel_state: Carousel state 106 | 107 | Returns: 108 | empty_carousel: Boolean (True -- empty, False -- not empty) 109 | """ 110 | return carousel_state == {"focus": None, "carousel": []} 111 | 112 | def _generate_carousel_mask(self, size): 113 | """Generates square masks for transformers to avoid peeking. 114 | """ 115 | mask = (torch.triu(torch.ones(size, size)) == 0).transpose(0, 1) 116 | if self.params["use_gpu"]: 117 | mask = mask.cuda() 118 | return mask 119 | 120 | def _setup_occupancy_states(self): 121 | """Setup carousel states and embeddings for different occupancy levels. 122 | """ 123 | self.occupancy_states = {} 124 | self.occupancy_embeds = {} 125 | self.zero_tensor = self.host.FloatTensor(self.params["asset_feature_size"]) 126 | self.zero_tensor.fill_(0.0) 127 | for num_items in range(4): 128 | states = [self.carousel_pos[ii] for ii in self.positions[:num_items]] 129 | states += [self.carousel_pos["empty"] for ii in range(3 - num_items)] 130 | states = torch.stack(states, dim=0) 131 | self.occupancy_states[num_items] = states 132 | embeds = [self.zero_tensor for _ in range(3 - num_items)] 133 | if len(embeds): 134 | embeds = [torch.stack(embeds, dim=0)] 135 | self.occupancy_embeds[num_items] = embeds 136 | self.empty_feature = torch.cat( 137 | [self.zero_tensor, self.carousel_pos["empty"]], dim=-1 138 | ).unsqueeze(0) 139 | self.none_features = self.empty_feature.expand(3, -1) 140 | -------------------------------------------------------------------------------- /mm_action_prediction/models/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | # Class decorator to register encoders. 2 | from __future__ import absolute_import, division, print_function, unicode_literals 3 | 4 | 5 | ENCODER_REGISTRY = {} 6 | 7 | 8 | def register_encoder(encoder_name): 9 | """Register the class with the name. 10 | """ 11 | 12 | def register_encoder_class(encoder_class): 13 | if encoder_name in ENCODER_REGISTRY: 14 | raise ValueError("Cant register {0} again!".format(encoder_name)) 15 | ENCODER_REGISTRY[encoder_name] = encoder_class 16 | return encoder_class 17 | 18 | return register_encoder_class 19 | 20 | 21 | from .history_agnostic import HistoryAgnosticEncoder 22 | from .hierarchical_recurrent import HierarchicalRecurrentEncoder 23 | from .memory_network import MemoryNetworkEncoder 24 | from .tf_idf_encoder import TFIDFEncoder 25 | 26 | __all__ = [ 27 | "HistoryAgnosticEncoder", 28 | "HierarchicalRecurrentEncoder", 29 | "MemoryNetworkEncoder", 30 | "TFIDFEncoder" 31 | ] 32 | -------------------------------------------------------------------------------- /mm_action_prediction/models/encoders/hierarchical_recurrent.py: -------------------------------------------------------------------------------- 1 | """Implements hierarchical recurrent neural network encoder. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import torch.nn as nn 9 | 10 | from tools import rnn_support as rnn 11 | from tools import torch_support as support 12 | import models.encoders as encoders 13 | 14 | 15 | @encoders.register_encoder("hierarchical_recurrent") 16 | class HierarchicalRecurrentEncoder(nn.Module): 17 | def __init__(self, params): 18 | super(HierarchicalRecurrentEncoder, self).__init__() 19 | self.params = params 20 | 21 | self.word_embed_net = nn.Embedding( 22 | params["vocab_size"], params["word_embed_size"] 23 | ) 24 | encoder_input_size = params["word_embed_size"] 25 | self.encoder_unit = nn.LSTM( 26 | encoder_input_size, 27 | params["hidden_size"], 28 | params["num_layers"], 29 | batch_first=True, 30 | ) 31 | self.dialog_unit = nn.LSTM( 32 | params["hidden_size"], 33 | params["hidden_size"], 34 | params["num_layers"], 35 | batch_first=True, 36 | ) 37 | 38 | def forward(self, batch): 39 | """Forward pass through the encoder. 40 | 41 | Args: 42 | batch: Dict of batch variables. 43 | 44 | Returns: 45 | encoder_outputs: Dict of outputs from the forward pass. 46 | """ 47 | encoder_out = {} 48 | # Flatten to encode sentences. 49 | batch_size, num_rounds, _ = batch["user_utt"].shape 50 | encoder_in = support.flatten(batch["user_utt"], batch_size, num_rounds) 51 | encoder_len = batch["user_utt_len"].reshape(-1) 52 | word_embeds_enc = self.word_embed_net(encoder_in) 53 | 54 | # Fake encoder_len to be non-zero even for utterances out of dialog. 55 | fake_encoder_len = encoder_len.eq(0).long() + encoder_len 56 | all_enc_states, enc_states = rnn.dynamic_rnn( 57 | self.encoder_unit, word_embeds_enc, fake_encoder_len, return_states=True 58 | ) 59 | encoder_out["hidden_states_all"] = all_enc_states 60 | encoder_out["hidden_state"] = enc_states 61 | 62 | utterance_enc = enc_states[0][-1] 63 | new_size = (batch_size, num_rounds, utterance_enc.shape[-1]) 64 | utterance_enc = utterance_enc.reshape(new_size) 65 | encoder_out["dialog_context"], _ = rnn.dynamic_rnn( 66 | self.dialog_unit, utterance_enc, batch["dialog_len"], return_states=True 67 | ) 68 | return encoder_out 69 | -------------------------------------------------------------------------------- /mm_action_prediction/models/encoders/history_agnostic.py: -------------------------------------------------------------------------------- 1 | """Implements seq2seq encoder that is history-agnostic. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import torch.nn as nn 9 | 10 | from tools import rnn_support as rnn 11 | from tools import torch_support as support 12 | import models 13 | import models.encoders as encoders 14 | 15 | 16 | @encoders.register_encoder("history_agnostic") 17 | class HistoryAgnosticEncoder(nn.Module): 18 | def __init__(self, params): 19 | super(HistoryAgnosticEncoder, self).__init__() 20 | self.params = params 21 | 22 | self.word_embed_net = nn.Embedding( 23 | params["vocab_size"], params["word_embed_size"] 24 | ) 25 | encoder_input_size = params["word_embed_size"] 26 | if params["text_encoder"] == "transformer": 27 | layer = nn.TransformerEncoderLayer( 28 | params["word_embed_size"], 29 | params["num_heads_transformer"], 30 | params["hidden_size_transformer"], 31 | ) 32 | self.encoder_unit = nn.TransformerEncoder( 33 | layer, params["num_layers_transformer"] 34 | ) 35 | self.pos_encoder = models.PositionalEncoding(params["word_embed_size"]) 36 | elif params["text_encoder"] == "lstm": 37 | self.encoder_unit = nn.LSTM( 38 | encoder_input_size, 39 | params["hidden_size"], 40 | params["num_layers"], 41 | batch_first=True, 42 | ) 43 | else: 44 | raise NotImplementedError("Text encoder must be transformer or LSTM!") 45 | 46 | def forward(self, batch): 47 | """Forward pass through the encoder. 48 | 49 | Args: 50 | batch: Dict of batch variables. 51 | 52 | Returns: 53 | encoder_outputs: Dict of outputs from the forward pass. 54 | """ 55 | encoder_out = {} 56 | # Flatten for history_agnostic encoder. 57 | batch_size, num_rounds, max_length = batch["user_utt"].shape 58 | encoder_in = support.flatten(batch["user_utt"], batch_size, num_rounds) 59 | encoder_len = support.flatten(batch["user_utt_len"], batch_size, num_rounds) 60 | word_embeds_enc = self.word_embed_net(encoder_in) 61 | # Text encoder: LSTM or Transformer. 62 | if self.params["text_encoder"] == "lstm": 63 | all_enc_states, enc_states = rnn.dynamic_rnn( 64 | self.encoder_unit, word_embeds_enc, encoder_len, return_states=True 65 | ) 66 | encoder_out["hidden_states_all"] = all_enc_states 67 | encoder_out["hidden_state"] = enc_states 68 | 69 | elif self.params["text_encoder"] == "transformer": 70 | enc_embeds = self.pos_encoder(word_embeds_enc).transpose(0, 1) 71 | enc_pad_mask = batch["user_utt"] == batch["pad_token"] 72 | enc_pad_mask = support.flatten(enc_pad_mask, batch_size, num_rounds) 73 | enc_states = self.encoder_unit( 74 | enc_embeds, src_key_padding_mask=enc_pad_mask 75 | ) 76 | encoder_out["hidden_states_all"] = enc_states.transpose(0, 1) 77 | return encoder_out 78 | -------------------------------------------------------------------------------- /mm_action_prediction/models/encoders/memory_network.py: -------------------------------------------------------------------------------- 1 | """Implements memory network encoder. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | 12 | from tools import rnn_support as rnn 13 | from tools import torch_support as support 14 | import models.encoders as encoders 15 | 16 | 17 | @encoders.register_encoder("memory_network") 18 | class MemoryNetworkEncoder(nn.Module): 19 | def __init__(self, params): 20 | super(MemoryNetworkEncoder, self).__init__() 21 | self.params = params 22 | 23 | self.word_embed_net = nn.Embedding( 24 | params["vocab_size"], params["word_embed_size"] 25 | ) 26 | encoder_input_size = params["word_embed_size"] 27 | self.encoder_unit = nn.LSTM( 28 | encoder_input_size, 29 | params["hidden_size"], 30 | params["num_layers"], 31 | batch_first=True, 32 | ) 33 | self.fact_unit = nn.LSTM( 34 | params["word_embed_size"], 35 | params["hidden_size"], 36 | params["num_layers"], 37 | batch_first=True, 38 | ) 39 | 40 | self.softmax = nn.functional.softmax 41 | self.fact_attention_net = nn.Sequential( 42 | nn.Linear(2 * params["hidden_size"], params["hidden_size"]), 43 | nn.ReLU(), 44 | nn.Linear(params["hidden_size"], 1), 45 | ) 46 | 47 | def forward(self, batch): 48 | """Forward pass through the encoder. 49 | 50 | Args: 51 | batch: Dict of batch variables. 52 | 53 | Returns: 54 | encoder_outputs: Dict of outputs from the forward pass. 55 | """ 56 | encoder_out = {} 57 | # Flatten to encode sentences. 58 | batch_size, num_rounds, _ = batch["user_utt"].shape 59 | encoder_in = support.flatten(batch["user_utt"], batch_size, num_rounds) 60 | encoder_len = batch["user_utt_len"].reshape(-1) 61 | word_embeds_enc = self.word_embed_net(encoder_in) 62 | 63 | # Fake encoder_len to be non-zero even for utterances out of dialog. 64 | fake_encoder_len = encoder_len.eq(0).long() + encoder_len 65 | all_enc_states, enc_states = rnn.dynamic_rnn( 66 | self.encoder_unit, word_embeds_enc, fake_encoder_len, return_states=True 67 | ) 68 | encoder_out["hidden_states_all"] = all_enc_states 69 | encoder_out["hidden_state"] = enc_states 70 | 71 | utterance_enc = enc_states[0][-1] 72 | batch["utterance_enc"] = support.unflatten( 73 | utterance_enc, batch_size, num_rounds 74 | ) 75 | encoder_out["dialog_context"] = self._memory_net_forward(batch) 76 | return encoder_out 77 | 78 | def _memory_net_forward(self, batch): 79 | """Forward pass for memory network to look up fact. 80 | 81 | 1. Encodes fact via fact rnn. 82 | 2. Computes attention with fact and utterance encoding. 83 | 3. Attended fact vector and question encoding -> new encoding. 84 | 85 | Args: 86 | batch: Dict of hist, hist_len, hidden_state 87 | """ 88 | batch_size, num_rounds, enc_time_steps = batch["fact"].shape 89 | all_ones = np.full((num_rounds, num_rounds), 1) 90 | fact_mask = np.triu(all_ones, 1) 91 | fact_mask = np.expand_dims(np.expand_dims(fact_mask, -1), 0) 92 | fact_mask = torch.BoolTensor(fact_mask) 93 | if self.params["use_gpu"]: 94 | fact_mask = fact_mask.cuda() 95 | fact_mask.requires_grad_(False) 96 | 97 | fact_in = support.flatten(batch["fact"], batch_size, num_rounds) 98 | fact_len = support.flatten(batch["fact_len"], batch_size, num_rounds) 99 | fact_embeds = self.word_embed_net(fact_in) 100 | # Encoder fact and unflatten the last hidden state. 101 | _, (hidden_state, _) = rnn.dynamic_rnn( 102 | self.fact_unit, fact_embeds, fact_len, return_states=True 103 | ) 104 | fact_encode = support.unflatten(hidden_state[-1], batch_size, num_rounds) 105 | fact_encode = fact_encode.unsqueeze(1).expand(-1, num_rounds, -1, -1) 106 | 107 | utterance_enc = batch["utterance_enc"].unsqueeze(2) 108 | utterance_enc = utterance_enc.expand(-1, -1, num_rounds, -1) 109 | # Combine, compute attention, mask, and weight the fact encodings. 110 | combined_encode = torch.cat([utterance_enc, fact_encode], dim=-1) 111 | attention = self.fact_attention_net(combined_encode) 112 | attention.masked_fill_(fact_mask, float("-Inf")) 113 | attention = self.softmax(attention, dim=2) 114 | attended_fact = (attention * fact_encode).sum(2) 115 | return attended_fact 116 | -------------------------------------------------------------------------------- /mm_action_prediction/models/encoders/tf_idf_encoder.py: -------------------------------------------------------------------------------- 1 | """Implements TF-IDF based encoder that is history-agnostic. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | import models.encoders as encoders 12 | 13 | 14 | @encoders.register_encoder("tf_idf") 15 | class TFIDFEncoder(nn.Module): 16 | def __init__(self, params): 17 | super(TFIDFEncoder, self).__init__() 18 | self.params = params 19 | self.IDF = nn.Parameter(torch.randn(params["vocab_size"])) 20 | self.encoder_net = nn.Sequential( 21 | nn.Linear(params["vocab_size"], params["vocab_size"] // 2), 22 | nn.ReLU(), 23 | nn.Linear(params["vocab_size"] // 2, params["vocab_size"] // 4), 24 | nn.ReLU(), 25 | nn.Linear(params["vocab_size"] // 4, params["hidden_size"]), 26 | nn.ReLU(), 27 | nn.Linear(params["hidden_size"], params["hidden_size"]), 28 | ) 29 | 30 | def forward(self, batch): 31 | """Forward pass through the encoder. 32 | 33 | Args: 34 | batch: Dict of batch variables. 35 | 36 | Returns: 37 | encoder_outputs: Dict of outputs from the forward pass. 38 | """ 39 | encoder_embed = self.encoder_net(batch["user_tf_idf"] * self.IDF) 40 | batch_size, num_rounds, feat_size = encoder_embed.shape 41 | encoder_embed = encoder_embed.view(1, -1, feat_size) 42 | return {"hidden_state": (encoder_embed, encoder_embed)} 43 | -------------------------------------------------------------------------------- /mm_action_prediction/models/fashion_model_metainfo.json: -------------------------------------------------------------------------------- 1 | { 2 | "actions": [ 3 | { 4 | "name": "SearchMemory", 5 | "id": 0, 6 | "attributes": ["attributes"] 7 | }, 8 | { 9 | "name": "SearchDatabase", 10 | "id": 1, 11 | "attributes": ["attributes"] 12 | }, 13 | { 14 | "name": "SpecifyInfo", 15 | "id": 2, 16 | "attributes": ["attributes"] 17 | }, 18 | { 19 | "name": "AddToCart", 20 | "id": 3, 21 | "attributes": [] 22 | }, 23 | { 24 | "name": "None", 25 | "id": 4, 26 | "attributes": [] 27 | } 28 | ] 29 | } 30 | -------------------------------------------------------------------------------- /mm_action_prediction/models/furniture_model_metainfo.json: -------------------------------------------------------------------------------- 1 | { 2 | "actions": [ 3 | { 4 | "id": 0, 5 | "name": "SearchFurniture", 6 | "attributes": ["color", "furnitureType"] 7 | }, 8 | { 9 | "id": 1, 10 | "name": "SpecifyInfo", 11 | "attributes": ["matches"] 12 | }, 13 | { 14 | "id": 2, 15 | "name": "FocusOnFurniture", 16 | "attributes": ["position"] 17 | }, 18 | { 19 | "id": 3, 20 | "name": "Rotate", 21 | "attributes": ["direction"] 22 | }, 23 | { 24 | "id": 4, 25 | "name": "NavigateCarousel", 26 | "attributes": ["navigate_direction"] 27 | }, 28 | { 29 | "id": 5, 30 | "name": "AddToCart", 31 | "attributes": [] 32 | }, 33 | { 34 | "id": 6, 35 | "name": "None", 36 | "attributes": [] 37 | } 38 | ] 39 | } 40 | -------------------------------------------------------------------------------- /mm_action_prediction/models/positional_encoding.py: -------------------------------------------------------------------------------- 1 | """Positional Encoding class for Transfomers. 2 | 3 | Author(s): Satwik Kottur 4 | Adapted from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html 5 | """ 6 | 7 | from __future__ import absolute_import, division, print_function, unicode_literals 8 | 9 | import math 10 | import torch 11 | import torch.nn as nn 12 | 13 | 14 | class PositionalEncoding(nn.Module): 15 | def __init__(self, d_model, dropout=0.1, max_len=100): 16 | super(PositionalEncoding, self).__init__() 17 | self.dropout = nn.Dropout(p=dropout) 18 | 19 | pe = torch.zeros(max_len, d_model) 20 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 21 | div_term = torch.exp( 22 | torch.arange(0, d_model, 2).float() * (-(math.log(10000.0)) / d_model) 23 | ) 24 | pe[:, 0::2] = torch.sin(position * div_term) 25 | pe[:, 1::2] = torch.cos(position * div_term) 26 | pe = pe.unsqueeze(0) 27 | self.register_buffer("pe", pe) 28 | 29 | def forward(self, x): 30 | """Adds positional encoding to the input. 31 | 32 | Args: 33 | x: Input of size Batch_shape x N_steps x Embed_size 34 | """ 35 | x = x + self.pe[:, : x.size(1), :] 36 | return self.dropout(x) 37 | -------------------------------------------------------------------------------- /mm_action_prediction/models/self_attention.py: -------------------------------------------------------------------------------- 1 | """Self attention network block. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | 12 | class SelfAttention(nn.Module): 13 | def __init__(self, feature_size): 14 | super(SelfAttention, self).__init__() 15 | self.feature_size = feature_size 16 | self.att_wt = nn.Parameter(torch.randn(1, 1, feature_size)) 17 | 18 | def forward(self, feature_block, mask=False): 19 | """Self attends a feature block. 20 | 21 | Args: 22 | feature_block: Input of size Batch_shape x N_steps x Embed_size 23 | mask: Boolean mask to ignore the feature_block (B x N) 24 | 25 | Returns: 26 | att_features: Self attended features from the feature block (B X E) 27 | """ 28 | # Compute attention scores. 29 | batch_size = feature_block.shape[0] 30 | new_size = (batch_size, 1, self.feature_size) 31 | att_logits = torch.bmm( 32 | feature_block, self.att_wt.expand(new_size).transpose(1, 2) 33 | ) 34 | if mask is not None: 35 | att_logits.masked_fill_(mask.unsqueeze(-1), float("-inf")) 36 | att_wts = nn.functional.softmax(att_logits, dim=1) 37 | att_features = (att_wts * feature_block).sum(1) 38 | return att_features 39 | -------------------------------------------------------------------------------- /mm_action_prediction/models/user_memory_embedder.py: -------------------------------------------------------------------------------- 1 | """Embedding user memory for action prediction for fashion. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import torch 9 | import torch.nn as nn 10 | from tools import torch_support as support 11 | 12 | 13 | class UserMemoryEmbedder(nn.Module): 14 | def __init__(self, params): 15 | super(UserMemoryEmbedder, self).__init__() 16 | self.params = params 17 | self.host = torch.cuda if params["use_gpu"] else torch 18 | self.categories = ["focus", "database", "memory"] 19 | self.category_embeds = {} 20 | for position in self.categories: 21 | pos_parameter = torch.randn(params["word_embed_size"]) 22 | if params["use_gpu"]: 23 | pos_parameter = pos_parameter.cuda() 24 | pos_parameter = nn.Parameter(pos_parameter) 25 | self.category_embeds[position] = pos_parameter 26 | # Register the parameter for training/saving. 27 | self.register_parameter(position, pos_parameter) 28 | self.category_state = None 29 | # Project multimodal embedding to same size as encoder. 30 | input_size = params["asset_feature_size"] + params["word_embed_size"] 31 | if params["text_encoder"] == "lstm": 32 | output_size = params["hidden_size"] 33 | else: 34 | output_size = params["word_embed_size"] 35 | self.multimodal_embed_net = nn.Linear(input_size, output_size) 36 | self.multimodal_attend = nn.MultiheadAttention(output_size, 1) 37 | 38 | def forward(self, multimodal_state, encoder_state, encoder_size): 39 | """Multimodal Embedding. 40 | 41 | Args: 42 | multimodal_state: Dict with memory, database, and focus images 43 | encoder_state: State of the encoder 44 | encoder_size: (batch_size, num_rounds) 45 | 46 | Returns: 47 | multimodal_encode: Encoder state with multimodal information 48 | """ 49 | # Setup category states if None. 50 | if self.category_state is None: 51 | self._setup_category_states() 52 | # Attend to multimodal memory using encoder states. 53 | batch_size, num_rounds = encoder_size 54 | memory_images = multimodal_state["memory_images"] 55 | memory_images = memory_images.unsqueeze(1).expand(-1, num_rounds, -1, -1) 56 | focus_images = multimodal_state["focus_images"][:, :num_rounds, :] 57 | focus_images = focus_images.unsqueeze(2) 58 | all_images = torch.cat([focus_images, memory_images], dim=2) 59 | all_images_flat = support.flatten(all_images, batch_size, num_rounds) 60 | category_state = self.category_state.expand(batch_size * num_rounds, -1, -1) 61 | cat_images = torch.cat([all_images_flat, category_state], dim=-1) 62 | multimodal_memory = self.multimodal_embed_net(cat_images) 63 | # Key (L, N, E), value (L, N, E), query (S, N, E) 64 | multimodal_memory = multimodal_memory.transpose(0, 1) 65 | query = encoder_state.unsqueeze(0) 66 | attended_query, attented_wts = self.multimodal_attend( 67 | query, multimodal_memory, multimodal_memory 68 | ) 69 | multimodal_encode = torch.cat( 70 | [attended_query.squeeze(0), encoder_state], dim=-1 71 | ) 72 | return multimodal_encode 73 | 74 | def _setup_category_states(self): 75 | """Setup category states (focus + memory images). 76 | """ 77 | # NOTE: Assumes three memory images; make it adaptive later. 78 | self.category_state = torch.stack( 79 | [ 80 | self.category_embeds["focus"], 81 | self.category_embeds["memory"], 82 | self.category_embeds["memory"], 83 | self.category_embeds["memory"], 84 | ], 85 | dim=0, 86 | ).unsqueeze(0) 87 | -------------------------------------------------------------------------------- /mm_action_prediction/options.py: -------------------------------------------------------------------------------- 1 | """Script to read command line flags using ArgParser. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | 7 | from __future__ import absolute_import, division, print_function, unicode_literals 8 | 9 | import argparse 10 | import torch 11 | from tools import support 12 | 13 | 14 | def read_command_line(): 15 | """Read and parse commandline arguments to run the program. 16 | 17 | Returns: 18 | parsed_args: Dictionary of parsed arguments. 19 | """ 20 | title = "Train assistant model for furniture genie" 21 | parser = argparse.ArgumentParser(description=title) 22 | 23 | # Data input settings. 24 | parser.add_argument( 25 | "--train_data_path", required=True, help="Path to compiled training data" 26 | ) 27 | parser.add_argument( 28 | "--eval_data_path", default=None, help="Path to compiled evaluation data" 29 | ) 30 | parser.add_argument( 31 | "--snapshot_path", default="checkpoints/", help="Path to save checkpoints" 32 | ) 33 | parser.add_argument( 34 | "--metainfo_path", 35 | default="data/furniture_metainfo.json", 36 | help="Path to file containing metainfo", 37 | ) 38 | parser.add_argument( 39 | "--attr_vocab_path", 40 | default="data/attr_vocab_file.json", 41 | help="Path to attribute vocabulary file", 42 | ) 43 | parser.add_argument( 44 | "--domain", 45 | required=True, 46 | choices=["furniture", "fashion"], 47 | help="Domain to train the model on", 48 | ) 49 | # Asset embedding. 50 | parser.add_argument( 51 | "--asset_embed_path", 52 | default="data/furniture_asset_path.npy", 53 | help="Path to asset embeddings", 54 | ) 55 | # Specify encoder/decoder flags. 56 | # Model hyperparameters. 57 | parser.add_argument( 58 | "--encoder", 59 | required=True, 60 | choices=[ 61 | "history_agnostic", 62 | "history_aware", 63 | "pretrained_transformer", 64 | "hierarchical_recurrent", 65 | "memory_network", 66 | "tf_idf", 67 | ], 68 | help="Encoder type to use for text", 69 | ) 70 | parser.add_argument( 71 | "--text_encoder", 72 | required=True, 73 | choices=["lstm", "transformer"], 74 | help="Encoder type to use for text", 75 | ) 76 | parser.add_argument( 77 | "--word_embed_size", default=128, type=int, help="size of embedding for text" 78 | ) 79 | parser.add_argument( 80 | "--hidden_size", 81 | default=128, 82 | type=int, 83 | help=( 84 | "Size of hidden state in LSTM/transformer." 85 | "Must be same as word_embed_size for transformer" 86 | ), 87 | ) 88 | # Parameters for transformer text encoder. 89 | parser.add_argument( 90 | "--num_heads_transformer", 91 | default=-1, 92 | type=int, 93 | help="Number of heads in the transformer", 94 | ) 95 | parser.add_argument( 96 | "--num_layers_transformer", 97 | default=-1, 98 | type=int, 99 | help="Number of layers in the transformer", 100 | ) 101 | parser.add_argument( 102 | "--hidden_size_transformer", 103 | default=2048, 104 | type=int, 105 | help="Hidden Size within transformer", 106 | ) 107 | parser.add_argument( 108 | "--num_layers", default=1, type=int, help="Number of layers in LSTM" 109 | ) 110 | parser.add_argument( 111 | "--use_action_attention", 112 | dest="use_action_attention", 113 | action="store_true", 114 | default=False, 115 | help="Use attention over all encoder statesfor action", 116 | ) 117 | parser.add_argument( 118 | "--use_action_output", 119 | dest="use_action_output", 120 | action="store_true", 121 | default=False, 122 | help="Model output of actions as decoder memory elements", 123 | ) 124 | parser.add_argument( 125 | "--use_multimodal_state", 126 | dest="use_multimodal_state", 127 | action="store_true", 128 | default=False, 129 | help="Use multimodal state for action prediction (fashion)", 130 | ) 131 | parser.add_argument( 132 | "--use_bahdanau_attention", 133 | dest="use_bahdanau_attention", 134 | action="store_true", 135 | default=False, 136 | help="Use bahdanau attention for decoder LSTM", 137 | ) 138 | parser.add_argument( 139 | "--skip_retrieval_evaluation", 140 | dest="retrieval_evaluation", 141 | action="store_false", 142 | default=True, 143 | help="Evaluation response generation through retrieval" 144 | ) 145 | parser.add_argument( 146 | "--skip_bleu_evaluation", 147 | dest="bleu_evaluation", 148 | action="store_false", 149 | default=True, 150 | help="Use beamsearch to evaluate BLEU score" 151 | ) 152 | parser.add_argument( 153 | "--max_encoder_len", 154 | default=24, 155 | type=int, 156 | help="Maximum encoding length for sentences", 157 | ) 158 | parser.add_argument( 159 | "--max_history_len", 160 | default=100, 161 | type=int, 162 | help="Maximum encoding length for history encoding", 163 | ) 164 | parser.add_argument( 165 | "--max_decoder_len", 166 | default=26, 167 | type=int, 168 | help="Maximum decoding length for sentences", 169 | ) 170 | parser.add_argument( 171 | "--max_rounds", 172 | default=30, 173 | type=int, 174 | help="Maximum number of rounds for the dialog", 175 | ) 176 | parser.add_argument( 177 | "--share_embeddings", 178 | dest="share_embeddings", 179 | action="store_true", 180 | default=True, 181 | help="Encoder/decoder share emebddings", 182 | ) 183 | 184 | # Optimization hyperparameters. 185 | parser.add_argument( 186 | "--batch_size", 187 | default=30, 188 | type=int, 189 | help="Training batch size (adjust based on GPU memory)", 190 | ) 191 | parser.add_argument( 192 | "--learning_rate", default=1e-3, type=float, help="Learning rate for training" 193 | ) 194 | parser.add_argument("--dropout", default=0.2, type=float, help="Dropout") 195 | parser.add_argument( 196 | "--num_epochs", 197 | default=20, 198 | type=int, 199 | help="Maximum number of epochs to run training", 200 | ) 201 | parser.add_argument( 202 | "--eval_every_epoch", 203 | default=1, 204 | type=int, 205 | help="Number of epochs to evaluate every", 206 | ) 207 | parser.add_argument( 208 | "--save_every_epoch", 209 | default=-1, 210 | type=int, 211 | help="Epochs to save the model every, -1 does not save", 212 | ) 213 | parser.add_argument( 214 | "--save_prudently", 215 | dest="save_prudently", 216 | action="store_true", 217 | default=False, 218 | help="Save checkpoints prudently (only best models)", 219 | ) 220 | parser.add_argument( 221 | "--gpu_id", type=int, default=-1, help="GPU id to use, -1 for CPU" 222 | ) 223 | try: 224 | parsed_args = vars(parser.parse_args()) 225 | except (IOError) as msg: 226 | parser.error(str(msg)) 227 | 228 | # For transformers, hidden size must be same as word_embed_size. 229 | if parsed_args["text_encoder"] == "transformer": 230 | assert ( 231 | parsed_args["word_embed_size"] == parsed_args["hidden_size"] 232 | ), "hidden_size should be same as word_embed_size for transformer" 233 | if not parsed_args["use_bahdanau_attention"]: 234 | print("Bahdanau attention must be off!") 235 | parsed_args["use_bahdanau_attention"] = False 236 | 237 | # If action output is to be used for LSTM, bahdahnau attention must be on. 238 | if parsed_args["use_action_output"] and parsed_args["text_encoder"] == "lstm": 239 | assert parsed_args["use_bahdanau_attention"], ( 240 | "Bahdanau attention " "must be on for action output to be used!" 241 | ) 242 | # For tf_idf, ignore the action_output flag. 243 | if parsed_args["encoder"] == "tf_idf": 244 | parsed_args["use_action_output"] = False 245 | # Prudent save is not possible without evaluation. 246 | if parsed_args["save_prudently"]: 247 | assert parsed_args[ 248 | "eval_data_path" 249 | ], "Prudent save needs a non-empty eval_data_path" 250 | 251 | # Set the cuda environment variable for the gpu to use and get context. 252 | parsed_args["use_gpu"] = support.setup_cuda_environment(parsed_args["gpu_id"]) 253 | # Force cuda initialization 254 | # (otherwise results in weird race conditions in PyTorch 1.4). 255 | if parsed_args["use_gpu"]: 256 | _ = torch.Tensor([1.0]).cuda() 257 | support.pretty_print_dict(parsed_args) 258 | return parsed_args 259 | -------------------------------------------------------------------------------- /mm_action_prediction/scripts/preprocess_simmc.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DOMAIN="furniture" 4 | # DOMAIN="fashion" 5 | ROOT="../data/simmc_${DOMAIN}/" 6 | 7 | # Input files. 8 | TRAIN_JSON_FILE="${ROOT}${DOMAIN}_train_dials.json" 9 | DEV_JSON_FILE="${ROOT}${DOMAIN}_dev_dials.json" 10 | DEVTEST_JSON_FILE="${ROOT}${DOMAIN}_devtest_dials.json" 11 | 12 | 13 | if [ "$DOMAIN" == "furniture" ]; then 14 | METADATA_FILE="${ROOT}furniture_metadata.csv" 15 | elif [ "$DOMAIN" == "fashion" ]; then 16 | METADATA_FILE="${ROOT}fashion_metadata.json" 17 | else 18 | echo "Invalid domain!" 19 | exit 0 20 | fi 21 | 22 | 23 | # Output files. 24 | VOCAB_FILE="${ROOT}${DOMAIN}_vocabulary.json" 25 | METADATA_EMBEDS="${ROOT}${DOMAIN}_asset_embeds.npy" 26 | ATTR_VOCAB_FILE="${ROOT}${DOMAIN}_attribute_vocabulary.json" 27 | 28 | 29 | # Step 1: Extract assistant API. 30 | INPUT_FILES="${TRAIN_JSON_FILE} ${DEV_JSON_FILE} ${DEVTEST_JSON_FILE}" 31 | # If statement. 32 | if [ "$DOMAIN" == "furniture" ]; then 33 | python tools/extract_actions.py \ 34 | --json_path="${INPUT_FILES}" \ 35 | --save_root="${ROOT}" \ 36 | --metadata_path="${METADATA_FILE}" 37 | elif [ "$DOMAIN" == "fashion" ]; then 38 | python tools/extract_actions_fashion.py \ 39 | --json_path="${INPUT_FILES}" \ 40 | --save_root="${ROOT}" \ 41 | --metadata_path="${METADATA_FILE}" 42 | else 43 | echo "Invalid domain!" 44 | exit 0 45 | fi 46 | 47 | 48 | # Step 2: Extract vocabulary from train. 49 | python tools/extract_vocabulary.py \ 50 | --train_json_path="${TRAIN_JSON_FILE}" \ 51 | --vocab_save_path="${VOCAB_FILE}" \ 52 | --threshold_count=5 53 | 54 | 55 | # Step 3: Read and embed shopping assets. 56 | if [ "$DOMAIN" == "furniture" ]; then 57 | python tools/embed_furniture_assets.py \ 58 | --input_csv_file="${METADATA_FILE}" \ 59 | --embed_path="${METADATA_EMBEDS}" 60 | elif [ "$DOMAIN" == "fashion" ]; then 61 | python tools/embed_fashion_assets.py \ 62 | --input_asset_file="${METADATA_FILE}" \ 63 | --embed_path="${METADATA_EMBEDS}" 64 | else 65 | echo "Invalid domain!" 66 | exit 0 67 | fi 68 | 69 | 70 | # Step 4: Convert all the splits into npy files for dataloader. 71 | SPLIT_JSON_FILES=("${TRAIN_JSON_FILE}" "${DEV_JSON_FILE}" "${DEVTEST_JSON_FILE}") 72 | for SPLIT_JSON_FILE in "${SPLIT_JSON_FILES[@]}" ; do 73 | python tools/build_multimodal_inputs.py \ 74 | --json_path="${SPLIT_JSON_FILE}" \ 75 | --vocab_file="${VOCAB_FILE}" \ 76 | --save_path="$ROOT" \ 77 | --action_json_path="${SPLIT_JSON_FILE/.json/_api_calls.json}" \ 78 | --retrieval_candidate_file="${SPLIT_JSON_FILE/.json/_retrieval_candidates.json}" \ 79 | --domain="${DOMAIN}" 80 | done 81 | 82 | 83 | # Step 5: Extract vocabulary for attributes from train npy file. 84 | python tools/extract_attribute_vocabulary.py \ 85 | --train_npy_path="${TRAIN_JSON_FILE/.json/_mm_inputs.npy}" \ 86 | --vocab_save_path="${ATTR_VOCAB_FILE}" \ 87 | --domain="${DOMAIN}" 88 | -------------------------------------------------------------------------------- /mm_action_prediction/scripts/train_all_simmc_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPU_ID=0 3 | # DOMAIN="furniture" 4 | DOMAIN="fashion" 5 | ROOT="../data/simmc_${DOMAIN}/" 6 | 7 | 8 | # Input files. 9 | TRAIN_JSON_FILE="${ROOT}${DOMAIN}_train_dials.json" 10 | DEV_JSON_FILE="${ROOT}${DOMAIN}_dev_dials.json" 11 | # Output files. 12 | METADATA_EMBEDS="${ROOT}${DOMAIN}_asset_embeds.npy" 13 | ATTR_VOCAB_FILE="${ROOT}${DOMAIN}_attribute_vocabulary.json" 14 | MODEL_METAINFO="models/${DOMAIN}_model_metainfo.json" 15 | CHECKPOINT_PATH="checkpoints" 16 | LOG_PATH="logs/" 17 | 18 | 19 | COMMON_FLAGS=" 20 | --train_data_path=${TRAIN_JSON_FILE/.json/_mm_inputs.npy} \ 21 | --eval_data_path=${DEV_JSON_FILE/.json/_mm_inputs.npy} \ 22 | --asset_embed_path=${METADATA_EMBEDS} \ 23 | --metainfo_path=${MODEL_METAINFO} \ 24 | --attr_vocab_path=${ATTR_VOCAB_FILE} \ 25 | --learning_rate=0.0001 --gpu_id=$GPU_ID --use_action_attention \ 26 | --num_epochs=100 --eval_every_epoch=5 --batch_size=20 \ 27 | --save_every_epoch=5 --word_embed_size=256 --num_layers=2 \ 28 | --hidden_size=512 \ 29 | --use_multimodal_state --use_action_output --use_bahdanau_attention \ 30 | --skip_bleu_evaluation --domain=${DOMAIN}" 31 | 32 | 33 | # History-agnostic model. 34 | function history_agnostic () { 35 | python -u train_simmc_agent.py $COMMON_FLAGS \ 36 | --encoder="history_agnostic" --text_encoder="lstm" \ 37 | --snapshot_path="${CHECKPOINT_PATH}/$1/hae/" &> "${LOG_PATH}/$1/hae.log" & 38 | } 39 | # Hierarchical recurrent encoder model. 40 | function hierarchical_recurrent () { 41 | python -u train_simmc_agent.py $COMMON_FLAGS \ 42 | --encoder="hierarchical_recurrent" --text_encoder="lstm" \ 43 | --snapshot_path="${CHECKPOINT_PATH}/$1/hre/" &> "${LOG_PATH}/$1/hre.log" & 44 | } 45 | # Memory encoder model. 46 | function memory_network () { 47 | python -u train_simmc_agent.py $COMMON_FLAGS \ 48 | --encoder="memory_network" --text_encoder="lstm" \ 49 | --snapshot_path="${CHECKPOINT_PATH}/$1/mn/" &> "${LOG_PATH}/$1/mn.log" & 50 | } 51 | # TF-IDF model. 52 | function tf_idf () { 53 | python -u train_simmc_agent.py $COMMON_FLAGS \ 54 | --encoder="tf_idf" --text_encoder="lstm" \ 55 | --snapshot_path="${CHECKPOINT_PATH}/$1/tf_idf/" &> "${LOG_PATH}/$1/tf_idf.log" & 56 | } 57 | # Transformer model. 58 | function transformer () { 59 | python -u train_simmc_agent.py $COMMON_FLAGS \ 60 | --encoder="history_agnostic" \ 61 | --text_encoder="transformer" \ 62 | --num_heads_transformer=4 --num_layers_transformer=4 \ 63 | --hidden_size_transformer=2048 --hidden_size=256\ 64 | --snapshot_path="${CHECKPOINT_PATH}/$1/transf/" &> "${LOG_PATH}/$1/transf.log" & 65 | } 66 | 67 | 68 | # Train all models on a domain Save checkpoints and logs with unique label. 69 | UNIQ_LABEL="${DOMAIN}_dstc_split" 70 | CUR_TIME=$(date +"_%m_%d_%Y_%H_%M_%S") 71 | UNIQ_LABEL+=$CUR_TIME 72 | mkdir "${LOG_PATH}${UNIQ_LABEL}" 73 | 74 | history_agnostic "$UNIQ_LABEL" 75 | hierarchical_recurrent "$UNIQ_LABEL" 76 | memory_network "$UNIQ_LABEL" 77 | tf_idf "$UNIQ_LABEL" 78 | transformer "$UNIQ_LABEL" 79 | -------------------------------------------------------------------------------- /mm_action_prediction/scripts/train_simmc_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPU_ID=0 4 | DOMAIN="furniture" 5 | # DOMAIN="fashion" 6 | ROOT="../data/simmc_${DOMAIN}/" 7 | 8 | 9 | # Input files. 10 | TRAIN_JSON_FILE="${ROOT}${DOMAIN}_train_dials.json" 11 | DEV_JSON_FILE="${ROOT}${DOMAIN}_dev_dials.json" 12 | DEVTEST_JSON_FILE="${ROOT}${DOMAIN}_devtest_dials.json" 13 | 14 | 15 | # Output files. 16 | METADATA_EMBEDS="${ROOT}${DOMAIN}_asset_embeds.npy" 17 | ATTR_VOCAB_FILE="${ROOT}${DOMAIN}_attribute_vocabulary.json" 18 | MODEL_METAINFO="models/${DOMAIN}_model_metainfo.json" 19 | 20 | 21 | COMMON_FLAGS=" 22 | --train_data_path=${TRAIN_JSON_FILE/.json/_mm_inputs.npy} \ 23 | --eval_data_path=${DEV_JSON_FILE/.json/_mm_inputs.npy} \ 24 | --asset_embed_path=${METADATA_EMBEDS} \ 25 | --metainfo_path=${MODEL_METAINFO} \ 26 | --attr_vocab_path=${ATTR_VOCAB_FILE} \ 27 | --learning_rate=0.0001 --gpu_id=$GPU_ID --use_action_attention \ 28 | --num_epochs=100 --eval_every_epoch=5 --batch_size=20 \ 29 | --save_every_epoch=5 --word_embed_size=256 --num_layers=2 \ 30 | --hidden_size=512 \ 31 | --use_multimodal_state --use_action_output --use_bahdanau_attention \ 32 | --skip_bleu_evaluation --domain=${DOMAIN}" 33 | 34 | 35 | # Train history-agnostic model. 36 | # For other models, please look at scripts/train_all_simmc_models.sh 37 | python -u train_simmc_agent.py $COMMON_FLAGS \ 38 | --encoder="history_agnostic" \ 39 | --text_encoder="lstm" 40 | 41 | 42 | # Evaluate a trained model checkpoint. 43 | CHECKPOINT_PATH="${CHECKPOINT_ROOT}/hae/epoch_20.tar" 44 | python -u eval_simmc_agent.py \ 45 | --eval_data_path=${DEVTEST_JSON_FILE/.json/_mm_inputs.npy} \ 46 | --checkpoint="$CHECKPOINT_PATH" --gpu_id=${GPU_ID} --batch_size=50 \ 47 | --domain="$DOMAIN" 48 | -------------------------------------------------------------------------------- /mm_action_prediction/tools/action_evaluation.py: -------------------------------------------------------------------------------- 1 | """Script evaluates action prediction along with attributes. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import argparse 9 | import collections 10 | import json 11 | 12 | import numpy as np 13 | 14 | 15 | IGNORE_ATTRIBUTES = [ 16 | "minPrice", 17 | "maxPrice", 18 | "furniture_id", 19 | "material", 20 | "decorStyle", 21 | "intendedRoom", 22 | "raw_matches", 23 | "focus", # fashion 24 | ] 25 | 26 | 27 | def evaluate_action_prediction( 28 | gt_actions, 29 | model_actions, 30 | single_round_eval=False, 31 | compute_std_err=False, 32 | record_instance_results=None, 33 | ): 34 | """Evaluates action prediction using the raw data and model predictions. 35 | 36 | Args: 37 | gt_actions: Ground truth actions + action attributes 38 | model_actions: Actions + attributes predicted by the model 39 | single_round_eval: Evaluate only for the last turn 40 | compute_std_err: Computes standard error for the metrics 41 | record_instance_results: Record the result per instance 42 | """ 43 | gt_actions_pool = {ii["dialog_id"]: ii for ii in gt_actions} 44 | matches = {"action": [], "attributes": [], "perplexity": []} 45 | confusion_dict = collections.defaultdict(list) 46 | for model_datum in model_actions: 47 | dialog_id = model_datum["dialog_id"] 48 | num_gt_rounds = len(gt_actions_pool[dialog_id]["actions"]) 49 | for round_datum in model_datum["predictions"]: 50 | round_id = round_datum["turn_id"] 51 | # Skip if single_round_eval and this is not the last round. 52 | if single_round_eval and round_id != num_gt_rounds - 1: 53 | continue 54 | 55 | gt_datum = gt_actions_pool[dialog_id]["actions"][round_id] 56 | action_match = gt_datum["action"] == round_datum["action"] 57 | # Record matches and confusion. 58 | matches["action"].append(action_match) 59 | matches["perplexity"].append( 60 | round_datum["action_log_prob"][gt_datum["action"]] 61 | ) 62 | confusion_dict[gt_datum["action"]].append(round_datum["action"]) 63 | 64 | # Add the result to datum and save it back. 65 | if record_instance_results: 66 | round_datum["action_result"] = action_match 67 | round_datum["gt_action"] = gt_datum["action"] 68 | 69 | # Get supervision for action attributes. 70 | supervision = gt_datum["action_supervision"] 71 | if supervision is not None and "args" in supervision: 72 | supervision = supervision["args"] 73 | if supervision is None: 74 | continue 75 | # Case 1: Action mismatch -- record False for all attributes. 76 | if not action_match: 77 | for key in supervision.keys(): 78 | if key in IGNORE_ATTRIBUTES: 79 | continue 80 | matches["attributes"].append(False) 81 | # Case 2: Action matches -- use model predictions for attributes. 82 | else: 83 | for key in supervision.keys(): 84 | if key in IGNORE_ATTRIBUTES: 85 | continue 86 | gt_key_vals = supervision[key] 87 | model_key_vals = round_datum["attributes"][key] 88 | if not len(gt_key_vals): 89 | continue 90 | # For fashion, this is a list -- multi label prediction. 91 | if isinstance(gt_key_vals, list): 92 | assert isinstance( 93 | model_key_vals, list 94 | ), "Model should also predict a list for attributes" 95 | recall = np.mean([(ii in model_key_vals) for ii in gt_key_vals]) 96 | if len(model_key_vals): 97 | precision = np.mean( 98 | [(ii in gt_key_vals) for ii in model_key_vals] 99 | ) 100 | else: 101 | precision = 0.0 102 | f1_score = (2 * recall * precision) / ( 103 | recall + precision + 1e-5 104 | ) 105 | matches["attributes"].append(f1_score) 106 | else: 107 | # For furniture, this is a string -- single label prediction. 108 | matches["attributes"].append(gt_key_vals == model_key_vals) 109 | 110 | print("#Instances evaluated API: {}".format(len(matches["action"]))) 111 | # Record and save per instance results. 112 | if record_instance_results: 113 | print("Saving per instance result: {}".format(record_instance_results)) 114 | with open(record_instance_results, "w") as file_id: 115 | json.dump(model_actions, file_id) 116 | 117 | # Compute the confusion matrix. 118 | all_actions = sorted( 119 | set(confusion_dict.keys()).union( 120 | {jj for ii in confusion_dict.values() for jj in ii} 121 | ) 122 | ) 123 | matrix = np.zeros((len(all_actions), len(all_actions))) 124 | for index, action in enumerate(all_actions): 125 | labels, counts = np.unique(confusion_dict[action], return_counts=True) 126 | for label, count in zip(labels, counts): 127 | matrix[all_actions.index(label), index] += count 128 | 129 | metrics = { 130 | "action_accuracy": np.mean(matches["action"]), 131 | "action_perplexity": np.exp(-1 * np.mean(matches["perplexity"])), 132 | "attribute_accuracy": np.mean(matches["attributes"]), 133 | "confusion_matrix": matrix, 134 | } 135 | if compute_std_err: 136 | metrics_std_err = { 137 | "action_accuracy": ( 138 | np.std(matches["action"]) / np.sqrt(len(matches["action"])) 139 | ), 140 | "action_perplexity": ( 141 | ( 142 | np.exp(-1 * np.std(matches["perplexity"])) 143 | / np.sqrt(len(matches["perplexity"])) 144 | ) 145 | ), 146 | "attribute_accuracy": ( 147 | np.std(matches["attributes"]) / np.sqrt(len(matches["attributes"])) 148 | ), 149 | } 150 | return metrics, metrics_std_err 151 | else: 152 | return metrics 153 | 154 | 155 | def main(args): 156 | print("Reading: {}".format(args["action_json_path"])) 157 | with open(args["action_json_path"], "r") as file_id: 158 | gt_actions = json.load(file_id) 159 | print("Reading: {}".format(args["model_output_path"])) 160 | with open(args["model_output_path"], "r") as file_id: 161 | model_actions = json.load(file_id) 162 | 163 | if args["record_instance_results"]: 164 | instance_results_path = args["model_output_path"].replace( 165 | ".json", "_results.json" 166 | ) 167 | else: 168 | instance_results_path = None 169 | 170 | action_metrics = evaluate_action_prediction( 171 | gt_actions, 172 | model_actions, 173 | args["single_round_evaluation"], 174 | args["compute_std_err"], 175 | instance_results_path, 176 | ) 177 | 178 | if args["compute_std_err"]: 179 | action_std_err = action_metrics[1] 180 | action_metrics = action_metrics[0] 181 | print("\nStandard error:") 182 | print(action_std_err) 183 | print(action_metrics) 184 | 185 | 186 | if __name__ == "__main__": 187 | parser = argparse.ArgumentParser(description="API Call Action Evaluation") 188 | parser.add_argument( 189 | "--action_json_path", 190 | default="data/furniture_api_calls.json", 191 | help="Ground truth API calls", 192 | ) 193 | parser.add_argument( 194 | "--model_output_path", default=None, help="API calls generated by the model" 195 | ) 196 | parser.add_argument( 197 | "--single_round_evaluation", 198 | dest="single_round_evaluation", 199 | action="store_true", 200 | default=False, 201 | help="Single round evaluation for hidden split", 202 | ) 203 | parser.add_argument( 204 | "--compute_std_err", 205 | dest="compute_std_err", 206 | action="store_true", 207 | default=False, 208 | help="Computes standard error for the metrics", 209 | ) 210 | parser.add_argument( 211 | "--record_instance_results", 212 | dest="record_instance_results", 213 | action="store_true", 214 | default=False, 215 | help="Records per instance results and save it back", 216 | ) 217 | try: 218 | parsed_args = vars(parser.parse_args()) 219 | except (IOError) as msg: 220 | parser.error(str(msg)) 221 | main(parsed_args) 222 | -------------------------------------------------------------------------------- /mm_action_prediction/tools/embed_fashion_assets.py: -------------------------------------------------------------------------------- 1 | """Create fashion assest embeddings by concatenating attribute Glove embeddings. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import argparse 9 | import ast 10 | import json 11 | 12 | import numpy as np 13 | import spacy 14 | 15 | 16 | # Attributes to encode. 17 | EMBED_ATTRIBUTES = ["type", "color", "embellishments", "pattern"] 18 | 19 | 20 | def main(args): 21 | with open(args["input_asset_file"], "r") as file_id: 22 | assets = json.load(file_id) 23 | 24 | # Select and embed only the top attributes. 25 | cleaned_assets = [] 26 | for image_id, asset in assets.items(): 27 | clean_asset = {} 28 | asset_info = asset["metadata"] 29 | for key in EMBED_ATTRIBUTES: 30 | if key in asset_info: 31 | val = asset_info[key] 32 | # val = correction.get(val, val).lower() 33 | val = ast.literal_eval(val) if "[" in val else val 34 | clean_asset[key] = val if isinstance(val, list) else [val] 35 | clean_asset["id"] = int(image_id) 36 | cleaned_assets.append(clean_asset) 37 | # Vocabulary for each field. 38 | vocabulary = {key: {} for key in EMBED_ATTRIBUTES} 39 | for asset in cleaned_assets: 40 | for attr in EMBED_ATTRIBUTES: 41 | attr_val = asset.get(attr, []) 42 | for val in attr_val: 43 | vocabulary[attr][val] = vocabulary[attr].get(val, 0) + 1 44 | 45 | # Embedding for each item. 46 | nlp = spacy.load(args["spacy_model"]) 47 | sample_feature = nlp("apple").vector 48 | feature_size = sample_feature.size 49 | zero_features = np.zeros(feature_size) 50 | embeddings = [] 51 | id_list = [] 52 | for asset in cleaned_assets: 53 | embed_vector = [] 54 | for attr in EMBED_ATTRIBUTES: 55 | if attr in asset and len(asset[attr]) > 0: 56 | attr_val = asset[attr] 57 | feature_vector = np.stack( 58 | [nlp(val).vector for val in attr_val] 59 | ).mean(0) 60 | else: 61 | feature_vector = zero_features 62 | embed_vector.append(feature_vector) 63 | embeddings.append(np.concatenate(embed_vector)) 64 | id_list.append(asset["id"]) 65 | embeddings = np.stack(embeddings) 66 | print("Saving embeddings: {}".format(args["embed_path"])) 67 | np.save( 68 | args["embed_path"], 69 | { 70 | "asset_id": id_list, 71 | "embedding": embeddings, 72 | "asset_feature_size": embeddings.shape[1], 73 | }, 74 | ) 75 | 76 | 77 | if __name__ == "__main__": 78 | parser = argparse.ArgumentParser(description="Embed Fashion assets") 79 | parser.add_argument( 80 | "--input_asset_file", 81 | default="data/fashion_assets.json", 82 | help="Fashion metadata file for assets", 83 | ) 84 | parser.add_argument( 85 | "--embed_path", 86 | default="data/fashion_metadata_embed.npy", 87 | help="Embeddings for Fashion assets", 88 | ) 89 | parser.add_argument( 90 | "--spacy_model", 91 | default="en_vectors_web_lg", 92 | help="Spacy model to use for language model", 93 | ) 94 | try: 95 | parsed_args = vars(parser.parse_args()) 96 | except (IOError) as msg: 97 | parser.error(str(msg)) 98 | main(parsed_args) 99 | -------------------------------------------------------------------------------- /mm_action_prediction/tools/embed_furniture_assets.py: -------------------------------------------------------------------------------- 1 | """Create furniture assest embeddings by concatenating attribute Glove embeddings. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import argparse 9 | import ast 10 | 11 | import numpy as np 12 | from tools import data_support 13 | import spacy 14 | 15 | 16 | # Attributes to encode. 17 | EMBED_ATTRIBUTES = [ 18 | "class_name", "color", "decor_style", "intended_room", "material" 19 | ] 20 | 21 | 22 | def main(args): 23 | assets = data_support.read_furniture_metadata(args["input_csv_file"]) 24 | cleaned_assets = [] 25 | # Quick fix dictionary. 26 | correction = { 27 | "['Traditional', 'Modern'']": "['Traditional', 'Modern']", 28 | "[Brown']": "['Brown']", 29 | } 30 | for _, asset in assets.items(): 31 | clean_asset = {} 32 | for key in EMBED_ATTRIBUTES: 33 | val = asset[key] 34 | val = correction.get(val, val).lower() 35 | val = ast.literal_eval(val) if "[" in val else val 36 | clean_asset[key] = val if isinstance(val, list) else [val] 37 | clean_asset["id"] = int(asset["obj"].split("/")[-1].strip(".zip")) 38 | cleaned_assets.append(clean_asset) 39 | 40 | # Vocabulary for each field. 41 | vocabulary = {key: {} for key in EMBED_ATTRIBUTES} 42 | for asset in cleaned_assets: 43 | for attr in EMBED_ATTRIBUTES: 44 | attr_val = asset[attr] 45 | for val in attr_val: 46 | vocabulary[attr][val] = vocabulary[attr].get(val, 0) + 1 47 | 48 | # Embedding for each item. 49 | nlp = spacy.load(args["spacy_model"]) 50 | embeddings = [] 51 | id_list = [] 52 | for asset in cleaned_assets: 53 | embed_vector = [] 54 | for attr in EMBED_ATTRIBUTES: 55 | attr_val = asset[attr] 56 | feature_vector = np.stack([nlp(val).vector for val in attr_val]) 57 | embed_vector.append(feature_vector.mean(0)) 58 | embeddings.append(np.concatenate(embed_vector)) 59 | id_list.append(asset["id"]) 60 | embeddings = np.stack(embeddings) 61 | print("Saving embeddings: {}".format(args["embed_path"])) 62 | feature_size = embeddings.shape[1] 63 | np.save( 64 | args["embed_path"], 65 | { 66 | "asset_id": id_list, 67 | "embedding": embeddings, 68 | "asset_feature_size": feature_size, 69 | }, 70 | ) 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser(description="Embed furniture assets") 75 | parser.add_argument( 76 | "--input_csv_file", 77 | default="data/furniture_metadata.csv", 78 | help="Furniture metadata file for assets", 79 | ) 80 | parser.add_argument( 81 | "--embed_path", 82 | default="data/furniture_metadata_embed.npy", 83 | help="Embeddings for furniture assets", 84 | ) 85 | parser.add_argument( 86 | "--spacy_model", 87 | default="en_vectors_web_lg", 88 | help="Spacy model to use for language model", 89 | ) 90 | try: 91 | parsed_args = vars(parser.parse_args()) 92 | except (IOError) as msg: 93 | parser.error(str(msg)) 94 | main(parsed_args) 95 | -------------------------------------------------------------------------------- /mm_action_prediction/tools/extract_actions_fashion.py: -------------------------------------------------------------------------------- 1 | """Extract action API supervision for the SIMMC Fashion dataset. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | from absl import flags 9 | from absl import app 10 | import ast 11 | import json 12 | import os 13 | 14 | 15 | FLAGS = flags.FLAGS 16 | flags.DEFINE_spaceseplist( 17 | "json_path", "data/furniture_pilot_oct24.json", "JSON containing the dataset" 18 | ) 19 | flags.DEFINE_string( 20 | "save_root", "data/", "Folder path to save extraced api annotations" 21 | ) 22 | flags.DEFINE_string( 23 | "metadata_path", "data/fashion_metadata.json", "Path to fashion metadata" 24 | ) 25 | 26 | 27 | def extract_actions(input_json_file): 28 | """Extract action API for SIMMC fashion. 29 | 30 | Args: 31 | input_json_file: JSON data file to extraction actions 32 | """ 33 | print("Reading: {}".format(input_json_file)) 34 | with open(input_json_file, "r") as file_id: 35 | raw_data = json.load(file_id) 36 | 37 | task_mapping = {ii["task_id"]: ii for ii in raw_data["task_mapping"]} 38 | dialogs = [] 39 | for dialog_datum in raw_data["dialogue_data"]: 40 | dialog_id = dialog_datum["dialogue_idx"] 41 | # If task id is missing for the dialog, assign a random task. 42 | # Could lead to problems but it is for < 0.1% of the data 43 | if "dialogue_task_id" not in dialog_datum: 44 | # Assign a random task for missing ids. 45 | print("Dialogue task Id missing: {}".format(dialog_id)) 46 | mm_state = task_mapping[1874] 47 | else: 48 | mm_state = task_mapping[dialog_datum["dialogue_task_id"]] 49 | focus_image = mm_state["focus_image"] 50 | focus_images = [] 51 | roundwise_actions = [] 52 | 53 | for round_datum in dialog_datum["dialogue"]: 54 | focus_images.append(focus_image) 55 | # Default None action. 56 | insert_item = { 57 | "turn_idx": round_datum["turn_idx"], 58 | "action": "None", 59 | "action_supervision": None 60 | } 61 | keystrokes = round_datum.get("raw_assistant_keystrokes", []) 62 | # Get information attributes given the asset id. 63 | attributes = extract_info_attributes(round_datum) 64 | if keystrokes: 65 | focus_image = int(keystrokes[0]["image_id"]) 66 | # Change of focus image -> Search in dataset or memory. 67 | if focus_image in mm_state["memory_images"]: 68 | insert_item["action"] = "SearchMemory" 69 | insert_item["action_supervision"] = { 70 | "focus": focus_image, 71 | "attributes": attributes, 72 | } 73 | elif focus_image in mm_state["database_images"]: 74 | insert_item["action"] = "SearchDatabase" 75 | insert_item["action_supervision"] = { 76 | "focus": focus_image, 77 | "attributes": attributes, 78 | } 79 | else: 80 | print("Undefined action; using None instead") 81 | roundwise_actions.append(insert_item) 82 | else: 83 | # Check for SpecifyInfo action. 84 | # Get information attributes given the asset id. 85 | attributes = extract_info_attributes(round_datum) 86 | if len(attributes): 87 | insert_item["action"] = "SpecifyInfo" 88 | insert_item["action_supervision"] = { 89 | "attributes": attributes 90 | } 91 | else: 92 | # AddToCart action. 93 | for intent_info in ast.literal_eval( 94 | round_datum["transcript_annotated"] 95 | ): 96 | if "DA:REQUEST:ADD_TO_CART" in intent_info["intent"]: 97 | insert_item["action"] = "AddToCart" 98 | insert_item["action_supervision"] = None 99 | roundwise_actions.append(insert_item) 100 | 101 | dialogs.append( 102 | { 103 | "dialog_id": dialog_id, 104 | "actions": roundwise_actions, 105 | "focus_images": focus_images, 106 | } 107 | ) 108 | 109 | # Save extracted API calls. 110 | save_path = input_json_file.split("/")[-1].replace(".json", "_api_calls.json") 111 | save_path = os.path.join(FLAGS.save_root, save_path) 112 | print("Saving: {}".format(save_path)) 113 | with open(save_path, "w") as f: 114 | json.dump(dialogs, f) 115 | 116 | 117 | def extract_info_attributes(round_datum): 118 | """Extract information attributes for current round using NLU annotations. 119 | 120 | Args: 121 | round_datum: Current round information 122 | 123 | Returns: 124 | get_attribute_matches: Information attributes 125 | """ 126 | user_annotation = ast.literal_eval(round_datum["transcript_annotated"]) 127 | # assistant_annotation = ast.literal_eval( 128 | # round_datum["system_transcript_annotated"] 129 | # ) 130 | # annotation = user_annotation + assistant_annotation 131 | annotation = user_annotation 132 | all_intents = [ii["intent"] for ii in annotation] 133 | get_attribute_matches = [] 134 | for index, intent in enumerate(all_intents): 135 | if any( 136 | ii in intent 137 | for ii in ("DA:ASK:GET", "DA:ASK:CHECK", "DA:INFORM:GET") 138 | ): 139 | # If there is no attribute added, default to info. 140 | if "." not in intent: 141 | get_attribute_matches.append("info") 142 | continue 143 | 144 | attribute = intent.split(".")[-1] 145 | if attribute == "info": 146 | new_matches = [ 147 | ii["id"].split(".")[-1] 148 | for ii in annotation[index]["slots"] 149 | if "INFO" in ii["id"] 150 | ] 151 | if len(new_matches): 152 | get_attribute_matches.extend(new_matches) 153 | else: 154 | get_attribute_matches.append("info") 155 | elif attribute != "": 156 | get_attribute_matches.append(attribute) 157 | return sorted(set(get_attribute_matches)) 158 | 159 | 160 | def main(_): 161 | for input_json_file in FLAGS.json_path: 162 | extract_actions(input_json_file) 163 | 164 | 165 | if __name__ == "__main__": 166 | app.run(main) 167 | -------------------------------------------------------------------------------- /mm_action_prediction/tools/extract_attribute_vocabulary.py: -------------------------------------------------------------------------------- 1 | """Extracts attribute vocabulary for SIMMC Furniture. 2 | 3 | Author: Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import collections 9 | import json 10 | import argparse 11 | import numpy as np 12 | 13 | 14 | # Furniture. 15 | EXCLUDE_KEYS_FURNITURE = [ 16 | "minPrice", 17 | "maxPrice", 18 | "furniture_id", 19 | "material", 20 | "decorStyle", 21 | "intendedRoom", 22 | "raw_matches", 23 | ] 24 | # Furniture 25 | EXCLUDE_KEYS_FASHION = ["focus", "memory"] 26 | INCLUDE_ATTRIBUTES_FASHION = [ 27 | "availableSizes", "price", "brand", "customerRating", "info", "color" 28 | ] 29 | 30 | 31 | # Key aliases. 32 | DOMAIN = "domain" 33 | FURNITURE = "furniture" 34 | FASHION = "fashion" 35 | 36 | 37 | def extract_action_attributes(args): 38 | """Read training multimodal input, extract attribute vocabulary (furniture) 39 | """ 40 | # Read the data, parse the datapoints. 41 | data = np.load(args["train_npy_path"], allow_pickle=True)[()] 42 | actions = data["action"] 43 | num_instances, num_rounds = actions.shape 44 | 45 | # Get action attributes. 46 | attr_vocab = {} 47 | for ii in range(num_instances): 48 | for jj in range(num_rounds): 49 | cur_action = actions[ii, jj] 50 | if cur_action == "None": 51 | continue 52 | if cur_action not in attr_vocab: 53 | if args[DOMAIN] == FURNITURE: 54 | attr_vocab[cur_action] = collections.defaultdict(dict) 55 | elif args[DOMAIN] == FASHION: 56 | attr_vocab[cur_action] = collections.defaultdict( 57 | lambda: collections.defaultdict(lambda: 0) 58 | ) 59 | 60 | cur_super = data["action_supervision"][ii][jj] 61 | if cur_super is None: 62 | continue 63 | for key, val in cur_super.items(): 64 | if args[DOMAIN] == FURNITURE: 65 | if key in EXCLUDE_KEYS_FURNITURE: 66 | continue 67 | if isinstance(val, list): 68 | val = tuple(val) 69 | new_count = attr_vocab[cur_action][key].get(val, 0) + 1 70 | attr_vocab[cur_action][key][val] = new_count 71 | 72 | elif args[DOMAIN] == FASHION: 73 | if key in EXCLUDE_KEYS_FASHION: 74 | continue 75 | if isinstance(val, list): 76 | val = tuple(val) 77 | for vv in val: 78 | # If vv not in INCLUDE_ATTRIBUTES_FASHION, 79 | # assign it to "other." 80 | if vv not in INCLUDE_ATTRIBUTES_FASHION: 81 | vv = "other" 82 | attr_vocab[cur_action][key][vv] += 1 83 | else: 84 | # If val not in INCLUDE_ATTRIBUTES_FASHION, 85 | # assign it to other. 86 | if val not in INCLUDE_ATTRIBUTES_FASHION: 87 | val = "other" 88 | attr_vocab[cur_action][key][val] += 1 89 | 90 | attr_vocab = { 91 | key: sorted(val) 92 | for attr_values in attr_vocab.values() 93 | for key, val in attr_values.items() 94 | } 95 | print(attr_vocab) 96 | print("Saving attribute dictionary: {}".format(args["vocab_save_path"])) 97 | with open(args["vocab_save_path"], "w") as file_id: 98 | json.dump(attr_vocab, file_id) 99 | 100 | 101 | def print_fashion_attributes(attribute_vocabulary): 102 | """Prints fashion attributes (for visualization). 103 | 104 | Args: 105 | attribute_vocabulary: Extracted attribute vocabulary count dict. 106 | """ 107 | for key, val in attribute_vocabulary.items(): 108 | print(key) 109 | print(val.keys()) 110 | for attr, attr_val_dict in val.items(): 111 | print('Name: {}'.format(attr)) 112 | for ii in sorted( 113 | attr_val_dict.items(), key=lambda x: x[1], reverse=True 114 | ): 115 | print("\t{}: {}".format(*ii)) 116 | print("-" * 50) 117 | 118 | 119 | if __name__ == "__main__": 120 | # Read the commandline arguments. 121 | parser = argparse.ArgumentParser(description="Extract vocabulary") 122 | parser.add_argument( 123 | "--train_json_path", 124 | default=None, 125 | help="Path to read the vocabulary (train) JSON", 126 | ) 127 | parser.add_argument( 128 | "--train_npy_path", 129 | default=None, 130 | help="Path to read the vocabulary (train) Numpy file", 131 | ) 132 | parser.add_argument( 133 | "--vocab_save_path", 134 | default="data/vocabulary_genie.json", 135 | help="Path to read the vocabulary (train) JSON", 136 | ) 137 | parser.add_argument( 138 | "--domain", 139 | default="furniture", 140 | choices=["furniture", "fashion"], 141 | help="Domain to extract attribute vocabulary" 142 | ) 143 | try: 144 | parsed_args = vars(parser.parse_args()) 145 | except (IOError) as msg: 146 | parser.error(str(msg)) 147 | 148 | # Extract action API attributes using training file. 149 | extract_action_attributes(parsed_args) 150 | -------------------------------------------------------------------------------- /mm_action_prediction/tools/extract_vocabulary.py: -------------------------------------------------------------------------------- 1 | """Extracts vocabulary for SIMMC dataset. 2 | 3 | Author: Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import json 9 | import argparse 10 | from nltk.tokenize import word_tokenize 11 | 12 | 13 | def main(args): 14 | # Read the data, parse the datapoints. 15 | print("Reading: {}".format(args["train_json_path"])) 16 | with open(args["train_json_path"], "r") as file_id: 17 | train_data = json.load(file_id) 18 | dialog_data = train_data["dialogue_data"] 19 | 20 | counts = {} 21 | for datum in dialog_data: 22 | dialog_utterances = [ 23 | ii[key] for ii in datum["dialogue"] 24 | for key in ("transcript", "system_transcript") 25 | ] 26 | dialog_tokens = [ 27 | word_tokenize(ii.lower()) for ii in dialog_utterances 28 | ] 29 | for turn in dialog_tokens: 30 | for word in turn: 31 | counts[word] = counts.get(word, 0) + 1 32 | 33 | # Add , , , . 34 | counts[""] = args["threshold_count"] + 1 35 | counts[""] = args["threshold_count"] + 1 36 | counts[""] = args["threshold_count"] + 1 37 | counts[""] = args["threshold_count"] + 1 38 | 39 | word_counts = sorted(counts.items(), key=lambda x: x[1], reverse=True) 40 | words = [ii[0] for ii in word_counts if ii[1] >= args["threshold_count"]] 41 | vocabulary = {"word": words} 42 | # Save answers and vocabularies. 43 | print("Identified {} words..".format(len(words))) 44 | print("Saving dictionary: {}".format(args["vocab_save_path"])) 45 | with open(args["vocab_save_path"], "w") as file_id: 46 | json.dump(vocabulary, file_id) 47 | 48 | 49 | if __name__ == "__main__": 50 | # Read the commandline arguments. 51 | parser = argparse.ArgumentParser(description="Extract vocabulary") 52 | parser.add_argument( 53 | "--train_json_path", 54 | default="data/furniture_data.json", 55 | help="Path to read the vocabulary (train) JSON", 56 | ) 57 | parser.add_argument( 58 | "--vocab_save_path", 59 | default="data/furniture_vocabulary.json", 60 | help="Path to read the vocabulary (train) JSON", 61 | ) 62 | parser.add_argument( 63 | "--threshold_count", 64 | default=0, 65 | type=int, 66 | help="Words are included if beyond this threshold", 67 | ) 68 | try: 69 | parsed_args = vars(parser.parse_args()) 70 | except (IOError) as msg: 71 | parser.error(str(msg)) 72 | main(parsed_args) 73 | -------------------------------------------------------------------------------- /mm_action_prediction/tools/response_evaluation.py: -------------------------------------------------------------------------------- 1 | """Script evaluates response generation using GT responses. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import argparse 9 | import json 10 | 11 | import nltk 12 | import numpy as np 13 | 14 | 15 | def normalize_sentence(sentence): 16 | """Normalize the sentences and tokenize.""" 17 | return nltk.tokenize.word_tokenize(sentence.lower()) 18 | 19 | 20 | def evaluate_response_generation( 21 | gt_responses, 22 | model_responses, 23 | single_round_eval=False, 24 | compute_std_err=False, 25 | record_instance_results=None, 26 | ): 27 | """Evaluates response generation using the raw data and model predictions. 28 | 29 | Args: 30 | gt_responses: Ground truth responses. 31 | model_responses: Generated responses. 32 | single_round_eval: Evaluate only for the last turn 33 | compute_std_err: Computes standard error for the metrics 34 | record_instance_results: Path to record per instance results 35 | """ 36 | gt_responses_pool = {ii["dialogue_idx"]: ii for ii in gt_responses["dialogue_data"]} 37 | bleu_scores = [] 38 | # Smoothing function. 39 | chencherry = nltk.translate.bleu_score.SmoothingFunction() 40 | num_evaluations = 0 41 | for model_datum in model_responses: 42 | dialog_id = model_datum["dialog_id"] 43 | num_gt_rounds = len(gt_responses_pool[dialog_id]["dialogue"]) 44 | for round_datum in model_datum["predictions"]: 45 | round_id = round_datum["turn_id"] 46 | # Skip if single_round_eval and this is not the last round. 47 | if single_round_eval and round_id != num_gt_rounds - 1: 48 | continue 49 | 50 | response = round_datum["response"] 51 | gt_datum = gt_responses_pool[dialog_id]["dialogue"][round_id] 52 | gt_response = gt_datum["system_transcript"] 53 | 54 | bleu_score = nltk.translate.bleu_score.sentence_bleu( 55 | [normalize_sentence(gt_response)], 56 | normalize_sentence(response), 57 | smoothing_function=chencherry.method1, 58 | ) 59 | bleu_scores.append(bleu_score) 60 | 61 | # Add the result to datum and save it back. 62 | if record_instance_results: 63 | round_datum["bleu"] = bleu_score 64 | round_datum["response_len"] = len(normalize_sentence(gt_response)) 65 | print("#Instances evaluated BLEU: {}".format(len(bleu_scores))) 66 | # Record and save per instance results. 67 | if record_instance_results: 68 | print("Saving per instance result: {}".format(record_instance_results)) 69 | with open(record_instance_results, "w") as file_id: 70 | json.dump(model_responses, file_id) 71 | 72 | bleu_std_err = np.std(bleu_scores) / np.sqrt(len(bleu_scores)) 73 | if compute_std_err: 74 | return np.mean(bleu_scores), bleu_std_err 75 | else: 76 | return np.mean(bleu_scores) 77 | 78 | 79 | def main(args): 80 | print("Reading: {}".format(args["data_json_path"])) 81 | with open(args["data_json_path"], "r") as file_id: 82 | gt_responses = json.load(file_id) 83 | print("Reading: {}".format(args["model_response_path"])) 84 | with open(args["model_response_path"], "r") as file_id: 85 | model_responses = json.load(file_id) 86 | 87 | if args["record_instance_results"]: 88 | instance_results_path = args["model_response_path"].replace( 89 | ".json", "_results.json" 90 | ) 91 | else: 92 | instance_results_path = None 93 | 94 | bleu_score = evaluate_response_generation( 95 | gt_responses, 96 | model_responses, 97 | args["single_round_evaluation"], 98 | args["compute_std_err"], 99 | instance_results_path, 100 | ) 101 | 102 | if args["compute_std_err"]: 103 | bleu_std_err = bleu_score[1] 104 | bleu_score = bleu_score[0] 105 | else: 106 | bleu_std_err = 0.0 107 | print("BLEU Score: {:.4f} +- {:.4f}".format(bleu_score, bleu_std_err)) 108 | 109 | 110 | if __name__ == "__main__": 111 | parser = argparse.ArgumentParser(description="Response Generation Evaluation") 112 | parser.add_argument( 113 | "--data_json_path", 114 | default="data/furniture_train.json", 115 | help="Data with gold responses", 116 | ) 117 | parser.add_argument( 118 | "--model_response_path", default=None, help="Responses generated by the model" 119 | ) 120 | parser.add_argument( 121 | "--single_round_evaluation", 122 | dest="single_round_evaluation", 123 | action="store_true", 124 | default=False, 125 | help="Single round evaluation for hidden split", 126 | ) 127 | parser.add_argument( 128 | "--compute_std_err", 129 | dest="compute_std_err", 130 | action="store_true", 131 | default=False, 132 | help="Computes standard error for the metrics", 133 | ) 134 | parser.add_argument( 135 | "--record_instance_results", 136 | dest="record_instance_results", 137 | action="store_true", 138 | default=False, 139 | help="Records per instance results and save it back", 140 | ) 141 | try: 142 | parsed_args = vars(parser.parse_args()) 143 | except (IOError) as msg: 144 | parser.error(str(msg)) 145 | main(parsed_args) 146 | -------------------------------------------------------------------------------- /mm_action_prediction/tools/retrieval_evaluation.py: -------------------------------------------------------------------------------- 1 | """Script evaluates response retrieval using GT responses. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | 7 | import argparse 8 | import json 9 | 10 | import numpy as np 11 | 12 | 13 | def evaluate_response_retrieval( 14 | gt_responses, model_scores, single_round_eval=False, compute_std_err=False 15 | ): 16 | """Evaluates response retrieval using the raw data and model predictions. 17 | 18 | Args: 19 | gt_responses: Ground truth responses. 20 | model_scores: Scores assigned by the model to the candidates 21 | single_round_eval: Evaluate only for the last turn 22 | compute_std_err: Computes standard error for the metrics 23 | 24 | If in single round evaluation model (mostly for hidden test-std split), 25 | use hidden gt_index field. Else, 0th element is the ground truth for other 26 | splits. 27 | """ 28 | gt_index_pool = { 29 | ii["dialogue_idx"]: ii for ii in gt_responses["retrieval_candidates"] 30 | } 31 | gt_ranks = [] 32 | for model_datum in model_scores: 33 | dialog_id = model_datum["dialog_id"] 34 | gt_datum = gt_index_pool[dialog_id]["retrieval_candidates"] 35 | num_gt_rounds = len(gt_datum) 36 | for round_id, round_datum in enumerate(model_datum["candidate_scores"]): 37 | round_id = round_datum["turn_id"] 38 | # Skip if single_round_eval and this is not the last round. 39 | if single_round_eval and round_id != num_gt_rounds - 1: 40 | continue 41 | 42 | gt_index = gt_datum[round_id]["gt_index"] 43 | current_turn = round_datum["turn_id"] 44 | round_scores = round_datum["scores"] 45 | gt_score = round_scores[gt_index] 46 | gt_ranks.append(np.sum(np.array(round_scores) > gt_score) + 1) 47 | gt_ranks = np.array(gt_ranks) 48 | print("#Instances evaluated retrieval: {}".format(gt_ranks.size)) 49 | 50 | num_instances = gt_ranks.size 51 | metrics = { 52 | "r1": np.mean(gt_ranks <= 1), 53 | "r5": np.mean(gt_ranks <= 5), 54 | "r10": np.mean(gt_ranks <= 10), 55 | "mean": np.mean(gt_ranks), 56 | "mrr": np.mean(1 / gt_ranks), 57 | } 58 | if compute_std_err: 59 | metrics_std_err = { 60 | "r1": np.std(gt_ranks <= 1) / np.sqrt(num_instances), 61 | "r5": np.std(gt_ranks <= 5) / np.sqrt(num_instances), 62 | "r10": np.std(gt_ranks <= 10) / np.sqrt(num_instances), 63 | "mean": np.std(gt_ranks) / np.sqrt(num_instances), 64 | "mrr": np.std(1 / gt_ranks) / np.sqrt(num_instances), 65 | } 66 | return metrics, metrics_std_err 67 | else: 68 | return metrics 69 | 70 | 71 | def main(args): 72 | print("Reading: {}".format(args["retrieval_json_path"])) 73 | with open(args["retrieval_json_path"], "r") as file_id: 74 | gt_responses = json.load(file_id) 75 | print("Reading: {}".format(args["model_score_path"])) 76 | with open(args["model_score_path"], "r") as file_id: 77 | model_scores = json.load(file_id) 78 | retrieval_metrics = evaluate_response_retrieval( 79 | gt_responses, 80 | model_scores, 81 | args["single_round_evaluation"], 82 | args["compute_std_err"], 83 | ) 84 | if args["compute_std_err"]: 85 | retrieval_std_err = retrieval_metrics[1] 86 | retrieval_metrics = retrieval_metrics[0] 87 | print("\nStandard error:") 88 | print(retrieval_std_err) 89 | print(retrieval_metrics) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser(description="Response Retrieval Evaluation") 94 | parser.add_argument( 95 | "--retrieval_json_path", 96 | default="data/furniture_train_retrieval_candidates.json", 97 | help="Data with retrieval candidates, gt", 98 | ) 99 | parser.add_argument( 100 | "--model_score_path", 101 | default=None, 102 | help="Candidate scores generated by the model", 103 | ) 104 | parser.add_argument( 105 | "--single_round_evaluation", 106 | dest="single_round_evaluation", 107 | action="store_true", 108 | default=False, 109 | help="Single round evaluation for hidden split", 110 | ) 111 | parser.add_argument( 112 | "--compute_std_err", 113 | dest="compute_std_err", 114 | action="store_true", 115 | default=False, 116 | help="Computes standard error for the metrics", 117 | ) 118 | try: 119 | parsed_args = vars(parser.parse_args()) 120 | except (IOError) as msg: 121 | parser.error(str(msg)) 122 | main(parsed_args) 123 | -------------------------------------------------------------------------------- /mm_action_prediction/tools/rnn_support.py: -------------------------------------------------------------------------------- 1 | """Utilities to run dynamic RNN. 2 | 3 | Adapted from: VisDial-RL PyTorch Codebase 4 | Author: Satwik Kottur 5 | """ 6 | 7 | import enum 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | # Enums to denote the output type. 14 | class OutputForm(enum.Enum): 15 | ALL = 1 16 | ALL_CONCISE = 2 17 | LAST = 3 18 | PACKED = 4 19 | NONE = 5 20 | 21 | 22 | def get_sorted_order(lengths): 23 | """Sorts based on the lengths. 24 | 25 | Args: 26 | lengths: Lengths to perform sorting on. 27 | 28 | Returns: 29 | sorted_len: Lengths sorted according to descending order. 30 | fwd_order: Forward order of the sorting. 31 | bwd_order: Backward order of the sorting. 32 | """ 33 | sorted_len, fwd_order = torch.sort(lengths, dim=0, descending=True) 34 | _, bwd_order = torch.sort(fwd_order) 35 | return sorted_len, fwd_order, bwd_order 36 | 37 | 38 | def rearrange(rearrange_order, dim=0, *inputs): 39 | """Rearrages input tensors based on an order and along a given dimension. 40 | 41 | Args: 42 | rearrange_order: Order to use while rearranging. 43 | dim: Dimension along which to rearrange. 44 | *inputs: List of input tensors. 45 | """ 46 | rearranged_inputs = [] 47 | for input_tensor in inputs: 48 | assert ( 49 | input_tensor.shape[dim] == rearrange_order.shape[0] 50 | ), "Rearrange " "along dim {0} is incompatible!".format(dim) 51 | rearranged_inputs.append(input_tensor.index_select(dim, rearrange_order)) 52 | return tuple(rearranged_inputs) 53 | 54 | 55 | def dynamic_rnn( 56 | rnn_model, 57 | seq_input, 58 | seq_len, 59 | init_state=None, 60 | return_states=False, 61 | return_output=OutputForm.ALL, 62 | ): 63 | """ 64 | Inputs: 65 | rnnModel : Any torch.nn RNN model 66 | seqInput : (batchSize, maxSequenceLength, embedSize) 67 | Input sequence tensor (padded) for RNN model 68 | seqLens : batchSize length torch.LongTensor or numpy array 69 | initialState : Initial (hidden, cell) states of RNN 70 | return_output: LAST time step or ALL time steps (ALL has ALL_CONCISE). 71 | 72 | Output: 73 | A single tensor of shape (batchSize, rnnHiddenSize) corresponding 74 | to the outputs of the RNN model at the last time step of each input 75 | sequence. If returnStates is True, also return a tuple of hidden 76 | and cell states at every layer of size (num_layers, batchSize, 77 | rnnHiddenSize) 78 | """ 79 | 80 | # Perform the sorting operation to ensure sequences are in decreasing order 81 | # of lengths, as required by PyTorch packed sequence. 82 | sorted_len, fwd_order, bwd_order = get_sorted_order(seq_len) 83 | sorted_seq_input = seq_input.index_select(0, fwd_order) 84 | packed_seq_in = nn.utils.rnn.pack_padded_sequence( 85 | sorted_seq_input, lengths=sorted_len, batch_first=True 86 | ) 87 | 88 | # If initial state is given, re-arrange according to the initial sorting. 89 | if init_state is not None: 90 | sorted_init_state = [ii.index_select(1, fwd_order) for ii in init_state] 91 | # Check for number of layers match. 92 | assert ( 93 | sorted_init_state[0].size(0) == rnn_model.num_layers 94 | ), "Number of hidden layers do not match in dynamic rnn!" 95 | else: 96 | sorted_init_state = None 97 | output, (h_n, c_n) = rnn_model(packed_seq_in, sorted_init_state) 98 | 99 | # Undo the sorting operation. 100 | if return_output == OutputForm.ALL: 101 | max_seq_len = seq_input.shape[1] 102 | output, _ = nn.utils.rnn.pad_packed_sequence( 103 | output, batch_first=True, total_length=max_seq_len 104 | ) 105 | rnn_output = output.index_select(0, bwd_order) 106 | elif return_output == OutputForm.ALL_CONCISE: 107 | output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) 108 | rnn_output = output.index_select(0, bwd_order) 109 | elif return_output == OutputForm.LAST: 110 | rnn_output = h_n[-1].index_select(0, bwd_order) 111 | elif return_output == OutputForm.NONE: 112 | rnn_output = None 113 | elif return_output == OutputForm.PACKED: 114 | raise NotImplementedError 115 | else: 116 | raise TypeError("Only LAST and ALL are supported in dynamic_rnn!") 117 | 118 | if return_states: 119 | h_n = h_n.index_select(1, bwd_order) 120 | c_n = c_n.index_select(1, bwd_order) 121 | return rnn_output, (h_n, c_n) 122 | else: 123 | return rnn_output 124 | -------------------------------------------------------------------------------- /mm_action_prediction/tools/support.py: -------------------------------------------------------------------------------- 1 | """Collection of support tools. 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import functools 9 | import os 10 | import numpy as np 11 | 12 | 13 | class ExponentialSmoothing: 14 | """Exponentially smooth and track losses. 15 | """ 16 | 17 | def __init__(self): 18 | self.value = None 19 | self.blur = 0.95 20 | self.op = lambda x, y: self.blur * x + (1 - self.blur) * y 21 | 22 | def report(self, new_val): 23 | """Add a new score. 24 | 25 | Args: 26 | new_val: New value to record. 27 | """ 28 | if self.value is None: 29 | self.value = new_val 30 | else: 31 | self.value = { 32 | key: self.op(value, new_val[key]) for key, value in self.value.items() 33 | } 34 | return self.value 35 | 36 | 37 | def setup_cuda_environment(gpu_id): 38 | """Setup the GPU/CPU configuration for PyTorch. 39 | """ 40 | if gpu_id < 0: 41 | print("Running on CPU...") 42 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 43 | return False 44 | else: 45 | print("Running on GPU {0}...".format(gpu_id)) 46 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 47 | return True 48 | 49 | 50 | def pretty_print_dict(parsed): 51 | """Pretty print a parsed dictionary. 52 | """ 53 | max_len = max(len(ii) for ii in parsed.keys()) 54 | format_str = "\t{{:<{width}}}: {{}}".format(width=max_len) 55 | print("Arguments:") 56 | # Sort in alphabetical order and print. 57 | for key in sorted(parsed.keys()): 58 | print(format_str.format(key, parsed[key])) 59 | print("") 60 | 61 | 62 | def print_distribution(counts, label=None): 63 | """Prints distribution for a given histogram of counts. 64 | 65 | Args: 66 | counts: Dictionary of count histograms 67 | """ 68 | total_items = sum(counts.values()) 69 | max_length = max(len(str(ii)) for ii in counts.keys()) 70 | if label is not None: 71 | print(label) 72 | format_str = "\t{{:<{width}}} [{{:.0f}}%]: {{}}".format(width=max_length) 73 | sorted_counts = sorted(counts.items(), key=lambda x: x[1], reverse=True) 74 | for key, val in sorted_counts: 75 | print(format_str.format(key, 100 * float(val) / total_items, val)) 76 | 77 | 78 | def sort_eval_metrics(eval_metrics): 79 | """Sort a dictionary of evaluation metrics. 80 | 81 | Args: 82 | eval_metrics: Dict of evaluation metrics. 83 | 84 | Returns: 85 | sorted_evals: Sorted evaluated metrics, best first. 86 | """ 87 | # Sort based on 'perplexity' (lower is better). 88 | # sorted_evals = sorted(eval_metrics.items(), key=lambda x: x[1]['perplexity']) 89 | # return sorted_evals 90 | 91 | # Sort based on average %increase across all metrics (higher is better). 92 | def mean_relative_increase(arg1, arg2): 93 | _, metric1 = arg1 94 | _, metric2 = arg2 95 | rel_gain = [] 96 | # higher_better is +1 if true and -1 if false. 97 | for higher_better, key in [ 98 | (-1, "perplexity"), 99 | (1, "action_accuracy"), 100 | (1, "action_attribute"), 101 | ]: 102 | rel_gain.append( 103 | higher_better 104 | * (metric1[key] - metric2[key]) 105 | / (metric1[key] + metric2[key] + 1e-5) 106 | ) 107 | return np.mean(rel_gain) 108 | 109 | sorted_evals = sorted( 110 | eval_metrics.items(), 111 | key=functools.cmp_to_key(mean_relative_increase), 112 | reverse=True, 113 | ) 114 | return sorted_evals 115 | 116 | 117 | def extract_split_from_filename(file_name): 118 | """Extract the split from the filename. 119 | 120 | Args: 121 | file_name: JSON path to the split 122 | Return: 123 | split: Name of the split (train | dev | devtest | test) 124 | """ 125 | for split in ("train", "devtest", "dev", "test"): 126 | if split in file_name.split('/')[-1]: 127 | return split 128 | -------------------------------------------------------------------------------- /mm_action_prediction/tools/torch_support.py: -------------------------------------------------------------------------------- 1 | """Additional utilities for torch tensors. 2 | 3 | Author: Satwik Kottur 4 | """ 5 | 6 | 7 | from __future__ import absolute_import, division, print_function, unicode_literals 8 | 9 | import torch 10 | 11 | 12 | def flatten(tensor, batch_size, num_rounds): 13 | """Flattens a tensor based on batch_size(B) and num_rounds(N). 14 | 15 | Args: 16 | tensor: Size [B, N, D1, D2, D3, ...] 17 | batch_size: B 18 | num_rounds = N 19 | 20 | Returns: 21 | flat_tensor: Size [B * N, D1, D2, ...] 22 | """ 23 | old_size = tensor.shape 24 | assert old_size[0] == batch_size, "Expected dim 0 as {}".format(batch_size) 25 | assert old_size[1] == num_rounds, "Expected dim 1 as {}".format(num_rounds) 26 | new_size = (-1,) + old_size[2:] 27 | flat_tensor = tensor.reshape(new_size) 28 | return flat_tensor 29 | 30 | 31 | def unflatten(tensor, batch_size, num_rounds): 32 | """Unflatten a tensor based on batch_size(B) and num_rounds(N). 33 | 34 | Args: 35 | tensor: Size [B*N, D1, D2, D3, ...] 36 | batch_size: B 37 | num_rounds = N 38 | 39 | Returns: 40 | unflat_tensor: Size [B, N, D1, D2, ...] 41 | """ 42 | old_size = tensor.shape 43 | expected_first_dim = batch_size * num_rounds 44 | assert old_size[0] == expected_first_dim, "Expected dim 0 as " "{}".format( 45 | expected_first_dim 46 | ) 47 | new_size = (batch_size, num_rounds) + old_size[1:] 48 | unflat_tensor = tensor.reshape(new_size) 49 | return unflat_tensor 50 | 51 | 52 | def gather_states(all_states, indices): 53 | """Gathers states from relevant indices given all states. 54 | 55 | Args: 56 | all_states: States for all indices (N x T x d) 57 | indices: Indices to extract states from (N) 58 | 59 | Returns: 60 | gathered_states: States gathers at given indices (N x d) 61 | """ 62 | return torch.cat([ss[ii].unsqueeze(0) for ss, ii in zip(all_states, indices)]) 63 | -------------------------------------------------------------------------------- /mm_action_prediction/tools/weight_init.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:UTF-8 -*- 3 | 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | 7 | 8 | def weight_init(m): 9 | """ 10 | Usage: 11 | model = Model() 12 | model.apply(weight_init) 13 | """ 14 | if isinstance(m, nn.Conv1d): 15 | init.normal_(m.weight.data) 16 | if m.bias is not None: 17 | init.normal_(m.bias.data) 18 | elif isinstance(m, nn.Conv2d): 19 | init.xavier_normal_(m.weight.data) 20 | if m.bias is not None: 21 | init.normal_(m.bias.data) 22 | elif isinstance(m, nn.Conv3d): 23 | init.xavier_normal_(m.weight.data) 24 | if m.bias is not None: 25 | init.normal_(m.bias.data) 26 | elif isinstance(m, nn.ConvTranspose1d): 27 | init.normal_(m.weight.data) 28 | if m.bias is not None: 29 | init.normal_(m.bias.data) 30 | elif isinstance(m, nn.ConvTranspose2d): 31 | init.xavier_normal_(m.weight.data) 32 | if m.bias is not None: 33 | init.normal_(m.bias.data) 34 | elif isinstance(m, nn.ConvTranspose3d): 35 | init.xavier_normal_(m.weight.data) 36 | if m.bias is not None: 37 | init.normal_(m.bias.data) 38 | elif isinstance(m, nn.BatchNorm1d): 39 | init.normal_(m.weight.data, mean=1, std=0.02) 40 | init.constant_(m.bias.data, 0) 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.normal_(m.weight.data, mean=1, std=0.02) 43 | init.constant_(m.bias.data, 0) 44 | elif isinstance(m, nn.BatchNorm3d): 45 | init.normal_(m.weight.data, mean=1, std=0.02) 46 | init.constant_(m.bias.data, 0) 47 | elif isinstance(m, nn.Linear): 48 | init.xavier_normal_(m.weight.data) 49 | init.normal_(m.bias.data) 50 | elif isinstance(m, nn.LSTM): 51 | for param in m.parameters(): 52 | if len(param.shape) >= 2: 53 | init.orthogonal_(param.data) 54 | else: 55 | init.normal_(param.data) 56 | elif isinstance(m, nn.LSTMCell): 57 | for param in m.parameters(): 58 | if len(param.shape) >= 2: 59 | init.orthogonal_(param.data) 60 | else: 61 | init.normal_(param.data) 62 | elif isinstance(m, nn.GRU): 63 | for param in m.parameters(): 64 | if len(param.shape) >= 2: 65 | init.orthogonal_(param.data) 66 | else: 67 | init.normal_(param.data) 68 | elif isinstance(m, nn.GRUCell): 69 | for param in m.parameters(): 70 | if len(param.shape) >= 2: 71 | init.orthogonal_(param.data) 72 | else: 73 | init.normal_(param.data) 74 | elif isinstance(m, nn.Embedding): 75 | init.xavier_normal_(m.weight.data) 76 | # Go recursively if nn.Module if parameters exists. 77 | elif isinstance(m, nn.Module) and len(list(m.parameters())) > 0: 78 | for child in m.children(): 79 | weight_init(child) 80 | else: 81 | pass 82 | # # Uncomment this to figure out if there are uninitialized components. 83 | # if type(m) not in (nn.CrossEntropyLoss, nn.ReLU): 84 | # print('Warning: No initialization found for {0}!'.format(m)) 85 | 86 | 87 | if __name__ == "__main__": 88 | pass 89 | -------------------------------------------------------------------------------- /mm_action_prediction/train_simmc_agent.py: -------------------------------------------------------------------------------- 1 | """Train baselines for SIMMC dataset (furniture and fashion). 2 | 3 | Author(s): Satwik Kottur 4 | """ 5 | 6 | from __future__ import absolute_import, division, print_function, unicode_literals 7 | 8 | import json 9 | import math 10 | import time 11 | import os 12 | import torch 13 | 14 | import loaders 15 | import models 16 | import options 17 | import eval_simmc_agent as evaluation 18 | from tools import support 19 | 20 | 21 | # Arguments. 22 | args = options.read_command_line() 23 | 24 | # Dataloader. 25 | dataloader_args = { 26 | "single_pass": False, 27 | "shuffle": True, 28 | "data_read_path": args["train_data_path"], 29 | "get_retrieval_candidates": False 30 | } 31 | dataloader_args.update(args) 32 | train_loader = loaders.DataloaderSIMMC(dataloader_args) 33 | args.update(train_loader.get_data_related_arguments()) 34 | # Initiate the loader for val (DEV) data split. 35 | if args["eval_data_path"]: 36 | dataloader_args = { 37 | "single_pass": True, 38 | "shuffle": False, 39 | "data_read_path": args["eval_data_path"], 40 | "get_retrieval_candidates": args["retrieval_evaluation"] 41 | } 42 | dataloader_args.update(args) 43 | val_loader = loaders.DataloaderSIMMC(dataloader_args) 44 | else: 45 | val_loader = None 46 | 47 | # Model. 48 | wizard = models.Assistant(args) 49 | wizard.train() 50 | if args["encoder"] == "tf_idf": 51 | wizard.encoder.IDF.data = train_loader._ship_helper(train_loader.IDF) 52 | 53 | # Optimizer. 54 | optimizer = torch.optim.Adam(wizard.parameters(), args["learning_rate"]) 55 | 56 | # Training iterations. 57 | smoother = support.ExponentialSmoothing() 58 | num_iters_per_epoch = train_loader.num_instances / args["batch_size"] 59 | print("Number of iterations per epoch: {:.2f}".format(num_iters_per_epoch)) 60 | eval_dict = {} 61 | best_epoch = -1 62 | 63 | # first_batch = None 64 | for iter_ind, batch in enumerate(train_loader.get_batch()): 65 | epoch = iter_ind / num_iters_per_epoch 66 | batch_loss = wizard(batch) 67 | batch_loss_items = {key: val.item() for key, val in batch_loss.items()} 68 | losses = smoother.report(batch_loss_items) 69 | 70 | # Optimization steps. 71 | optimizer.zero_grad() 72 | batch_loss["total"].backward() 73 | torch.nn.utils.clip_grad_value_(wizard.parameters(), 1.0) 74 | optimizer.step() 75 | 76 | if iter_ind % 50 == 0: 77 | cur_time = time.strftime("%a %d%b%y %X", time.gmtime()) 78 | print_str = ( 79 | "[{}][Ep: {:.2f}][It: {:d}][A: {:.2f}][Aa: {:.2f}]" "[L: {:.2f}][T: {:.2f}]" 80 | ) 81 | print_args = ( 82 | cur_time, 83 | epoch, 84 | iter_ind, 85 | losses["action"], 86 | losses["action_attr"], 87 | losses["token"], 88 | losses["total"], 89 | ) 90 | print(print_str.format(*print_args)) 91 | 92 | # Perform evaluation, every X number of epochs. 93 | if ( 94 | val_loader 95 | and int(epoch) % args["eval_every_epoch"] == 0 96 | and (iter_ind == math.ceil(int(epoch) * num_iters_per_epoch)) 97 | ): 98 | eval_dict[int(epoch)], eval_outputs = evaluation.evaluate_agent( 99 | wizard, val_loader, args 100 | ) 101 | # Print the best epoch so far. 102 | best_epoch, best_epoch_dict = support.sort_eval_metrics(eval_dict)[0] 103 | print("\nBest Val Performance: Ep {}".format(best_epoch)) 104 | for item in best_epoch_dict.items(): 105 | print("\t{}: {:.2f}".format(*item)) 106 | 107 | # Save the model every epoch. 108 | if ( 109 | args["save_every_epoch"] > 0 110 | and int(epoch) % args["save_every_epoch"] == 0 111 | and (iter_ind == math.ceil(int(epoch) * num_iters_per_epoch)) 112 | ): 113 | # Create the folder if it does not exist. 114 | os.makedirs(args["snapshot_path"], exist_ok=True) 115 | # If prudent, save only if best model. 116 | checkpoint_dict = { 117 | "model_state": wizard.state_dict(), 118 | "args": args, 119 | "epoch": best_epoch, 120 | } 121 | if args["save_prudently"]: 122 | if best_epoch == int(epoch): 123 | save_path = os.path.join(args["snapshot_path"], "epoch_best.tar") 124 | print("Saving the model: {}".format(save_path)) 125 | torch.save(checkpoint_dict, save_path) 126 | else: 127 | save_path = os.path.join( 128 | args["snapshot_path"], "epoch_{}.tar".format(int(epoch)) 129 | ) 130 | print("Saving the model: {}".format(save_path)) 131 | torch.save(checkpoint_dict, save_path) 132 | # Save the file with evaluation metrics. 133 | eval_file = os.path.join(args["snapshot_path"], "eval_metrics.json") 134 | with open(eval_file, "w") as file_id: 135 | json.dump(eval_dict, file_id) 136 | # Exit if number of epochs exceed. 137 | if epoch > args["num_epochs"]: 138 | break 139 | -------------------------------------------------------------------------------- /mm_dst/README.md: -------------------------------------------------------------------------------- 1 | # DSTC Track 4: SIMMC | Sub-Task #3: Multimodal Dialog State Tracking (MM-DST) 2 | 3 | This directory contains the code and the scripts for running the baseline models for Sub-Task #3: Multimodal DST. 4 | 5 | The Multimodal Dialog State Tracking (MM-DST) task involves systematically tracking the attributes of dialog act labels cumulative across multiple turns. 6 | Multimodal belief states at each turn should encode sufficient information for handling user utterances in the downstream dialog components (e.g. Dialog Policy). 7 | 8 | Please check the [task input](./TASK_INPUTS.md) file for a full description of inputs 9 | for each subtask. 10 | 11 | For more details on the task definition and the baseline models we provide, please refer to our SIMMC paper: 12 | ``` 13 | @article{moon2020situated, 14 | title={Situated and Interactive Multimodal Conversations}, 15 | author={Moon, Seungwhan and Kottur, Satwik and Crook, Paul A and De, Ankita and Poddar, Shivani and Levin, Theodore and Whitney, David and Difranco, Daniel and Beirami, Ahmad and Cho, Eunjoon and Subba, Rajen and Geramifard, Alborz}, 16 | journal={arXiv preprint arXiv:2006.01460}, 17 | year={2020} 18 | } 19 | ``` 20 | **NOTE**: The [paper][simmc_arxiv] reports the results from an earlier version of the dataset and with different train-dev-test splits, hence the baseline performances on the challenge resources will be slightly different. 21 | 22 | 23 | ## Installation (Same across all sub-tasks) 24 | 25 | * Git clone the repository: 26 | ``` 27 | $ git lfs install 28 | $ git clone https://github.com/facebookresearch/simmc.git 29 | ``` 30 | 31 | * Install the required Python packages: 32 | * [Python 3.6+](https://www.python.org/downloads/) 33 | * [PyTorch 1.5+](https://pytorch.org/get-started/locally/#start-locally) 34 | * [Transformers](https://huggingface.co/transformers/installation.html) 35 | 36 | **NOTE**: We recommend installation in a virtual environment ([user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)). Create a new virtual environment and activate it prior to installing the packages. 37 | 38 | 39 | ## Run Baselines 40 | 41 | ### Baseline: GPT-2 Based DST 42 | 43 | 1. **Preprocess** the datasets to reformat the data for GPT-2 input. 44 | 45 | ``` 46 | $ cd mm_dst 47 | $ ./run_preprocess_gpt2.sh 48 | ``` 49 | 50 | The shell script above repeats the following for all {train|dev|devtest} splits and both {furniture|fashion} domains. 51 | 52 | ``` 53 | $ python -m gpt2_dst.scripts.preprocess_input \ 54 | --input_path_json={path_dir}/data/simmc-fashion/fashion_train_dials.json \ 55 | --output_path_predict={path_dir}/mm_dst/gpt2_dst/data/fashion/fashion_train_dials_predict.txt \ 56 | --output_path_target={path_dir}/mm_dst/gpt2_dst/data/fashion/fashion_train_dials_target.txt \ 57 | --output_path_special_tokens={path_dir}/mm_dst/gpt2_dst/data/fashion/special_tokens.json 58 | --len_context=2 \ 59 | --use_multimodal_contexts=1 \ 60 | ``` 61 | 62 | 2. **Train** the baseline model 63 | 64 | ``` 65 | $ ./run_train_gpt2.sh 66 | ``` 67 | 68 | The shell script above repeats the following for both {furniture|fashion} domains. 69 | 70 | ``` 71 | $ python -m gpt2_dst.scripts.run_language_modeling \ 72 | --output_dir={path_dir}/save/fashion \ 73 | --model_type=gpt2 \ 74 | --model_name_or_path=gpt2 \ 75 | --line_by_line \ 76 | --add_special_tokens={path_dir}/mm_dst/gpt2_dst/data/fashion/special_tokens.json \ 77 | --do_train \ 78 | --train_data_file={path_dir}/mm_dst/gpt2_dst/data/fashion/fashion_train_dials_target.txt \ 79 | --do_eval \ 80 | --eval_data_file={path_dir}/mm_dst/gpt2_dst/data/fashion/fashion_dev_dials_target.txt \ 81 | --num_train_epochs=1 \ 82 | --overwrite_output_dir \ 83 | --per_gpu_train_batch_size=4 \ 84 | --per_gpu_eval_batch_size=4 \ 85 | #--no_cuda 86 | 87 | ``` 88 | 89 | 3. **Generate** prediction for `devtest` data 90 | 91 | ``` 92 | $ ./run_generate_gpt2.sh 93 | ``` 94 | 95 | The shell script above repeats the following for both {furniture|fashion} domains. 96 | ``` 97 | $ python -m gpt2_dst.scripts.run_generation \ 98 | --model_type=gpt2 \ 99 | --model_name_or_path={path_dir}/mm_dst/gpt2_dst/save/furniture/ \ 100 | --num_return_sequences=1 \ 101 | --length=100 \ 102 | --stop_token='' \ 103 | --prompts_from_file={path_dir}/mm_dst/gpt2_dst/data/furniture/furniture_devtest_dials_predict.txt \ 104 | --path_output={path_dir}/mm_dst/gpt2_dst/results/furniture/furniture_devtest_dials_predicted.txt 105 | ``` 106 | 107 | Here is an example output: 108 | ``` 109 | System : Yes, here's another one you might like. User : Oh yeah I think my niece would really like that. Does it come in any other colors? System : I'm sorry I don't have that information. User : Ah well. I like this color. I'd like to go ahead and buy it. Can you add it to my cart please? => Belief State : 110 | DA:INFORM:PREFER:JACKET [ fashion-O_2 = obj ] DA:REQUEST:ADD_TO_CART:JACKET [ fashion-O_2 = obj ] Of course, you now have this 111 | ``` 112 | 113 | The generation results are saved in the `/mm_dst/results` folder. Change the `path_output` to a desired path accordingly. 114 | 115 | 116 | 4. **Evaluate** predictions for `devtest` data 117 | 118 | ``` 119 | $ ./run_evaluate_gpt2.sh 120 | ``` 121 | 122 | The shell script above repeats the following for both {furniture|fashion} domains. 123 | ``` 124 | python -m gpt2_dst.scripts.evaluate \ 125 | --input_path_target={path_dir}/mm_dst/gpt2_dst/data/furniture/furniture_devtest_dials_target.txt \ 126 | --input_path_predicted={path_dir}/mm_dst/gpt2_dst/results/furniture/furniture_devtest_dials_predicted.txt \ 127 | --output_path_report={path_dir}/mm_dst/gpt2_dst/results/furniture/furniture_devtest_dials_report.json 128 | 129 | ``` 130 | 131 | Evaluation reports are saved in the `/mm_dst/results` folder as JSON files. 132 | 133 | Please note that the GPT2 fine-tuning is highly sensitive to the batch size (which `n_gpu` of your machine may affect), hence it may need some hyperparameter tuning to obtain the best results (and avoid over/under fitting). Please feel free to change the hyperparameter of the default settings (provided) to compare results. 134 | 135 | Alternatively, we *also* provide an evaluation script that takes as input a JSON file that is in the same structure as the original data JSON files (in case your model outputs predictions per dialog, as opposed to per turn). For example, the input `pred_dials.json` file should be formatted: 136 | ``` 137 | { 138 | "dialogue_data": [ 139 | { 140 | "dialogue": [ 141 | { 142 | "belief_state": [ 143 | [ 144 | { 145 | 'act': , 146 | 'slots': [ 147 | [ 148 | SLOT_NAME, SLOT_VALUE 149 | ], ... 150 | ] 151 | }, 152 | [End of a frame] 153 | ... 154 | ], 155 | ] 156 | } 157 | [End of a turn] 158 | ... 159 | ], 160 | }, 161 | [End of a dialogue] 162 | ... 163 | ] 164 | } 165 | ``` 166 | 167 | To run this evaluation script: 168 | ``` 169 | $ ./run_evaluate.sh 170 | ``` 171 | 172 | The shell script above repeats the following for both {furniture|fashion} domains. 173 | ``` 174 | python -m utils.evaluate_dst \ 175 | --input_path_target="${PATH_DATA_DIR}"/simmc_fashion/fashion_devtest_dials.json \ 176 | --input_path_predicted="${PATH_DIR}"/fashion_devtest_pred_dials.json \ 177 | --output_path_report="${PATH_DIR}"/fashion_report.json 178 | ``` 179 | 180 | Below is the summary of the [published models](https://github.com/facebookresearch/simmc/releases/download/1.0/mm_dst_gpt2_baselines.tar.gz) we provide: 181 | 182 | | Baseline | Dialog Act F1 | Slot F1 | 183 | |--------|-------|-------| 184 | | GPT2 - Furniture (text-only) | 69.9 | 52.5 | 185 | | GPT2 - Furniture (multimodal) | 69.5 | 63.9 | 186 | | GPT2 - Fashion (text-only) | 61.2 | 52.1 | 187 | | GPT2 - Fashion (multimodal) | 61.1 | 60.6 | 188 | 189 | **Note:** DSTC9 SIMMC Challenge was conducted on SIMMC v1.0. Thus all the results and baseline performances are on SIMMC v1.0. 190 | 191 | ## Rules for Sub-task #3 Submissions 192 | * Disallowed input per each turn: `belief_state`, `system_transcript`, `system_transcript_annotated`, `state_graph_1`, `state_graph_2`, and anything from future turns. 193 | * If you would like to use any other external resources, please consult with the track organizers (simmc@fb.com). Generally, we allow the use of publicly available pre-trained language models, such as BERT, GPT-2, etc. 194 | 195 | [dstc9]:https://sites.google.com/dstc.community/dstc9/home 196 | [simmc_arxiv]:https://arxiv.org/abs/2006.01460 197 | -------------------------------------------------------------------------------- /mm_dst/gpt2_dst/scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Scripts for evaluating the GPT-2 DST model predictions. 4 | 5 | First, we parse the line-by-line stringified format into 6 | the structured DST output. 7 | 8 | We then run the main DST Evaluation script to get results. 9 | """ 10 | import argparse 11 | import json 12 | from gpt2_dst.utils.convert import parse_flattened_results_from_file 13 | from utils.evaluate_dst import evaluate_from_flat_list 14 | 15 | 16 | if __name__ == '__main__': 17 | # Parse input args 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--input_path_target', 20 | help='path for target, line-separated format (.txt)') 21 | parser.add_argument('--input_path_predicted', 22 | help='path for model prediction output, line-separated format (.txt)') 23 | parser.add_argument('--output_path_report', 24 | help='path for saving evaluation summary (.json)') 25 | 26 | args = parser.parse_args() 27 | input_path_target = args.input_path_target 28 | input_path_predicted = args.input_path_predicted 29 | output_path_report = args.output_path_report 30 | 31 | # Convert the data from the GPT-2 friendly format to JSON 32 | list_target = parse_flattened_results_from_file(input_path_target) 33 | list_predicted = parse_flattened_results_from_file(input_path_predicted) 34 | 35 | # Evaluate 36 | report = evaluate_from_flat_list(list_target, list_predicted) 37 | 38 | # Save report 39 | with open(output_path_report, 'w') as f_out: 40 | json.dump(report, f_out) 41 | -------------------------------------------------------------------------------- /mm_dst/gpt2_dst/scripts/preprocess_input.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Scripts for converting the main SIMMC datasets (.JSON format) 4 | into the line-by-line stringified format (and back). 5 | 6 | The reformatted data is used as input for the GPT-2 based 7 | DST model baseline. 8 | """ 9 | from gpt2_dst.utils.convert import convert_json_to_flattened 10 | import argparse 11 | 12 | if __name__ == '__main__': 13 | # Parse input args 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--input_path_json', 16 | help='input path to the original dialog data') 17 | parser.add_argument('--output_path_predict', 18 | help='output path for model input') 19 | parser.add_argument('--output_path_target', 20 | help='output path for full target') 21 | parser.add_argument('--input_path_special_tokens', 22 | help='input path for special tokens. blank if not provided', 23 | default='') 24 | parser.add_argument('--output_path_special_tokens', 25 | help='output path for special tokens. blank if not saving', 26 | default='') 27 | parser.add_argument('--len_context', 28 | help='# of turns to include as dialog context', 29 | type=int, default=2) 30 | parser.add_argument('--use_multimodal_contexts', 31 | help='determine whether to use the multimodal contexts each turn', 32 | type=int, default=1) 33 | 34 | args = parser.parse_args() 35 | input_path_json = args.input_path_json 36 | output_path_predict = args.output_path_predict 37 | output_path_target = args.output_path_target 38 | input_path_special_tokens = args.input_path_special_tokens 39 | output_path_special_tokens = args.output_path_special_tokens 40 | len_context = args.len_context 41 | use_multimodal_contexts = bool(args.use_multimodal_contexts) 42 | 43 | # Convert the data into GPT-2 friendly format 44 | convert_json_to_flattened( 45 | input_path_json, 46 | output_path_predict, 47 | output_path_target, 48 | input_path_special_tokens=input_path_special_tokens, 49 | output_path_special_tokens=output_path_special_tokens, 50 | len_context=len_context, 51 | use_multimodal_contexts=use_multimodal_contexts) 52 | -------------------------------------------------------------------------------- /mm_dst/run_evaluate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [[ $# -lt 1 ]] 3 | then 4 | PATH_DIR=$(realpath .) 5 | PATH_DATA_DIR=$(realpath ../data) 6 | else 7 | PATH_DIR=$(realpath "$1") 8 | PATH_DATA_DIR=$(realpath "$2") 9 | fi 10 | 11 | # Evaluate (Example) 12 | python -m utils.evaluate_dst \ 13 | --input_path_target="${PATH_DATA_DIR}"/simmc_fashion/fashion_devtest_dials.json \ 14 | --input_path_predicted="${PATH_DIR}"/fashion_devtest_pred_dials.json \ 15 | --output_path_report="${PATH_DIR}"/fashion_report.json 16 | 17 | python -m utils.evaluate_dst \ 18 | --input_path_target="${PATH_DATA_DIR}"/simmc_furniture/furniture_devtest_dials.json \ 19 | --input_path_predicted="${PATH_DIR}"/furniture_devtest_pred_dials.json \ 20 | --output_path_report="${PATH_DIR}"/furniture_report.json -------------------------------------------------------------------------------- /mm_dst/run_evaluate_gpt2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [[ $# -lt 1 ]] 3 | then 4 | PATH_DIR=$(realpath .) 5 | else 6 | PATH_DIR=$(realpath "$1") 7 | fi 8 | 9 | # Evaluate (furniture, non-multimodal) 10 | python -m gpt2_dst.scripts.evaluate \ 11 | --input_path_target="${PATH_DIR}"/gpt2_dst/data/furniture_to/furniture_devtest_dials_target.txt \ 12 | --input_path_predicted="${PATH_DIR}"/gpt2_dst/results/furniture_to/furniture_devtest_dials_predicted.txt \ 13 | --output_path_report="${PATH_DIR}"/gpt2_dst/results/furniture_to/furniture_devtest_dials_report.json 14 | 15 | # Evaluate (furniture, multi-modal) 16 | python -m gpt2_dst.scripts.evaluate \ 17 | --input_path_target="${PATH_DIR}"/gpt2_dst/data/furniture/furniture_devtest_dials_target.txt \ 18 | --input_path_predicted="${PATH_DIR}"/gpt2_dst/results/furniture/furniture_devtest_dials_predicted.txt \ 19 | --output_path_report="${PATH_DIR}"/gpt2_dst/results/furniture/furniture_devtest_dials_report.json 20 | 21 | # Evaluate (Fashion, non-multimodal) 22 | python -m gpt2_dst.scripts.evaluate \ 23 | --input_path_target="${PATH_DIR}"/gpt2_dst/data/fashion_to/fashion_devtest_dials_target.txt \ 24 | --input_path_predicted="${PATH_DIR}"/gpt2_dst/results/fashion_to/fashion_devtest_dials_predicted.txt \ 25 | --output_path_report="${PATH_DIR}"/gpt2_dst/results/fashion_to/fashion_devtest_dials_report.json 26 | 27 | # Evaluate (Fashion, multi-modal) 28 | python -m gpt2_dst.scripts.evaluate \ 29 | --input_path_target="${PATH_DIR}"/gpt2_dst/data/fashion/fashion_devtest_dials_target.txt \ 30 | --input_path_predicted="${PATH_DIR}"/gpt2_dst/results/fashion/fashion_devtest_dials_predicted.txt \ 31 | --output_path_report="${PATH_DIR}"/gpt2_dst/results/fashion/fashion_devtest_dials_report.json 32 | -------------------------------------------------------------------------------- /mm_dst/run_generate_gpt2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [[ $# -lt 1 ]] 3 | then 4 | PATH_DIR=$(realpath .) 5 | else 6 | PATH_DIR=$(realpath "$1") 7 | fi 8 | 9 | # Generate sentences (Furniture, text-only) 10 | python -m gpt2_dst.scripts.run_generation \ 11 | --model_type=gpt2 \ 12 | --model_name_or_path="${PATH_DIR}"/gpt2_dst/save/furniture_to/ \ 13 | --num_return_sequences=1 \ 14 | --length=100 \ 15 | --stop_token='' \ 16 | --prompts_from_file="${PATH_DIR}"/gpt2_dst/data/furniture_to/furniture_devtest_dials_predict.txt \ 17 | --path_output="${PATH_DIR}"/gpt2_dst/results/furniture_to/furniture_devtest_dials_predicted.txt 18 | 19 | # Generate sentences (Furniture, multi-modal) 20 | python -m gpt2_dst.scripts.run_generation \ 21 | --model_type=gpt2 \ 22 | --model_name_or_path="${PATH_DIR}"/gpt2_dst/save/furniture/ \ 23 | --num_return_sequences=1 \ 24 | --length=100 \ 25 | --stop_token='' \ 26 | --prompts_from_file="${PATH_DIR}"/gpt2_dst/data/furniture/furniture_devtest_dials_predict.txt \ 27 | --path_output="${PATH_DIR}"/gpt2_dst/results/furniture/furniture_devtest_dials_predicted.txt 28 | 29 | # Generate sentences (Fashion, text-only) 30 | python -m gpt2_dst.scripts.run_generation \ 31 | --model_type=gpt2 \ 32 | --model_name_or_path="${PATH_DIR}"/gpt2_dst/save/fashion_to/ \ 33 | --num_return_sequences=1 \ 34 | --length=100 \ 35 | --stop_token='' \ 36 | --prompts_from_file="${PATH_DIR}"/gpt2_dst/data/fashion_to/fashion_devtest_dials_predict.txt \ 37 | --path_output="${PATH_DIR}"/gpt2_dst/results/fashion_to/fashion_devtest_dials_predicted.txt 38 | 39 | # Generate sentences (Fashion, multi-modal) 40 | python -m gpt2_dst.scripts.run_generation \ 41 | --model_type=gpt2 \ 42 | --model_name_or_path="${PATH_DIR}"/gpt2_dst/save/fashion/ \ 43 | --num_return_sequences=1 \ 44 | --length=100 \ 45 | --stop_token='' \ 46 | --prompts_from_file="${PATH_DIR}"/gpt2_dst/data/fashion/fashion_devtest_dials_predict.txt \ 47 | --path_output="${PATH_DIR}"/gpt2_dst/results/fashion/fashion_devtest_dials_predicted.txt 48 | -------------------------------------------------------------------------------- /mm_dst/run_preprocess_gpt2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [[ $# -lt 1 ]] 3 | then 4 | PATH_DIR=$(realpath .) 5 | PATH_DATA_DIR=$(realpath ../data) 6 | else 7 | PATH_DIR=$(realpath "$1") 8 | PATH_DATA_DIR=$(realpath "$2") 9 | fi 10 | 11 | # Fashion 12 | # Multimodal Data 13 | # Train split 14 | python -m gpt2_dst.scripts.preprocess_input \ 15 | --input_path_json="${PATH_DATA_DIR}"/simmc_fashion/fashion_train_dials.json \ 16 | --output_path_predict="${PATH_DIR}"/gpt2_dst/data/fashion/fashion_train_dials_predict.txt \ 17 | --output_path_target="${PATH_DIR}"/gpt2_dst/data/fashion/fashion_train_dials_target.txt \ 18 | --len_context=2 \ 19 | --use_multimodal_contexts=1 \ 20 | --output_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/fashion/special_tokens.json 21 | 22 | # Dev split 23 | python -m gpt2_dst.scripts.preprocess_input \ 24 | --input_path_json="${PATH_DATA_DIR}"/simmc_fashion/fashion_dev_dials.json \ 25 | --output_path_predict="${PATH_DIR}"/gpt2_dst/data/fashion/fashion_dev_dials_predict.txt \ 26 | --output_path_target="${PATH_DIR}"/gpt2_dst/data/fashion/fashion_dev_dials_target.txt \ 27 | --len_context=2 \ 28 | --use_multimodal_contexts=1 \ 29 | --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/fashion/special_tokens.json \ 30 | --output_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/fashion/special_tokens.json \ 31 | 32 | # Devtest split 33 | python -m gpt2_dst.scripts.preprocess_input \ 34 | --input_path_json="${PATH_DATA_DIR}"/simmc_fashion/fashion_devtest_dials.json \ 35 | --output_path_predict="${PATH_DIR}"/gpt2_dst/data/fashion/fashion_devtest_dials_predict.txt \ 36 | --output_path_target="${PATH_DIR}"/gpt2_dst/data/fashion/fashion_devtest_dials_target.txt \ 37 | --len_context=2 \ 38 | --use_multimodal_contexts=1 \ 39 | --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/fashion/special_tokens.json \ 40 | 41 | # Test split 42 | # python -m gpt2_dst.scripts.preprocess_input \ 43 | # --input_path_json="${PATH_DATA_DIR}"/simmc_fashion/fashion_test_dials.json \ 44 | # --output_path_predict="${PATH_DIR}"/gpt2_dst/data/fashion/fashion_test_dials_predict.txt \ 45 | # --output_path_target="${PATH_DIR}"/gpt2_dst/data/fashion/fashion_test_dials_target.txt \ 46 | # --len_context=2 \ 47 | # --use_multimodal_contexts=1 \ 48 | # --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/fashion/special_tokens.json \ 49 | 50 | # Fashion 51 | # Non-multimodal Data 52 | # Train split 53 | python -m gpt2_dst.scripts.preprocess_input \ 54 | --input_path_json="${PATH_DATA_DIR}"/simmc_fashion/fashion_train_dials.json \ 55 | --output_path_predict="${PATH_DIR}"/gpt2_dst/data/fashion_to/fashion_train_dials_predict.txt \ 56 | --output_path_target="${PATH_DIR}"/gpt2_dst/data/fashion_to/fashion_train_dials_target.txt \ 57 | --len_context=2 \ 58 | --use_multimodal_contexts=0 \ 59 | --output_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/fashion_to/special_tokens.json 60 | 61 | # Dev split 62 | python -m gpt2_dst.scripts.preprocess_input \ 63 | --input_path_json="${PATH_DATA_DIR}"/simmc_fashion/fashion_dev_dials.json \ 64 | --output_path_predict="${PATH_DIR}"/gpt2_dst/data/fashion_to/fashion_dev_dials_predict.txt \ 65 | --output_path_target="${PATH_DIR}"/gpt2_dst/data/fashion_to/fashion_dev_dials_target.txt \ 66 | --len_context=2 \ 67 | --use_multimodal_contexts=0 \ 68 | --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/fashion_to/special_tokens.json \ 69 | --output_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/fashion_to/special_tokens.json \ 70 | 71 | # Devtest split 72 | python -m gpt2_dst.scripts.preprocess_input \ 73 | --input_path_json="${PATH_DATA_DIR}"/simmc_fashion/fashion_devtest_dials.json \ 74 | --output_path_predict="${PATH_DIR}"/gpt2_dst/data/fashion_to/fashion_devtest_dials_predict.txt \ 75 | --output_path_target="${PATH_DIR}"/gpt2_dst/data/fashion_to/fashion_devtest_dials_target.txt \ 76 | --len_context=2 \ 77 | --use_multimodal_contexts=0 \ 78 | --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/fashion_to/special_tokens.json \ 79 | 80 | # Test split 81 | # python -m gpt2_dst.scripts.preprocess_input \ 82 | # --input_path_json="${PATH_DATA_DIR}"/simmc_fashion/fashion_test_dials.json \ 83 | # --output_path_predict="${PATH_DIR}"/gpt2_dst/data/fashion_to/fashion_test_dials_predict.txt \ 84 | # --output_path_target="${PATH_DIR}"/gpt2_dst/data/fashion_to/fashion_test_dials_target.txt \ 85 | # --len_context=2 \ 86 | # --use_multimodal_contexts=0 \ 87 | # --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/fashion_to/special_tokens.json \ 88 | 89 | # Furniture 90 | # Multimodal Data 91 | # Train split 92 | python -m gpt2_dst.scripts.preprocess_input \ 93 | --input_path_json="${PATH_DATA_DIR}"/simmc_furniture/furniture_train_dials.json \ 94 | --output_path_predict="${PATH_DIR}"/gpt2_dst/data/furniture/furniture_train_dials_predict.txt \ 95 | --output_path_target="${PATH_DIR}"/gpt2_dst/data/furniture/furniture_train_dials_target.txt \ 96 | --len_context=2 \ 97 | --use_multimodal_contexts=1 \ 98 | --output_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/furniture/special_tokens.json 99 | 100 | # Dev split 101 | python -m gpt2_dst.scripts.preprocess_input \ 102 | --input_path_json="${PATH_DATA_DIR}"/simmc_furniture/furniture_dev_dials.json \ 103 | --output_path_predict="${PATH_DIR}"/gpt2_dst/data/furniture/furniture_dev_dials_predict.txt \ 104 | --output_path_target="${PATH_DIR}"/gpt2_dst/data/furniture/furniture_dev_dials_target.txt \ 105 | --len_context=2 \ 106 | --use_multimodal_contexts=1 \ 107 | --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/furniture/special_tokens.json \ 108 | --output_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/furniture/special_tokens.json \ 109 | 110 | # Devtest split 111 | python -m gpt2_dst.scripts.preprocess_input \ 112 | --input_path_json="${PATH_DATA_DIR}"/simmc_furniture/furniture_devtest_dials.json \ 113 | --output_path_predict="${PATH_DIR}"/gpt2_dst/data/furniture/furniture_devtest_dials_predict.txt \ 114 | --output_path_target="${PATH_DIR}"/gpt2_dst/data/furniture/furniture_devtest_dials_target.txt \ 115 | --len_context=2 \ 116 | --use_multimodal_contexts=1 \ 117 | --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/furniture/special_tokens.json \ 118 | 119 | # Test split 120 | #python -m gpt2_dst.scripts.preprocess_input \ 121 | # --input_path_json="${PATH_DATA_DIR}"/simmc_furniture/furniture_test_dials.json \ 122 | # --output_path_predict="${PATH_DIR}"/gpt2_dst/data/furniture/furniture_test_dials_predict.txt \ 123 | # --output_path_target="${PATH_DIR}"/gpt2_dst/data/furniture/furniture_test_dials_target.txt \ 124 | # --len_context=2 \ 125 | # --use_multimodal_contexts=1 \ 126 | # --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/furniture/special_tokens.json \ 127 | 128 | # Furniture 129 | # Non-multimodal Data 130 | # Train split 131 | python -m gpt2_dst.scripts.preprocess_input \ 132 | --input_path_json="${PATH_DATA_DIR}"/simmc_furniture/furniture_train_dials.json \ 133 | --output_path_predict="${PATH_DIR}"/gpt2_dst/data/furniture_to/furniture_train_dials_predict.txt \ 134 | --output_path_target="${PATH_DIR}"/gpt2_dst/data/furniture_to/furniture_train_dials_target.txt \ 135 | --len_context=2 \ 136 | --use_multimodal_contexts=0 \ 137 | --output_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/furniture_to/special_tokens.json 138 | 139 | # Dev split 140 | python -m gpt2_dst.scripts.preprocess_input \ 141 | --input_path_json="${PATH_DATA_DIR}"/simmc_furniture/furniture_dev_dials.json \ 142 | --output_path_predict="${PATH_DIR}"/gpt2_dst/data/furniture_to/furniture_dev_dials_predict.txt \ 143 | --output_path_target="${PATH_DIR}"/gpt2_dst/data/furniture_to/furniture_dev_dials_target.txt \ 144 | --len_context=2 \ 145 | --use_multimodal_contexts=0 \ 146 | --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/furniture_to/special_tokens.json \ 147 | --output_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/furniture_to/special_tokens.json \ 148 | 149 | # Devtest split 150 | python -m gpt2_dst.scripts.preprocess_input \ 151 | --input_path_json="${PATH_DATA_DIR}"/simmc_furniture/furniture_devtest_dials.json \ 152 | --output_path_predict="${PATH_DIR}"/gpt2_dst/data/furniture_to/furniture_devtest_dials_predict.txt \ 153 | --output_path_target="${PATH_DIR}"/gpt2_dst/data/furniture_to/furniture_devtest_dials_target.txt \ 154 | --len_context=2 \ 155 | --use_multimodal_contexts=0 \ 156 | --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/furniture_to/special_tokens.json \ 157 | 158 | # Test split 159 | # python -m gpt2_dst.scripts.preprocess_input \ 160 | # --input_path_json="${PATH_DATA_DIR}"/simmc_furniture/furniture_test_dials.json \ 161 | # --output_path_predict="${PATH_DIR}"/gpt2_dst/data/furniture_to/furniture_test_dials_predict.txt \ 162 | # --output_path_target="${PATH_DIR}"/gpt2_dst/data/furniture_to/furniture_test_dials_target.txt \ 163 | # --len_context=2 \ 164 | # --use_multimodal_contexts=0 \ 165 | # --input_path_special_tokens="${PATH_DIR}"/gpt2_dst/data/furniture_to/special_tokens.json \ 166 | -------------------------------------------------------------------------------- /mm_dst/run_train_gpt2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | if [[ $# -lt 1 ]] 3 | then 4 | PATH_DIR=$(realpath .) 5 | else 6 | PATH_DIR=$(realpath "$1") 7 | fi 8 | 9 | # Train (furniture, text-only) 10 | python -m gpt2_dst.scripts.run_language_modeling \ 11 | --output_dir="${PATH_DIR}"/gpt2_dst/save/furniture_to \ 12 | --model_type=gpt2 \ 13 | --model_name_or_path=gpt2 \ 14 | --line_by_line \ 15 | --add_special_tokens="${PATH_DIR}"/gpt2_dst/data/furniture_to/special_tokens.json \ 16 | --do_train \ 17 | --train_data_file="${PATH_DIR}"/gpt2_dst/data/furniture_to/furniture_train_dials_target.txt \ 18 | --do_eval \ 19 | --eval_data_file="${PATH_DIR}"/gpt2_dst/data/furniture_to/furniture_dev_dials_target.txt \ 20 | --num_train_epochs=1 \ 21 | --overwrite_output_dir \ 22 | --per_gpu_train_batch_size=4 \ 23 | --per_gpu_eval_batch_size=4 24 | 25 | # Train (furniture, multi-modal) 26 | python -m gpt2_dst.scripts.run_language_modeling \ 27 | --output_dir="${PATH_DIR}"/gpt2_dst/save/furniture \ 28 | --model_type=gpt2 \ 29 | --model_name_or_path=gpt2 \ 30 | --line_by_line \ 31 | --add_special_tokens="${PATH_DIR}"/gpt2_dst/data/furniture/special_tokens.json \ 32 | --do_train \ 33 | --train_data_file="${PATH_DIR}"/gpt2_dst/data/furniture/furniture_train_dials_target.txt \ 34 | --do_eval \ 35 | --eval_data_file="${PATH_DIR}"/gpt2_dst/data/furniture/furniture_dev_dials_target.txt \ 36 | --num_train_epochs=1 \ 37 | --overwrite_output_dir \ 38 | --per_gpu_train_batch_size=4 \ 39 | --per_gpu_eval_batch_size=4 40 | 41 | # Train (Fashion, text-only) 42 | python -m gpt2_dst.scripts.run_language_modeling \ 43 | --output_dir="${PATH_DIR}"/gpt2_dst/save/fashion_to \ 44 | --model_type=gpt2 \ 45 | --model_name_or_path=gpt2 \ 46 | --line_by_line \ 47 | --add_special_tokens="${PATH_DIR}"/gpt2_dst/data/fashion_to/special_tokens.json \ 48 | --do_train \ 49 | --train_data_file="${PATH_DIR}"/gpt2_dst/data/fashion_to/fashion_train_dials_target.txt \ 50 | --do_eval \ 51 | --eval_data_file="${PATH_DIR}"/gpt2_dst/data/fashion_to/fashion_dev_dials_target.txt \ 52 | --num_train_epochs=1 \ 53 | --overwrite_output_dir \ 54 | --per_gpu_train_batch_size=4 \ 55 | --per_gpu_eval_batch_size=4 56 | 57 | # Train (Fashion, multi-modal) 58 | python -m gpt2_dst.scripts.run_language_modeling \ 59 | --output_dir="${PATH_DIR}"/gpt2_dst/save/fashion \ 60 | --model_type=gpt2 \ 61 | --model_name_or_path=gpt2 \ 62 | --line_by_line \ 63 | --add_special_tokens="${PATH_DIR}"/gpt2_dst/data/fashion/special_tokens.json \ 64 | --do_train \ 65 | --train_data_file="${PATH_DIR}"/gpt2_dst/data/fashion/fashion_train_dials_target.txt \ 66 | --do_eval \ 67 | --eval_data_file="${PATH_DIR}"/gpt2_dst/data/fashion/fashion_dev_dials_target.txt \ 68 | --num_train_epochs=1 \ 69 | --overwrite_output_dir \ 70 | --per_gpu_train_batch_size=4 \ 71 | --per_gpu_eval_batch_size=4 72 | -------------------------------------------------------------------------------- /mm_dst/utils/evaluate_dst.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import numpy as np 3 | 4 | 5 | def evaluate_from_flat_list(d_true, d_pred): 6 | """ 7 | d_true and d_pred are in the following format: 8 | (Each element represents a single turn, with (multiple) frames) 9 | [ 10 | [ 11 | { 12 | 'act': , 13 | 'slots': [ 14 | [ 15 | SLOT_NAME, SLOT_VALUE 16 | ], ... 17 | ] 18 | }, 19 | [End of a frame] 20 | ... 21 | ], 22 | [End of a turn] 23 | ... 24 | ] 25 | """ 26 | c = { 27 | 'n_frames': 0.0, 28 | 'n_true_acts': 0.0, 29 | 'n_pred_acts': 0.0, 30 | 'n_correct_acts': 0.0, 31 | 'n_true_slots': 0.0, 32 | 'n_pred_slots': 0.0, 33 | 'n_correct_slots': 0.0, 34 | 'n_correct_beliefs': 0.0, 35 | } 36 | 37 | # Count # corrects & # wrongs 38 | for turn_idx in range(len(d_true)): 39 | true_turn = d_true[turn_idx] 40 | pred_turn = d_pred[turn_idx] 41 | 42 | c = add_dicts( 43 | c, 44 | evaluate_turn(true_turn, pred_turn)) 45 | 46 | # Calculate metrics 47 | joint_accuracy = c['n_correct_beliefs'] / c['n_frames'] 48 | act_rec = c['n_correct_acts'] / c['n_true_acts'] 49 | act_prec = c['n_correct_acts'] / c['n_pred_acts'] 50 | act_f1 = \ 51 | 2 * act_prec * act_rec / (act_prec + act_rec) \ 52 | if (act_prec + act_rec) != 0 else 0 53 | 54 | slot_rec = c['n_correct_slots'] / c['n_true_slots'] 55 | slot_prec = c['n_correct_slots'] / c['n_pred_slots'] 56 | slot_f1 = \ 57 | 2 * slot_prec * slot_rec / (slot_prec + slot_rec) \ 58 | if (slot_prec + slot_rec) != 0 else 0 59 | 60 | # Calculate std err 61 | act_f1_stderr = \ 62 | d_f1(c['n_true_acts'], c['n_pred_acts'], c['n_correct_acts']) 63 | slot_f1_stderr = \ 64 | d_f1(c['n_true_slots'], c['n_pred_slots'], c['n_correct_slots']) 65 | 66 | return { 67 | 'joint_accuracy': joint_accuracy, 68 | 'act_rec': act_rec, 69 | 'act_prec': act_prec, 70 | 'act_f1': act_f1, 71 | 'act_f1_stderr': act_f1_stderr, 72 | 'slot_rec': slot_rec, 73 | 'slot_prec': slot_prec, 74 | 'slot_f1': slot_f1, 75 | 'slot_f1_stderr': slot_f1_stderr, 76 | } 77 | 78 | 79 | def evaluate_turn(true_turn, pred_turn): 80 | 81 | count_dict = { 82 | 'n_frames': 0, 83 | 'n_true_acts': 0, 84 | 'n_pred_acts': 0, 85 | 'n_correct_acts': 0, 86 | 'n_true_slots': 0, 87 | 'n_pred_slots': 0, 88 | 'n_correct_slots': 0, 89 | 'n_correct_beliefs': 0, 90 | } 91 | 92 | # Must preserve order in which frames appear. 93 | for frame_idx in range(len(true_turn)): 94 | # For each frame 95 | true_frame = true_turn[frame_idx] 96 | if frame_idx >= len(pred_turn): 97 | pred_frame = {} 98 | else: 99 | pred_frame = pred_turn[frame_idx] 100 | 101 | count_dict = add_dicts( 102 | count_dict, 103 | evaluate_frame(true_frame, pred_frame, strict=False)) 104 | 105 | return count_dict 106 | 107 | 108 | def evaluate_frame(true_frame, pred_frame, strict=True): 109 | """ 110 | If strict=True, 111 | For each dialog_act (frame), set(slot values) must match. 112 | If dialog_act is incorrect, its set(slot values) is considered wrong. 113 | """ 114 | count_dict = { 115 | 'n_frames': 1, 116 | 'n_true_acts': 0, 117 | 'n_pred_acts': 0, 118 | 'n_correct_acts': 0, 119 | 'n_true_slots': 0, 120 | 'n_pred_slots': 0, 121 | 'n_correct_slots': 0, 122 | 'n_correct_beliefs': 0, 123 | } 124 | 125 | # Compare Dialog Actss 126 | true_act = true_frame['act'] if 'act' in true_frame else None 127 | pred_act = pred_frame['act'] if 'act' in pred_frame else None 128 | b_correct_act = true_act == pred_act 129 | count_dict['n_correct_acts'] += b_correct_act 130 | count_dict['n_true_acts'] += 'act' in true_frame 131 | count_dict['n_pred_acts'] += 'act' in pred_frame 132 | 133 | # Compare Slots 134 | true_frame_slot_values = \ 135 | set(f'{k}={v}' for k, v in true_frame.get('slots', [])) 136 | 137 | pred_frame_slot_values = \ 138 | set(f'{k}={v}' for k, v in pred_frame.get('slots', [])) 139 | 140 | count_dict['n_true_slots'] += len(true_frame_slot_values) 141 | count_dict['n_pred_slots'] += len(pred_frame_slot_values) 142 | 143 | if strict and not b_correct_act: 144 | pass 145 | else: 146 | count_dict['n_correct_slots'] += \ 147 | len(true_frame_slot_values.intersection(pred_frame_slot_values)) 148 | 149 | count_dict['n_correct_beliefs'] += \ 150 | (b_correct_act and true_frame_slot_values == pred_frame_slot_values) 151 | 152 | return count_dict 153 | 154 | 155 | def add_dicts(d1, d2): 156 | return {k: d1[k] + d2[k] for k in d1} 157 | 158 | 159 | def d_f1(n_true, n_pred, n_correct): 160 | # 1/r + 1/p = 2/F1 161 | # dr / r + dp/p = 2dF1/ F1 162 | # dr / r^2 + dp / p^2 = 2dF1 /F1^2 163 | # dF1 = 1/2 F1^2 (dr/r^2 + dp/p^2) 164 | dr = b_stderr(n_true, n_correct) 165 | dp = b_stderr(n_pred, n_correct) 166 | 167 | r = n_correct / n_true 168 | p = n_correct / n_pred 169 | f1 = 2 * p * r / (p + r) 170 | 171 | d_f1 = 0.5 * f1**2 * (dr / r**2 + dp / p**2) 172 | return d_f1 173 | 174 | 175 | def b_stderr(n_total, n_pos): 176 | return np.std(b_arr(n_total, n_pos)) / np.sqrt(n_total) 177 | 178 | 179 | def b_arr(n_total, n_pos): 180 | out = np.zeros(int(n_total)) 181 | out[:int(n_pos)] = 1.0 182 | return out 183 | -------------------------------------------------------------------------------- /mm_response_generation/README.md: -------------------------------------------------------------------------------- 1 | # DSTC Track 4: SIMMC | Sub-Task #2: Multimodal Assistant Response Generation 2 | 3 | This directory contains the code and the scripts for running the baseline models for Sub-Task #2: Multimodal Assistant Response Generation. 4 | 5 | This subtask measures the generation (or retrieval) of the assistant response given the dialog history, multimodal context, ground truth assistant API call and the current utterance. 6 | 7 | Please check the [task input](./TASK_INPUTS.md) file for a full description of inputs 8 | for each subtask. 9 | 10 | ## Evaluation 11 | For generation, we use BLEU-4 score and for retrieval, we use recall@1, recall@5, recall@10, mean reciprocal rank (MRR), and mean rank. 12 | 13 | The code to evaluate Sub-Task #2 is given in `mm_action_prediction/tools/response_evaluation.py` and 14 | `mm_action_prediction/tools/retrieval_evaluation.py`. 15 | The model outputs are expected in the following format: 16 | 17 | **Response Generation Evaluation** 18 | 19 | ``` 20 | [ 21 | { 22 | "dialog_id": batch["dialog_id"][ii].item(), 23 | "predictions": [ 24 | { 25 | "turn_id": .. 26 | "response": ... 27 | } 28 | ... 29 | ] 30 | } 31 | ... 32 | ] 33 | ``` 34 | 35 | **Retrieval Evaluation** 36 | 37 | ``` 38 | [ 39 | { 40 | "dialog_id": , 41 | "candidate_scores": [ 42 | { 43 | "scores": 44 | "turn_id": .. 45 | }, 46 | ... 47 | ] 48 | } 49 | ... 50 | ] 51 | ``` 52 | 53 | 54 | For more details on the task definition and the baseline models we provide, please refer to our SIMMC paper: 55 | 56 | ``` 57 | @article{moon2020situated, 58 | title={Situated and Interactive Multimodal Conversations}, 59 | author={Moon, Seungwhan and Kottur, Satwik and Crook, Paul A and De, Ankita and Poddar, Shivani and Levin, Theodore and Whitney, David and Difranco, Daniel and Beirami, Ahmad and Cho, Eunjoon and Subba, Rajen and Geramifard, Alborz}, 60 | journal={arXiv preprint arXiv:2006.01460}, 61 | year={2020} 62 | } 63 | ``` 64 | **NOTE**: The [paper][simmc_arxiv] reports the results from an earlier version of the dataset and with different train-dev-test splits, hence the baseline performances on the challenge resources will be slightly different. 65 | 66 | ## Installation (Same across all sub-tasks) 67 | 68 | * Git clone the repository: 69 | ``` 70 | $ git lfs install 71 | $ git clone https://github.com/facebookresearch/simmc.git 72 | ``` 73 | 74 | * Install the required Python packages: 75 | * [Python 3.6+](https://www.python.org/downloads/) 76 | * [PyTorch 1.5+](https://pytorch.org/get-started/locally/#start-locally) 77 | * [Transformers](https://huggingface.co/transformers/installation.html) 78 | 79 | **NOTE**: We recommend installation in a virtual environment ([user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)). Create a new virtual environment and activate it prior to installing the packages. 80 | 81 | ## Run Baselines 82 | 83 | Baselines for Sub-Task #2 jointly train for both Sub-Task #2 and Sub-Task #1. 84 | Please see Sub-Task #1 for instructions to run the baselines. 85 | 86 | ### Results 87 | The baselines trained through the code obtain the following results for Sub-Task #2. 88 | 89 | **SIMMC-Furniture** 90 | 91 | | Model | BLEU-4 | R@1 | R@5 | R@10 | Mean Rank | MRR | 92 | |----------| :-------------: | :------: | :------: | :------: | :------: |:------: | 93 | | LSTM | 0.022 | 4.1 | 11.1 | 17.3 | 46.4 | 0.094 | 94 | | HAE | 0.075 | 12.9 | 28.9 | 38.4 | 31.0 | 0.218 | 95 | | HRE | 0.075 | 13.8 | 30.5 | 40.2 | 30.0 | 0.229 | 96 | | MN | 0.084 | 15.3 | 31.8 | 42.2 | 29.1 | 0.244 | 97 | | T-HAE | 0.044 | 8.5 | 20.3 | 28.9 | 37.9 | 0.156 | 98 | 99 | 100 | **SIMMC-Fashion** 101 | 102 | | Model | BLEU-4 | R@1 | R@5 | R@10 | Mean Rank | MRR | 103 | |----------| :-------------: | :------: | :------: | :------: | :------: |:------: | 104 | | LSTM | 0.022 | 5.3 | 11.4 | 16.5 | 46.9 | 0.102 | 105 | | HAE | 0.059 | 10.5 | 25.3 | 34.1 | 33.5 | 0.190 | 106 | | HRE | 0.079 | 16.3 | 33.1 | 41.7 | 27.4 | 0.253 | 107 | | MN | 0.065 | 16.1 | 31.0 | 39.4 | 29.3 | 0.245 | 108 | | T-HAE | 0.051 | 10.3 | 23.2 | 31.1 | 37.1 | 0.178 | 109 | 110 | MRR = Mean Reciprocal Rank 111 | **Higher is better:** BLEU-4, R@1, R@5, R@10, MRR 112 | **Lower is better:** Mean Rank 113 | 114 | 115 | ## Rules for Sub-task #2 Submissions 116 | * Disallowed Input: `belief_state`, `system_transcript`, `system_transcript_annotated`, `state_graph_1`, `state_graph_2`, and anything from future turns. 117 | * If you would like to use any other external resources, please consult with the track organizers (simmc@fb.com). Generally, we allow the use of publicly available pre-trained language models, such as BERT, GPT-2, etc. 118 | 119 | [simmc_arxiv]:https://arxiv.org/abs/2006.01460 120 | --------------------------------------------------------------------------------