├── .DS_Store ├── README.md ├── data ├── .DS_Store ├── bank77 │ ├── dataset.json │ ├── original │ │ ├── LICENSE │ │ ├── README.md │ │ ├── banking_data │ │ │ ├── categories.json │ │ │ ├── test.csv │ │ │ └── train.csv │ │ ├── polyai-logo.png │ │ └── span_extraction │ │ │ ├── dstc8 │ │ │ ├── Buses_1 │ │ │ │ ├── test.json │ │ │ │ ├── train_0.json │ │ │ │ ├── train_1.json │ │ │ │ ├── train_2.json │ │ │ │ ├── train_3.json │ │ │ │ ├── train_4.json │ │ │ │ └── train_5.json │ │ │ ├── Events_1 │ │ │ │ ├── test.json │ │ │ │ ├── train_0.json │ │ │ │ ├── train_1.json │ │ │ │ ├── train_2.json │ │ │ │ ├── train_3.json │ │ │ │ ├── train_4.json │ │ │ │ └── train_5.json │ │ │ ├── Homes_1 │ │ │ │ ├── test.json │ │ │ │ ├── train_0.json │ │ │ │ ├── train_1.json │ │ │ │ ├── train_2.json │ │ │ │ ├── train_3.json │ │ │ │ ├── train_4.json │ │ │ │ └── train_5.json │ │ │ ├── RentalCars_1 │ │ │ │ ├── test.json │ │ │ │ ├── train_0.json │ │ │ │ ├── train_1.json │ │ │ │ ├── train_2.json │ │ │ │ ├── train_3.json │ │ │ │ ├── train_4.json │ │ │ │ └── train_5.json │ │ │ └── stats.csv │ │ │ └── restaurant8k │ │ │ ├── test.json │ │ │ ├── train_0.json │ │ │ ├── train_1.json │ │ │ ├── train_2.json │ │ │ ├── train_3.json │ │ │ ├── train_4.json │ │ │ ├── train_5.json │ │ │ ├── train_6.json │ │ │ ├── train_7.json │ │ │ └── train_8.json │ └── showDataset.py ├── hint3 │ ├── dataset.json │ ├── original │ │ ├── readme.rd │ │ └── v2 │ │ │ ├── test │ │ │ ├── curekart_test.csv │ │ │ ├── powerplay11_test.csv │ │ │ └── sofmattress_test.csv │ │ │ └── train │ │ │ ├── curekart_train.csv │ │ │ ├── powerplay11_train.csv │ │ │ └── sofmattress_train.csv │ └── showDataset.py ├── hwu64 │ ├── dataset.json │ ├── original │ │ └── NLU-Data-Home-Domain-Annotated-All.csv │ └── showDataset.py ├── mcid │ ├── dataset.json │ ├── original │ │ ├── README │ │ ├── de │ │ │ ├── eval.tsv │ │ │ ├── test.tsv │ │ │ └── train.tsv │ │ ├── en │ │ │ ├── eval.tsv │ │ │ ├── test.tsv │ │ │ └── train.tsv │ │ ├── es │ │ │ ├── eval.tsv │ │ │ ├── test.tsv │ │ │ └── train.tsv │ │ ├── fr │ │ │ ├── eval.tsv │ │ │ ├── test.tsv │ │ │ └── train.tsv │ │ └── spanglish │ │ │ └── test.tsv │ └── showDataset.py └── oos │ ├── dataset.json │ ├── original │ ├── domain_intent.txt │ └── oos-eval-master │ │ ├── .gitignore │ │ ├── LICENSE │ │ ├── README.md │ │ ├── clinc_logo.png │ │ ├── data │ │ ├── all_wiki_sents.txt │ │ ├── binary_undersample.json │ │ ├── binary_wiki_aug.json │ │ ├── data_full.json │ │ ├── data_imbalanced.json │ │ ├── data_oos_plus.json │ │ └── data_small.json │ │ ├── hyperparameters.csv │ │ ├── paper.pdf │ │ ├── poster.pdf │ │ └── supplementary.pdf │ └── showDataset.py ├── eval.py ├── images ├── combined.png └── main.png ├── mlm.py ├── scripts ├── eval.sh ├── mlm.sh └── transfer.sh ├── transfer.py └── utils ├── Evaluator.py ├── IntentDataset.py ├── Logger.py ├── TaskSampler.py ├── Trainer.py ├── __init__.py ├── commonVar.py ├── models.py ├── printHelper.py └── tools.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## IntentBERT: Effectiveness of Pre-training for Few-shot Intent Classification 2 | 3 | This repository contains the code and pre-trained models for our paper on *EMNLP-findings*: [Effectiveness of Pre-training for Few-shot Intent Classification](https://arxiv.org/abs/2109.05782). We write this readme thanks to this [repo](https://github.com/princeton-nlp/SimCSE). 4 | 5 | ## Quick Links 6 | 7 | - [Overview](#overview) 8 | - [Train IntentBERT](#train-intentbert) 9 | - [Requirements](#requirements) 10 | - [Evaluation](#evaluation) 11 | - [Dataset](#dataset) 12 | - [Before running](#before-running) 13 | - [Training](#training) 14 | - [Bugs or Questions?](#bugs-or-questions) 15 | - [Citation](#citation) 16 | 17 | ## Overview 18 | 19 | * Is it possible to learn transferable task-specific knowledge to generalize across different domains for intent detection? 20 | 21 | In this paper, we offer a free lunch solution for few-shot intent detection by pre-training on a large publicly available dataset. Our experiment shows significant improvement over previous pre-trained models on a drastically different target domain, which indicates IntentBERT possesses high generalizability and is a ready-to-use model without further fine-tuning. We also propose a joint pre-training scheme (IntentBERT+MLM) to leverage unlabeled data on target domain. 22 | 23 |

24 | scatter 25 |

26 | 27 |

28 | scatter 29 |

30 | 31 | ## Train IntentBERT 32 | 33 | In the following section, we describe how to train a IntentBERT model by using our code. 34 | 35 | ### Requirements 36 | 37 | Run the following script to install the dependencies, 38 | 39 | ```bash 40 | pip install -r requirements.txt 41 | ``` 42 | 43 | ### Dataset 44 | 45 | We provide dataset required for training and evaluation in `data` folder of this repo. Specifically, "oos" & "hwu64" are used for training or validation, "bank77", "mcid" & "hint3" are used as target dataset. In each dataset, there is a `showDataset.py`. You can `cd` into the dataset folder and run it to display the statistics and examples of the dataset. 46 | 47 | ```bash 48 | python showDataset.py 49 | ``` 50 | 51 | ### Before running 52 | 53 | Set up the path for data and models in `./utils/commonVar.py` as you wish. For example, 54 | 55 | ```python 56 | SAVE_PATH = './saved_models' 57 | DATA_PATH = './data' 58 | ``` 59 | 60 | Download the pre-trained IntentBERT model [here](https://1drv.ms/u/s!AsY5oOBeNeY-hCRMnhQQPojqdK8R?e=Ixz4ke), and save under `SAVE_PATH`. The scripts for running experiments are kept in `./scripts`. You can run a script with an argument `debug` for debug mode and `normal` for experiment mode. A log file will save all the outputs into a file under `./log`. 61 | 62 | ### Evaluation 63 | Code for few-shot evaluation is kept in `eval.py` with a corresponding bash script in `./scripts`. 64 | 65 | Run with the default parameters as, 66 | ```bash 67 | ./scripts/eval.sh normal ${cuda_id} 68 | ``` 69 | 70 | Necessary arguments for the evaluation script are as follows, 71 | 72 | * `--dataDir`: Directory for evaluation data 73 | * `--targetDomain`: Target domain name for evaluation 74 | * `--shot`: Shot number for each class 75 | * `--LMName`: Language model name to be evaluated. Could be a language model name in huggingface hub or a directory in `SAVE_PATH` 76 | 77 | Change `LMName` to evaluate our provided pre-trained models. We provide four trained models under ./saved_models: 78 | 1. intent-bert-base-uncased 79 | 2. joint-intent-bert-base-uncased-hint3 80 | 3. joint-intent-bert-base-uncased-bank77 81 | 4. joint-intent-bert-base-uncased-mcid 82 | 83 | They are corresponding to 'IntentBERT (OOS)', 'IntentBERT (OOS)+MLM'(on hint3), 'IntentBERT (OOS)+MLM'(on bank77) and 'IntentBERT (OOS)+MLM'(on mcid) in the paper. 84 | 85 | 86 | ### Training 87 | 88 | Both supervised pre-training and joint pre-training can be run by `transfer.py` with a corresponding script in `./scripts`. Important arguments are shown here, 89 | * `--dataDir`: Directory for training, validation and test data, concat with "," 90 | * `--sourceDomain`: Source domain name for training, concat with "," 91 | * `--valDomain`: Validation domain name, concat with "," 92 | * `--targetDomain`: Target domain name for evaluation, concat with "," 93 | * `--shot`: Shot number for each class 94 | * `--tensorboard`: Enable tensorboard 95 | * `--saveModel`: Enable to save model 96 | * `--saveName`: The name you want to specify for the saved model, or "none" to use the default name 97 | * `--validation`: Enable validation, it is turned off automatically while using joint pre-training 98 | * `--mlm`: Enable mlm, enable this while using joint pre-training 99 | * `--LMName`: Languge model name as an initialization. Could be a language model name in huggingface hub or a directory in `SAVE_PATH` 100 | 101 | Note that the results might be different from the reported by 1~3% when training with different seeds. 102 | 103 | **Supervised Pre-training** 104 | 105 | Turn off `mlm` and turn on `validation`. Change the datasets and domain names for different settings. 106 | 107 | **Joint Pre-training** 108 | 109 | Turn on `mlm`, `validation` will be turned off automatically. Change the datasets and domain names for different settings. 110 | 111 | ## Bugs or questions? 112 | 113 | If you have any questions related to the code or the paper, feel free to email Haode (`haode.zhang@connect.poly.hk`) and Yuwei (`zhangyuwei.work@gmail.com`). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker! 114 | 115 | ## Citation 116 | 117 | Please cite our paper if you use IntentBERT in your work: 118 | 119 | ```bibtex 120 | @article{zhang2021effectiveness, 121 | title={Effectiveness of Pre-training for Few-shot Intent Classification}, 122 | author={Haode Zhang and Yuwei Zhang and Li-Ming Zhan and Jiaxin Chen and Guangyuan Shi and Xiao-Ming Wu and Albert Y. S. Lam}, 123 | journal={arXiv preprint arXiv:2109.05782}, 124 | year={2021} 125 | } 126 | ``` 127 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/data/.DS_Store -------------------------------------------------------------------------------- /data/bank77/original/README.md: -------------------------------------------------------------------------------- 1 | [![PolyAI](polyai-logo.png)](https://poly-ai.com/) 2 | 3 | # task-specific-datasets 4 | 5 | *A collection of NLU datasets in constrained domains.* 6 | 7 | ## Datasets 8 | 9 | ## Banking 10 | 11 | Dataset composed of online banking queries annotated with their corresponding intents. 12 | 13 | | Dataset statistics | | 14 | | --- | --- | 15 | | Train examples | 10003 | 16 | | Test examples | 3080 | 17 | | Number of intents | 77 | 18 | 19 | | Example Query | Intent | 20 | | --- | --- | 21 | | Is there a way to know when my card will arrive?| card_arrival | 22 | | I think my card is broken | card_not_working | 23 | | I made a mistake and need to cancel a transaction | cancel_transfer | 24 | | Is my card usable anywhere? | card_acceptance | 25 | 26 | 27 | ### Citations 28 | 29 | When using the banking dataset in your work, please cite [Efficient Intent Detection with Dual Sentence Encoders](https://arxiv.org/abs/2003.04807). 30 | 31 | ```bibtex 32 | @inproceedings{Casanueva2020, 33 | author = {I{\~{n}}igo Casanueva and Tadas Temcinas and Daniela Gerz and Matthew Henderson and Ivan Vulic}, 34 | title = {Efficient Intent Detection with Dual Sentence Encoders}, 35 | year = {2020}, 36 | month = {mar}, 37 | note = {Data available at https://github.com/PolyAI-LDN/task-specific-datasets}, 38 | url = {https://arxiv.org/abs/2003.04807}, 39 | booktitle = {Proceedings of the 2nd Workshop on NLP for ConvAI - ACL 2020} 40 | } 41 | 42 | ``` 43 | 44 | ## Span Extraction 45 | The directory `span_extraction` contains the data used for the SpanConvert paper. 46 | 47 | A training example looks like: 48 | ``` 49 | { 50 | "userInput": { 51 | "text": "I would like a table for one person" 52 | }, 53 | "labels": [ 54 | { 55 | "slot": "people", 56 | "valueSpan": { 57 | "startIndex": 25, 58 | "endIndex": 35 59 | } 60 | } 61 | ] 62 | } 63 | ``` 64 | 65 | In the above example, the span "one person" is the value for the `people` slot. 66 | 67 | The datasets have a structure like this: 68 | ``` 69 | ls span_extraction/restaurant8k 70 | 71 | test.json 72 | train_0.json 73 | train_1.json 74 | train_2.json 75 | ... 76 | ``` 77 | Where: 78 | * `test.json` contains the examples for evaluation 79 | * `train_0.json` contains all of the training examples 80 | * `train_{i}.json` contains `1/(2^i)`th of the training data. 81 | 82 | 83 | #### Exploring the Span Extraction Datasets 84 | Here's a quick command line demo to explore some of the datasets (requires `jq` and `parallel`) 85 | ``` 86 | 87 | # Calculate the number of examples in each json file. 88 | cd span_extraction 89 | 90 | ls -d restaurant8k/*.json | parallel -k 'echo -n "{}," && cat {} | jq length' 91 | 92 | restaurant8k/test.json,3731 93 | restaurant8k/train_0.json,8198 94 | restaurant8k/train_1.json,4099 95 | restaurant8k/train_2.json,2049 96 | restaurant8k/train_3.json,1024 97 | restaurant8k/train_4.json,512 98 | restaurant8k/train_5.json,256 99 | restaurant8k/train_6.json,128 100 | restaurant8k/train_7.json,64 101 | restaurant8k/train_8.json,32 102 | 103 | ls -d dstc8/*/*.json | parallel -k 'echo -n "{}," && cat {} | jq length' 104 | dstc8/Buses_1/test.json,377 105 | dstc8/Buses_1/train_0.json,1133 106 | dstc8/Buses_1/train_1.json,566 107 | dstc8/Buses_1/train_2.json,283 108 | dstc8/Buses_1/train_3.json,141 109 | dstc8/Buses_1/train_4.json,70 110 | dstc8/Events_1/test.json,521 111 | dstc8/Events_1/train_0.json,1498 112 | dstc8/Events_1/train_1.json,749 113 | dstc8/Events_1/train_2.json,374 114 | dstc8/Events_1/train_3.json,187 115 | dstc8/Events_1/train_4.json,93 116 | dstc8/Homes_1/test.json,587 117 | dstc8/Homes_1/train_0.json,2064 118 | dstc8/Homes_1/train_1.json,1032 119 | dstc8/Homes_1/train_2.json,516 120 | dstc8/Homes_1/train_3.json,258 121 | dstc8/Homes_1/train_4.json,129 122 | dstc8/RentalCars_1/test.json,328 123 | dstc8/RentalCars_1/train_0.json,874 124 | dstc8/RentalCars_1/train_1.json,437 125 | dstc8/RentalCars_1/train_2.json,218 126 | dstc8/RentalCars_1/train_3.json,109 127 | dstc8/RentalCars_1/train_4.json,54 128 | ``` 129 | ### Citations 130 | 131 | When using the datasets in your work, please cite [the Span-ConveRT paper](https://arxiv.org/abs/2005.08866). 132 | 133 | ```bibtex 134 | @inproceedings{CoopeFarghly2020, 135 | Author = {Sam Coope and Tyler Farghly and Daniela Gerz and Ivan Vulić and Matthew Henderson}, 136 | Title = {Span-ConveRT: Few-shot Span Extraction for Dialog with Pretrained Conversational Representations}, 137 | Year = {2020}, 138 | url = {https://arxiv.org/abs/2005.08866}, 139 | publisher = {ACL}, 140 | } 141 | 142 | ``` 143 | 144 | 145 | ## License 146 | The datasets shared on this repository are licensed under the license found in the LICENSE file. 147 | -------------------------------------------------------------------------------- /data/bank77/original/banking_data/categories.json: -------------------------------------------------------------------------------- 1 | [ 2 | "card_arrival", 3 | "card_linking", 4 | "exchange_rate", 5 | "card_payment_wrong_exchange_rate", 6 | "extra_charge_on_statement", 7 | "pending_cash_withdrawal", 8 | "fiat_currency_support", 9 | "card_delivery_estimate", 10 | "automatic_top_up", 11 | "card_not_working", 12 | "exchange_via_app", 13 | "lost_or_stolen_card", 14 | "age_limit", 15 | "pin_blocked", 16 | "contactless_not_working", 17 | "top_up_by_bank_transfer_charge", 18 | "pending_top_up", 19 | "cancel_transfer", 20 | "top_up_limits", 21 | "wrong_amount_of_cash_received", 22 | "card_payment_fee_charged", 23 | "transfer_not_received_by_recipient", 24 | "supported_cards_and_currencies", 25 | "getting_virtual_card", 26 | "card_acceptance", 27 | "top_up_reverted", 28 | "balance_not_updated_after_cheque_or_cash_deposit", 29 | "card_payment_not_recognised", 30 | "edit_personal_details", 31 | "why_verify_identity", 32 | "unable_to_verify_identity", 33 | "get_physical_card", 34 | "visa_or_mastercard", 35 | "topping_up_by_card", 36 | "disposable_card_limits", 37 | "compromised_card", 38 | "atm_support", 39 | "direct_debit_payment_not_recognised", 40 | "passcode_forgotten", 41 | "declined_cash_withdrawal", 42 | "pending_card_payment", 43 | "lost_or_stolen_phone", 44 | "request_refund", 45 | "declined_transfer", 46 | "Refund_not_showing_up", 47 | "declined_card_payment", 48 | "pending_transfer", 49 | "terminate_account", 50 | "card_swallowed", 51 | "transaction_charged_twice", 52 | "verify_source_of_funds", 53 | "transfer_timing", 54 | "reverted_card_payment?", 55 | "change_pin", 56 | "beneficiary_not_allowed", 57 | "transfer_fee_charged", 58 | "receiving_money", 59 | "failed_transfer", 60 | "transfer_into_account", 61 | "verify_top_up", 62 | "getting_spare_card", 63 | "top_up_by_cash_or_cheque", 64 | "order_physical_card", 65 | "virtual_card_not_working", 66 | "wrong_exchange_rate_for_cash_withdrawal", 67 | "get_disposable_virtual_card", 68 | "top_up_failed", 69 | "balance_not_updated_after_bank_transfer", 70 | "cash_withdrawal_not_recognised", 71 | "exchange_charge", 72 | "top_up_by_card_charge", 73 | "activate_my_card", 74 | "cash_withdrawal_charge", 75 | "card_about_to_expire", 76 | "apple_pay_or_google_pay", 77 | "verify_my_identity", 78 | "country_support" 79 | ] -------------------------------------------------------------------------------- /data/bank77/original/polyai-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/data/bank77/original/polyai-logo.png -------------------------------------------------------------------------------- /data/bank77/original/span_extraction/dstc8/Buses_1/train_5.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "userInput": { 4 | "text": "Can you help me out and look for a bus for me?" 5 | }, 6 | "context": {}, 7 | "id": "27_001080", 8 | "splitKey": 1.0 9 | }, 10 | { 11 | "userInput": { 12 | "text": "I want to visit New York." 13 | }, 14 | "context": { 15 | "requestedSlots": [ 16 | "to_location" 17 | ] 18 | }, 19 | "labels": [ 20 | { 21 | "slot": "to_location", 22 | "valueSpan": { 23 | "startIndex": 16, 24 | "endIndex": 24 25 | } 26 | } 27 | ], 28 | "id": "27_001082", 29 | "splitKey": 1.0 30 | }, 31 | { 32 | "userInput": { 33 | "text": "I am leaving out from Washington on the 11th of March." 34 | }, 35 | "context": { 36 | "requestedSlots": [ 37 | "leaving_date", 38 | "from_location" 39 | ] 40 | }, 41 | "labels": [ 42 | { 43 | "slot": "leaving_date", 44 | "valueSpan": { 45 | "startIndex": 40, 46 | "endIndex": 53 47 | } 48 | }, 49 | { 50 | "slot": "from_location", 51 | "valueSpan": { 52 | "startIndex": 22, 53 | "endIndex": 32 54 | } 55 | } 56 | ], 57 | "id": "27_001084", 58 | "splitKey": 1.0 59 | }, 60 | { 61 | "userInput": { 62 | "text": "What is the bus station that I will leave from and which station will I arrrive at?" 63 | }, 64 | "context": {}, 65 | "id": "27_001086", 66 | "splitKey": 1.0 67 | }, 68 | { 69 | "userInput": { 70 | "text": "Sounds fine to me." 71 | }, 72 | "context": {}, 73 | "id": "27_001088", 74 | "splitKey": 1.0 75 | }, 76 | { 77 | "userInput": { 78 | "text": "Yep, I would like to get tickets." 79 | }, 80 | "context": {}, 81 | "id": "27_0010810", 82 | "splitKey": 1.0 83 | }, 84 | { 85 | "userInput": { 86 | "text": "I need it for two people." 87 | }, 88 | "context": { 89 | "requestedSlots": [ 90 | "travelers" 91 | ] 92 | }, 93 | "id": "27_0010812", 94 | "splitKey": 1.0 95 | }, 96 | { 97 | "userInput": { 98 | "text": "That's not right, it's for 4 people." 99 | }, 100 | "context": {}, 101 | "id": "27_0010814", 102 | "splitKey": 1.0 103 | }, 104 | { 105 | "userInput": { 106 | "text": "Yep, that sounds good to me." 107 | }, 108 | "context": {}, 109 | "id": "27_0010816", 110 | "splitKey": 1.0 111 | }, 112 | { 113 | "userInput": { 114 | "text": "Thanks so much for your help." 115 | }, 116 | "context": {}, 117 | "id": "27_0010818", 118 | "splitKey": 1.0 119 | }, 120 | { 121 | "userInput": { 122 | "text": "Nope, thanks again for helping me." 123 | }, 124 | "context": {}, 125 | "id": "27_0010820", 126 | "splitKey": 1.0 127 | }, 128 | { 129 | "userInput": { 130 | "text": "Is there a bus that leaves from Anaheim to SD?" 131 | }, 132 | "context": {}, 133 | "labels": [ 134 | { 135 | "slot": "to_location", 136 | "valueSpan": { 137 | "startIndex": 43, 138 | "endIndex": 45 139 | } 140 | }, 141 | { 142 | "slot": "from_location", 143 | "valueSpan": { 144 | "startIndex": 32, 145 | "endIndex": 39 146 | } 147 | } 148 | ], 149 | "id": "27_001090", 150 | "splitKey": 1.0 151 | }, 152 | { 153 | "userInput": { 154 | "text": "Thursday next week for four people." 155 | }, 156 | "context": { 157 | "requestedSlots": [ 158 | "leaving_date" 159 | ] 160 | }, 161 | "labels": [ 162 | { 163 | "slot": "leaving_date", 164 | "valueSpan": { 165 | "endIndex": 18 166 | } 167 | } 168 | ], 169 | "id": "27_001092", 170 | "splitKey": 1.0 171 | }, 172 | { 173 | "userInput": { 174 | "text": "Anything else for three people from LAX?" 175 | }, 176 | "context": {}, 177 | "labels": [ 178 | { 179 | "slot": "from_location", 180 | "valueSpan": { 181 | "startIndex": 36, 182 | "endIndex": 39 183 | } 184 | } 185 | ], 186 | "id": "27_001094", 187 | "splitKey": 1.0 188 | }, 189 | { 190 | "userInput": { 191 | "text": "Sounds great. I would like to reserve it." 192 | }, 193 | "context": {}, 194 | "id": "27_001096", 195 | "splitKey": 1.0 196 | }, 197 | { 198 | "userInput": { 199 | "text": "Sounds great." 200 | }, 201 | "context": {}, 202 | "id": "27_001098", 203 | "splitKey": 1.0 204 | }, 205 | { 206 | "userInput": { 207 | "text": "Thanks, that's all for now." 208 | }, 209 | "context": {}, 210 | "id": "27_0010910", 211 | "splitKey": 1.0 212 | }, 213 | { 214 | "userInput": { 215 | "text": "I'm thinking of going out of town. Could you please help me find a bus?" 216 | }, 217 | "context": {}, 218 | "id": "27_001100", 219 | "splitKey": 1.0 220 | }, 221 | { 222 | "userInput": { 223 | "text": "I'd like to leave on 10th of March." 224 | }, 225 | "context": { 226 | "requestedSlots": [ 227 | "leaving_date" 228 | ] 229 | }, 230 | "labels": [ 231 | { 232 | "slot": "leaving_date", 233 | "valueSpan": { 234 | "startIndex": 21, 235 | "endIndex": 34 236 | } 237 | } 238 | ], 239 | "id": "27_001102", 240 | "splitKey": 1.0 241 | }, 242 | { 243 | "userInput": { 244 | "text": "I want to travel from Washington to Philly." 245 | }, 246 | "context": { 247 | "requestedSlots": [ 248 | "to_location", 249 | "from_location" 250 | ] 251 | }, 252 | "labels": [ 253 | { 254 | "slot": "to_location", 255 | "valueSpan": { 256 | "startIndex": 36, 257 | "endIndex": 42 258 | } 259 | }, 260 | { 261 | "slot": "from_location", 262 | "valueSpan": { 263 | "startIndex": 22, 264 | "endIndex": 32 265 | } 266 | } 267 | ], 268 | "id": "27_001104", 269 | "splitKey": 1.0 270 | }, 271 | { 272 | "userInput": { 273 | "text": "What's my destination station?" 274 | }, 275 | "context": {}, 276 | "id": "27_001106", 277 | "splitKey": 1.0 278 | }, 279 | { 280 | "userInput": { 281 | "text": "Please find out whether there are other available buses, for a group of three." 282 | }, 283 | "context": {}, 284 | "id": "27_001108", 285 | "splitKey": 1.0 286 | }, 287 | { 288 | "userInput": { 289 | "text": "Please let me know my departure and destination stations." 290 | }, 291 | "context": {}, 292 | "id": "27_0011010", 293 | "splitKey": 1.0 294 | }, 295 | { 296 | "userInput": { 297 | "text": "Okay, that sounds perfect. Please make a reservation on the bus." 298 | }, 299 | "context": {}, 300 | "id": "27_0011012", 301 | "splitKey": 1.0 302 | }, 303 | { 304 | "userInput": { 305 | "text": "Yes, sounds good." 306 | }, 307 | "context": {}, 308 | "id": "27_0011014", 309 | "splitKey": 1.0 310 | }, 311 | { 312 | "userInput": { 313 | "text": "I'm really grateful for your help. That will be all." 314 | }, 315 | "context": {}, 316 | "id": "27_0011016", 317 | "splitKey": 1.0 318 | }, 319 | { 320 | "userInput": { 321 | "text": "Can you find me a bus?" 322 | }, 323 | "context": {}, 324 | "id": "27_001110", 325 | "splitKey": 1.0 326 | }, 327 | { 328 | "userInput": { 329 | "text": "I want to go from San Francisco to Los Angeles on the 9th" 330 | }, 331 | "context": { 332 | "requestedSlots": [ 333 | "from_location", 334 | "leaving_date", 335 | "to_location" 336 | ] 337 | }, 338 | "labels": [ 339 | { 340 | "slot": "to_location", 341 | "valueSpan": { 342 | "startIndex": 35, 343 | "endIndex": 46 344 | } 345 | }, 346 | { 347 | "slot": "from_location", 348 | "valueSpan": { 349 | "startIndex": 18, 350 | "endIndex": 31 351 | } 352 | }, 353 | { 354 | "slot": "leaving_date", 355 | "valueSpan": { 356 | "startIndex": 50, 357 | "endIndex": 57 358 | } 359 | } 360 | ], 361 | "id": "27_001112", 362 | "splitKey": 1.0 363 | }, 364 | { 365 | "userInput": { 366 | "text": "Which station am I going from?" 367 | }, 368 | "context": {}, 369 | "id": "27_001114", 370 | "splitKey": 1.0 371 | }, 372 | { 373 | "userInput": { 374 | "text": "That's good thanks." 375 | }, 376 | "context": {}, 377 | "id": "27_001116", 378 | "splitKey": 1.0 379 | }, 380 | { 381 | "userInput": { 382 | "text": "Yes please reserve them." 383 | }, 384 | "context": {}, 385 | "id": "27_001118", 386 | "splitKey": 1.0 387 | }, 388 | { 389 | "userInput": { 390 | "text": "I need 2 please." 391 | }, 392 | "context": { 393 | "requestedSlots": [ 394 | "travelers" 395 | ] 396 | }, 397 | "id": "27_0011110", 398 | "splitKey": 1.0 399 | }, 400 | { 401 | "userInput": { 402 | "text": "No I need 3 seats." 403 | }, 404 | "context": {}, 405 | "id": "27_0011112", 406 | "splitKey": 1.0 407 | }, 408 | { 409 | "userInput": { 410 | "text": "Yes that's good, which bus station does it go to?" 411 | }, 412 | "context": {}, 413 | "id": "27_0011114", 414 | "splitKey": 1.0 415 | }, 416 | { 417 | "userInput": { 418 | "text": "Thanks a lot" 419 | }, 420 | "context": {}, 421 | "id": "27_0011116", 422 | "splitKey": 1.0 423 | } 424 | ] -------------------------------------------------------------------------------- /data/bank77/original/span_extraction/dstc8/RentalCars_1/train_5.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "userInput": { 4 | "text": "Book me a rental car." 5 | }, 6 | "context": {}, 7 | "id": "23_000170", 8 | "splitKey": 1.0 9 | }, 10 | { 11 | "userInput": { 12 | "text": "I'll pick it up around 10:30 in the morning, keeping it until next Thursday" 13 | }, 14 | "context": { 15 | "requestedSlots": [ 16 | "pickup_time", 17 | "dropoff_date" 18 | ] 19 | }, 20 | "labels": [ 21 | { 22 | "slot": "dropoff_date", 23 | "valueSpan": { 24 | "startIndex": 62, 25 | "endIndex": 75 26 | } 27 | }, 28 | { 29 | "slot": "pickup_time", 30 | "valueSpan": { 31 | "startIndex": 23, 32 | "endIndex": 43 33 | } 34 | } 35 | ], 36 | "id": "23_000172", 37 | "splitKey": 1.0 38 | }, 39 | { 40 | "userInput": { 41 | "text": "I'll pick it up in New York on March 4th" 42 | }, 43 | "context": { 44 | "requestedSlots": [ 45 | "pickup_city", 46 | "pickup_date" 47 | ] 48 | }, 49 | "labels": [ 50 | { 51 | "slot": "pickup_city", 52 | "valueSpan": { 53 | "startIndex": 19, 54 | "endIndex": 27 55 | } 56 | }, 57 | { 58 | "slot": "pickup_date", 59 | "valueSpan": { 60 | "startIndex": 31, 61 | "endIndex": 40 62 | } 63 | } 64 | ], 65 | "id": "23_000174", 66 | "splitKey": 1.0 67 | }, 68 | { 69 | "userInput": { 70 | "text": "Sounds good" 71 | }, 72 | "context": {}, 73 | "id": "23_000176", 74 | "splitKey": 1.0 75 | }, 76 | { 77 | "userInput": { 78 | "text": "No not now" 79 | }, 80 | "context": {}, 81 | "id": "23_000178", 82 | "splitKey": 1.0 83 | }, 84 | { 85 | "userInput": { 86 | "text": "No thanks so much" 87 | }, 88 | "context": {}, 89 | "id": "23_0001710", 90 | "splitKey": 1.0 91 | }, 92 | { 93 | "userInput": { 94 | "text": "I want to find a rental car for March 6th, pick up in New York." 95 | }, 96 | "context": {}, 97 | "labels": [ 98 | { 99 | "slot": "pickup_city", 100 | "valueSpan": { 101 | "startIndex": 54, 102 | "endIndex": 62 103 | } 104 | }, 105 | { 106 | "slot": "pickup_date", 107 | "valueSpan": { 108 | "startIndex": 32, 109 | "endIndex": 41 110 | } 111 | } 112 | ], 113 | "id": "23_000180", 114 | "splitKey": 1.0 115 | }, 116 | { 117 | "userInput": { 118 | "text": "I need a pickup time around 18:30 and I will need the car until the 7th." 119 | }, 120 | "context": { 121 | "requestedSlots": [ 122 | "dropoff_date" 123 | ] 124 | }, 125 | "labels": [ 126 | { 127 | "slot": "pickup_time", 128 | "valueSpan": { 129 | "startIndex": 28, 130 | "endIndex": 33 131 | } 132 | }, 133 | { 134 | "slot": "dropoff_date", 135 | "valueSpan": { 136 | "startIndex": 64, 137 | "endIndex": 71 138 | } 139 | } 140 | ], 141 | "id": "23_000182", 142 | "splitKey": 1.0 143 | }, 144 | { 145 | "userInput": { 146 | "text": "Can you find anything else?" 147 | }, 148 | "context": {}, 149 | "id": "23_000184", 150 | "splitKey": 1.0 151 | }, 152 | { 153 | "userInput": { 154 | "text": "That would be perfect." 155 | }, 156 | "context": {}, 157 | "id": "23_000186", 158 | "splitKey": 1.0 159 | }, 160 | { 161 | "userInput": { 162 | "text": "No, I don't want to reserve it yet. That's all I need." 163 | }, 164 | "context": {}, 165 | "id": "23_000188", 166 | "splitKey": 1.0 167 | }, 168 | { 169 | "userInput": { 170 | "text": "My usual car is being repaired right now. I'm in need of a rental car for the interim." 171 | }, 172 | "context": {}, 173 | "id": "23_000190", 174 | "splitKey": 1.0 175 | }, 176 | { 177 | "userInput": { 178 | "text": "I'll be needing it from March 9th to March 14th." 179 | }, 180 | "context": { 181 | "requestedSlots": [ 182 | "dropoff_date", 183 | "pickup_date" 184 | ] 185 | }, 186 | "labels": [ 187 | { 188 | "slot": "pickup_date", 189 | "valueSpan": { 190 | "startIndex": 24, 191 | "endIndex": 33 192 | } 193 | }, 194 | { 195 | "slot": "dropoff_date", 196 | "valueSpan": { 197 | "startIndex": 37, 198 | "endIndex": 47 199 | } 200 | } 201 | ], 202 | "id": "23_000192", 203 | "splitKey": 1.0 204 | }, 205 | { 206 | "userInput": { 207 | "text": "I'd like to pick it up from Phoenix." 208 | }, 209 | "context": { 210 | "requestedSlots": [ 211 | "pickup_city" 212 | ] 213 | }, 214 | "labels": [ 215 | { 216 | "slot": "pickup_city", 217 | "valueSpan": { 218 | "startIndex": 28, 219 | "endIndex": 35 220 | } 221 | } 222 | ], 223 | "id": "23_000194", 224 | "splitKey": 1.0 225 | }, 226 | { 227 | "userInput": { 228 | "text": "I'll have to pick it up around 4:30 pm." 229 | }, 230 | "context": { 231 | "requestedSlots": [ 232 | "pickup_time" 233 | ] 234 | }, 235 | "labels": [ 236 | { 237 | "slot": "pickup_time", 238 | "valueSpan": { 239 | "startIndex": 31, 240 | "endIndex": 38 241 | } 242 | } 243 | ], 244 | "id": "23_000196", 245 | "splitKey": 1.0 246 | }, 247 | { 248 | "userInput": { 249 | "text": "Sounds great to me." 250 | }, 251 | "context": {}, 252 | "id": "23_000198", 253 | "splitKey": 1.0 254 | }, 255 | { 256 | "userInput": { 257 | "text": "Yes, I want to rent this Accord." 258 | }, 259 | "context": {}, 260 | "id": "23_0001910", 261 | "splitKey": 1.0 262 | }, 263 | { 264 | "userInput": { 265 | "text": "That's all correct. What will it cost me?" 266 | }, 267 | "context": {}, 268 | "id": "23_0001912", 269 | "splitKey": 1.0 270 | }, 271 | { 272 | "userInput": { 273 | "text": "Thanks for the help." 274 | }, 275 | "context": {}, 276 | "id": "23_0001914", 277 | "splitKey": 1.0 278 | }, 279 | { 280 | "userInput": { 281 | "text": "No, that's all. Thanks for helping." 282 | }, 283 | "context": {}, 284 | "id": "23_0001916", 285 | "splitKey": 1.0 286 | }, 287 | { 288 | "userInput": { 289 | "text": "I need a car for next Thursday" 290 | }, 291 | "context": {}, 292 | "labels": [ 293 | { 294 | "slot": "dropoff_date", 295 | "valueSpan": { 296 | "startIndex": 17, 297 | "endIndex": 30 298 | } 299 | } 300 | ], 301 | "id": "23_000200", 302 | "splitKey": 1.0 303 | }, 304 | { 305 | "userInput": { 306 | "text": "6 in the evening please." 307 | }, 308 | "context": { 309 | "requestedSlots": [ 310 | "pickup_time" 311 | ] 312 | }, 313 | "labels": [ 314 | { 315 | "slot": "pickup_time", 316 | "valueSpan": { 317 | "endIndex": 16 318 | } 319 | } 320 | ], 321 | "id": "23_000202", 322 | "splitKey": 1.0 323 | }, 324 | { 325 | "userInput": { 326 | "text": "the 5th of March" 327 | }, 328 | "context": { 329 | "requestedSlots": [ 330 | "pickup_date" 331 | ] 332 | }, 333 | "labels": [ 334 | { 335 | "slot": "pickup_date", 336 | "valueSpan": { 337 | "startIndex": 4, 338 | "endIndex": 16 339 | } 340 | } 341 | ], 342 | "id": "23_000204", 343 | "splitKey": 1.0 344 | }, 345 | { 346 | "userInput": { 347 | "text": "Yes, from Sacramento, Ca on the 4th of March" 348 | }, 349 | "context": { 350 | "requestedSlots": [ 351 | "pickup_city" 352 | ] 353 | }, 354 | "labels": [ 355 | { 356 | "slot": "pickup_city", 357 | "valueSpan": { 358 | "startIndex": 10, 359 | "endIndex": 24 360 | } 361 | }, 362 | { 363 | "slot": "pickup_date", 364 | "valueSpan": { 365 | "startIndex": 32, 366 | "endIndex": 44 367 | } 368 | } 369 | ], 370 | "id": "23_000206", 371 | "splitKey": 1.0 372 | }, 373 | { 374 | "userInput": { 375 | "text": "How much?" 376 | }, 377 | "context": {}, 378 | "id": "23_000208", 379 | "splitKey": 1.0 380 | }, 381 | { 382 | "userInput": { 383 | "text": "What else is there, I'd like arrange for it to be picked up in New York." 384 | }, 385 | "context": {}, 386 | "labels": [ 387 | { 388 | "slot": "pickup_city", 389 | "valueSpan": { 390 | "startIndex": 63, 391 | "endIndex": 71 392 | } 393 | } 394 | ], 395 | "id": "23_0002010", 396 | "splitKey": 1.0 397 | }, 398 | { 399 | "userInput": { 400 | "text": "How much?" 401 | }, 402 | "context": {}, 403 | "id": "23_0002012", 404 | "splitKey": 1.0 405 | } 406 | ] -------------------------------------------------------------------------------- /data/bank77/original/span_extraction/dstc8/stats.csv: -------------------------------------------------------------------------------- 1 | Buses_1/test.json,377 2 | Buses_1/train_0.json,1133 3 | Buses_1/train_1.json,566 4 | Buses_1/train_2.json,283 5 | Buses_1/train_3.json,141 6 | Buses_1/train_4.json,70 7 | 8 | Events_1/test.json,521 9 | Events_1/train_0.json,1498 10 | Events_1/train_1.json,749 11 | Events_1/train_2.json,374 12 | Events_1/train_3.json,187 13 | Events_1/train_4.json,93 14 | 15 | Homes_1/test.json,587 16 | Homes_1/train_0.json,2064 17 | Homes_1/train_1.json,1032 18 | Homes_1/train_2.json,516 19 | Homes_1/train_3.json,258 20 | Homes_1/train_4.json,129 21 | 22 | RentalCars_1/test.json,328 23 | RentalCars_1/train_0.json,874 24 | RentalCars_1/train_1.json,437 25 | RentalCars_1/train_2.json,218 26 | RentalCars_1/train_3.json,109 27 | RentalCars_1/train_4.json,54 28 | -------------------------------------------------------------------------------- /data/bank77/original/span_extraction/restaurant8k/train_8.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "userInput": { 4 | "text": "There will be 5 adults and 1 child." 5 | }, 6 | "context": { 7 | "requestedSlots": [ 8 | "people" 9 | ] 10 | }, 11 | "labels": [ 12 | { 13 | "slot": "people", 14 | "valueSpan": { 15 | "startIndex": 14, 16 | "endIndex": 34 17 | } 18 | } 19 | ] 20 | }, 21 | { 22 | "userInput": { 23 | "text": "We will require and outside table to seat 9 people on August 23rd" 24 | }, 25 | "labels": [ 26 | { 27 | "slot": "people", 28 | "valueSpan": { 29 | "startIndex": 42, 30 | "endIndex": 50 31 | } 32 | }, 33 | { 34 | "slot": "date", 35 | "valueSpan": { 36 | "startIndex": 54, 37 | "endIndex": 65 38 | } 39 | } 40 | ] 41 | }, 42 | { 43 | "userInput": { 44 | "text": "Do you have room for 11 of us?" 45 | }, 46 | "labels": [ 47 | { 48 | "slot": "people", 49 | "valueSpan": { 50 | "startIndex": 21, 51 | "endIndex": 23 52 | } 53 | } 54 | ] 55 | }, 56 | { 57 | "userInput": { 58 | "text": "We are 13 and have a table." 59 | }, 60 | "labels": [ 61 | { 62 | "slot": "people", 63 | "valueSpan": { 64 | "startIndex": 7, 65 | "endIndex": 9 66 | } 67 | } 68 | ] 69 | }, 70 | { 71 | "userInput": { 72 | "text": "6 a.m." 73 | }, 74 | "context": { 75 | "requestedSlots": [ 76 | "time" 77 | ] 78 | }, 79 | "labels": [ 80 | { 81 | "slot": "time", 82 | "valueSpan": { 83 | "endIndex": 6 84 | } 85 | } 86 | ] 87 | }, 88 | { 89 | "userInput": { 90 | "text": "Can i chnage my booking from 18:00 for 3 people to 17:45 for 4 people?" 91 | }, 92 | "labels": [ 93 | { 94 | "slot": "time", 95 | "valueSpan": { 96 | "startIndex": 51, 97 | "endIndex": 56 98 | } 99 | }, 100 | { 101 | "slot": "people", 102 | "valueSpan": { 103 | "startIndex": 61, 104 | "endIndex": 69 105 | } 106 | } 107 | ] 108 | }, 109 | { 110 | "userInput": { 111 | "text": "I do not want a restaurant located in Crystal Palace, but in West Central." 112 | } 113 | }, 114 | { 115 | "userInput": { 116 | "text": "Can I book for in 14 days?" 117 | }, 118 | "labels": [ 119 | { 120 | "slot": "date", 121 | "valueSpan": { 122 | "startIndex": 15, 123 | "endIndex": 25 124 | } 125 | } 126 | ] 127 | }, 128 | { 129 | "userInput": { 130 | "text": "Please book this hotel for 10 people and on Sun, 05 Aug 2018" 131 | }, 132 | "labels": [ 133 | { 134 | "slot": "people", 135 | "valueSpan": { 136 | "startIndex": 27, 137 | "endIndex": 36 138 | } 139 | }, 140 | { 141 | "slot": "date", 142 | "valueSpan": { 143 | "startIndex": 44, 144 | "endIndex": 60 145 | } 146 | } 147 | ] 148 | }, 149 | { 150 | "userInput": { 151 | "text": "30 minutes before 12pm." 152 | }, 153 | "context": { 154 | "requestedSlots": [ 155 | "time" 156 | ] 157 | }, 158 | "labels": [ 159 | { 160 | "slot": "time", 161 | "valueSpan": { 162 | "endIndex": 23 163 | } 164 | } 165 | ] 166 | }, 167 | { 168 | "userInput": { 169 | "text": "At 11:45AM please" 170 | }, 171 | "labels": [ 172 | { 173 | "slot": "time", 174 | "valueSpan": { 175 | "startIndex": 3, 176 | "endIndex": 10 177 | } 178 | } 179 | ] 180 | }, 181 | { 182 | "userInput": { 183 | "text": "I am looking for Polish cuisine." 184 | } 185 | }, 186 | { 187 | "userInput": { 188 | "text": "Yes. Please place your reservation." 189 | } 190 | }, 191 | { 192 | "userInput": { 193 | "text": "I booked a table for this evening at 08:30PM for Vashti Sampaia. We're going to reschedule because a couple of the party have pulled out." 194 | }, 195 | "context": {}, 196 | "labels": [ 197 | { 198 | "slot": "date", 199 | "valueSpan": { 200 | "startIndex": 21, 201 | "endIndex": 33 202 | } 203 | }, 204 | { 205 | "slot": "time", 206 | "valueSpan": { 207 | "startIndex": 37, 208 | "endIndex": 44 209 | } 210 | }, 211 | { 212 | "slot": "first_name", 213 | "valueSpan": { 214 | "startIndex": 49, 215 | "endIndex": 55 216 | } 217 | }, 218 | { 219 | "slot": "last_name", 220 | "valueSpan": { 221 | "startIndex": 56, 222 | "endIndex": 63 223 | } 224 | } 225 | ] 226 | }, 227 | { 228 | "userInput": { 229 | "text": "8:30 is also fine" 230 | }, 231 | "context": { 232 | "requestedSlots": [ 233 | "time" 234 | ] 235 | }, 236 | "labels": [ 237 | { 238 | "slot": "time", 239 | "valueSpan": { 240 | "endIndex": 4 241 | } 242 | } 243 | ] 244 | }, 245 | { 246 | "userInput": { 247 | "text": "the 31st" 248 | }, 249 | "context": { 250 | "requestedSlots": [ 251 | "date" 252 | ] 253 | }, 254 | "labels": [ 255 | { 256 | "slot": "date", 257 | "valueSpan": { 258 | "startIndex": 4, 259 | "endIndex": 8 260 | } 261 | } 262 | ] 263 | }, 264 | { 265 | "userInput": { 266 | "text": "I am looking for an outside table." 267 | } 268 | }, 269 | { 270 | "userInput": { 271 | "text": "What is the address for the Hardy's in St James?" 272 | } 273 | }, 274 | { 275 | "userInput": { 276 | "text": "I would like to sit outside in 4 days at 10:45AM." 277 | }, 278 | "labels": [ 279 | { 280 | "slot": "date", 281 | "valueSpan": { 282 | "startIndex": 28, 283 | "endIndex": 37 284 | } 285 | }, 286 | { 287 | "slot": "time", 288 | "valueSpan": { 289 | "startIndex": 41, 290 | "endIndex": 49 291 | } 292 | } 293 | ] 294 | }, 295 | { 296 | "userInput": { 297 | "text": "I would like to book for 4 people." 298 | }, 299 | "context": { 300 | "requestedSlots": [ 301 | "people" 302 | ] 303 | }, 304 | "labels": [ 305 | { 306 | "slot": "people", 307 | "valueSpan": { 308 | "startIndex": 25, 309 | "endIndex": 33 310 | } 311 | } 312 | ] 313 | }, 314 | { 315 | "userInput": { 316 | "text": "11 pm" 317 | }, 318 | "context": { 319 | "requestedSlots": [ 320 | "time" 321 | ] 322 | }, 323 | "labels": [ 324 | { 325 | "slot": "time", 326 | "valueSpan": { 327 | "endIndex": 5 328 | } 329 | } 330 | ] 331 | }, 332 | { 333 | "userInput": { 334 | "text": "my friend is lactose intolerant" 335 | } 336 | }, 337 | { 338 | "userInput": { 339 | "text": "Do you have any outside tables?" 340 | } 341 | }, 342 | { 343 | "userInput": { 344 | "text": "19th of November" 345 | }, 346 | "context": { 347 | "requestedSlots": [ 348 | "date" 349 | ] 350 | }, 351 | "labels": [ 352 | { 353 | "slot": "date", 354 | "valueSpan": { 355 | "endIndex": 16 356 | } 357 | } 358 | ] 359 | }, 360 | { 361 | "userInput": { 362 | "text": "Can I have a table inside?" 363 | } 364 | }, 365 | { 366 | "userInput": { 367 | "text": "I would like to find out information about the restaurant called Dulwich Wood House." 368 | } 369 | }, 370 | { 371 | "userInput": { 372 | "text": "I would like to change my booking to include 9 people." 373 | }, 374 | "labels": [ 375 | { 376 | "slot": "people", 377 | "valueSpan": { 378 | "startIndex": 45, 379 | "endIndex": 53 380 | } 381 | } 382 | ] 383 | }, 384 | { 385 | "userInput": { 386 | "text": "I would like information on the Hush Brasserie restaurant please." 387 | } 388 | }, 389 | { 390 | "userInput": { 391 | "text": "Our party changed number; it's 10 now." 392 | }, 393 | "labels": [ 394 | { 395 | "slot": "people", 396 | "valueSpan": { 397 | "startIndex": 31, 398 | "endIndex": 33 399 | } 400 | } 401 | ] 402 | }, 403 | { 404 | "userInput": { 405 | "text": "2018/04/24." 406 | }, 407 | "context": { 408 | "requestedSlots": [ 409 | "date" 410 | ] 411 | }, 412 | "labels": [ 413 | { 414 | "slot": "date", 415 | "valueSpan": { 416 | "endIndex": 10 417 | } 418 | } 419 | ] 420 | }, 421 | { 422 | "userInput": { 423 | "text": "On 12:30PM" 424 | }, 425 | "context": { 426 | "requestedSlots": [ 427 | "time" 428 | ] 429 | }, 430 | "labels": [ 431 | { 432 | "slot": "time", 433 | "valueSpan": { 434 | "startIndex": 3, 435 | "endIndex": 10 436 | } 437 | } 438 | ] 439 | }, 440 | { 441 | "userInput": { 442 | "text": "I want a table at for 2 at noon." 443 | }, 444 | "labels": [ 445 | { 446 | "slot": "time", 447 | "valueSpan": { 448 | "startIndex": 27, 449 | "endIndex": 31 450 | } 451 | }, 452 | { 453 | "slot": "people", 454 | "valueSpan": { 455 | "startIndex": 22, 456 | "endIndex": 23 457 | } 458 | } 459 | ] 460 | } 461 | ] -------------------------------------------------------------------------------- /data/bank77/showDataset.py: -------------------------------------------------------------------------------- 1 | # this script extracts a sub word embedding from the entire one 2 | from gensim.models.keyedvectors import KeyedVectors 3 | import numpy as np 4 | import scipy.io as sio 5 | import nltk 6 | import pdb 7 | import random 8 | import csv 9 | import string 10 | import contractions 11 | import json 12 | import random 13 | 14 | def getDomainIntent(domainLabFile): 15 | domain2lab = {} 16 | lab2domain = {} 17 | currentDomain = None 18 | with open(domainLabFile,'r') as f: 19 | for line in f: 20 | if ':' in line and currentDomain == None: 21 | currentDomain = cleanUpSentence(line) 22 | domain2lab[currentDomain] = [] 23 | elif line == "\n": 24 | currentDomain = None 25 | else: 26 | intent = cleanUpSentence(line) 27 | domain2lab[currentDomain].append(intent) 28 | 29 | for key in domain2lab: 30 | domain = key 31 | labList = domain2lab[key] 32 | for lab in labList: 33 | lab2domain[lab] = domain 34 | 35 | return domain2lab, lab2domain 36 | 37 | def cleanUpSentence(sentence): 38 | # sentence: a string, like " Hello, do you like apple? I hate it!! " 39 | 40 | # strip 41 | sentence = sentence.strip() 42 | 43 | # lower case 44 | sentence = sentence.lower() 45 | 46 | # fix contractions 47 | sentence = contractions.fix(sentence) 48 | 49 | # remove '_' and '-' 50 | sentence = sentence.replace('-',' ') 51 | sentence = sentence.replace('_',' ') 52 | 53 | # remove all punctuations 54 | sentence = ''.join(ch for ch in sentence if ch not in string.punctuation) 55 | 56 | return sentence 57 | 58 | 59 | def check_data_format(file_path): 60 | for line in open(file_path,'rb'): 61 | arr =str(line.strip(),'utf-8') 62 | arr = arr.split('\t') 63 | label = [w for w in arr[0].split(' ')] 64 | question = [w for w in arr[1].split(' ')] 65 | 66 | if len(label) == 0 or len(question) == 0: 67 | print("[ERROR] Find empty data: ", label, question) 68 | return False 69 | 70 | return True 71 | 72 | 73 | def save_data(data, file_path): 74 | # save data to disk 75 | with open(file_path, 'w') as f: 76 | json.dump(data, f) 77 | return 78 | 79 | def save_domain_intent(data, file_path): 80 | domain2intent = {} 81 | for line in data: 82 | domain = line[0] 83 | intent = line[1] 84 | 85 | if not domain in domain2intent: 86 | domain2intent[domain] = set() 87 | 88 | domain2intent[domain].add(intent) 89 | 90 | # save data to disk 91 | print("Saving domain intent out ... format: domain \t intent") 92 | with open(file_path,"w") as f: 93 | for domain in domain2intent: 94 | intentSet = domain2intent[domain] 95 | for intent in intentSet: 96 | f.write("%s\t%s\n" % (domain, intent)) 97 | return 98 | 99 | def display_data(data): 100 | # dataset count 101 | print("[INFO] We have %d dataset."%(len(data))) 102 | 103 | datasetName = 'BANKING77' 104 | data = data[datasetName] 105 | 106 | # domain count 107 | domainName = set() 108 | for domain in data: 109 | domainName.add(domain) 110 | print("[INFO] There are %d domains."%(len(domainName))) 111 | print(domainName) 112 | 113 | # intent count 114 | intentName = set() 115 | for domain in data: 116 | for d in data[domain]: 117 | lab = d[1][0] 118 | intentName.add(lab) 119 | intentName = list(intentName) 120 | intentName.sort() 121 | print("[INFO] There are %d intent."%(len(intentName))) 122 | print(intentName) 123 | 124 | # data count 125 | count = 0 126 | for domain in data: 127 | for d in data[domain]: 128 | count = count+1 129 | print("[INFO] Data count: %d"%(count)) 130 | 131 | # intent for each domain 132 | domain2intentDict = {} 133 | for domain in data: 134 | if not domain in domain2intentDict: 135 | domain2intentDict[domain] = set() 136 | 137 | for d in data[domain]: 138 | lab = d[1][0] 139 | domain2intentDict[domain].add(lab) 140 | print("[INFO] Intent for each domain.") 141 | print(domain2intentDict) 142 | 143 | # data for each intent 144 | intent2count = {} 145 | for domain in data: 146 | for d in data[domain]: 147 | lab = d[1][0] 148 | if not lab in intent2count: 149 | intent2count[lab] = 0 150 | intent2count[lab] = intent2count[lab]+1 151 | print("[INFO] Intent count") 152 | print(intent2count) 153 | 154 | # examples of data 155 | exampleNum = 3 156 | while not exampleNum == 0: 157 | for domain in data: 158 | for d in data[domain]: 159 | lab = d[1] 160 | utt = d[0] 161 | if random.random() < 0.001: 162 | print("[INFO] Example:--%s, %s, %s, %s"%(datasetName, domain, lab, utt)) 163 | exampleNum = exampleNum-1 164 | break 165 | if (exampleNum==0): 166 | break 167 | 168 | return None 169 | 170 | 171 | ## 172 | # @brief clean up data, including intent and utterance 173 | # 174 | # @param data a list of data 175 | # 176 | # @return 177 | def cleanData(data): 178 | newData = [] 179 | for d in data: 180 | utt = d[0] 181 | lab = d[1] 182 | 183 | uttClr = cleanUpSentence(utt) 184 | labClr = cleanUpSentence(lab) 185 | newData.append([labClr, uttClr]) 186 | 187 | return newData 188 | 189 | def constructData(data, intent2domain): 190 | dataset2domain = {} 191 | datasetName = 'CLINC150' 192 | dataset2domain[datasetName] = {} 193 | for d in data: 194 | lab = d[0] 195 | utt = d[1] 196 | domain = intent2domain[lab] 197 | if not domain in dataset2domain[datasetName]: 198 | dataset2domain[datasetName][domain] = [] 199 | dataField = [utt, [lab]] 200 | dataset2domain[datasetName][domain].append(dataField) 201 | 202 | return dataset2domain 203 | 204 | 205 | def read_data(file_path): 206 | with open(file_path) as json_file: 207 | data = json.load(json_file) 208 | return data 209 | 210 | # read in data 211 | #dataPath = "/data1/haode/projects/EMDIntentFewShot/SPIN_refactor/data/refactor_OOS/dataset.json" 212 | dataPath = "./dataset.json" 213 | print("Loading data ...", dataPath) 214 | # read lines, collect data count for different classes 215 | data = read_data(dataPath) 216 | 217 | display_data(data) 218 | print("Display.. done") 219 | -------------------------------------------------------------------------------- /data/hint3/original/readme.rd: -------------------------------------------------------------------------------- 1 | Original HINT3 have v1 and v2, two versions. V2 corrects some false label in v2. 2 | 3 | Here, we use v2. 4 | -------------------------------------------------------------------------------- /data/hint3/original/v2/train/sofmattress_train.csv: -------------------------------------------------------------------------------- 1 | sentence,label 2 | You guys provide EMI option?,EMI 3 | Do you offer Zero Percent EMI payment options?,EMI 4 | 0% EMI.,EMI 5 | EMI,EMI 6 | I want in installment,EMI 7 | I want it on 0% interest,EMI 8 | How to get in EMI,EMI 9 | what about emi options,EMI 10 | I need emi payment. ,EMI 11 | How to EMI,EMI 12 | I want to buy on EMI,EMI 13 | I want to buy this in installments,EMI 14 | 0% Emi,EMI 15 | Down payments,EMI 16 | Installments,EMI 17 | How about Paisa finance,EMI 18 | Paisa finance service available,EMI 19 | What is minimum down payment,EMI 20 | Paisa Finance is available,EMI 21 | Do you accept Paisa EMI card,EMI 22 | Can we buy through Paisa finance,EMI 23 | EMI Option,EMI 24 | Paisa Finance,EMI 25 | No cost EMI is available?,EMI 26 | Is EMI available,EMI 27 | COD option is availble?,COD 28 | Do you offer COD to my pincode?,COD 29 | Can I do COD?,COD 30 | COD,COD 31 | Is it possible to COD,COD 32 | Cash on delivery is acceptable?,COD 33 | Can pay later on delivery ,COD 34 | DO you have COD option,COD 35 | Can it deliver by COD,COD 36 | Is COD option available,COD 37 | Can I get COD option?,COD 38 | Cash on delivery is available,COD 39 | Features of Ortho mattress,ORTHO_FEATURES 40 | What are the key features of the SOF Ortho mattress,ORTHO_FEATURES 41 | SOF ortho,ORTHO_FEATURES 42 | Ortho mattress,ORTHO_FEATURES 43 | Ortho features,ORTHO_FEATURES 44 | ortho,ORTHO_FEATURES 45 | Tell me about SOF Ortho mattress,ORTHO_FEATURES 46 | Have a back problem,ORTHO_FEATURES 47 | I have back pain issue,ORTHO_FEATURES 48 | Back Pain,ORTHO_FEATURES 49 | Is it orthopaedic,ORTHO_FEATURES 50 | Neck pain and back pain,ORTHO_FEATURES 51 | back ache issue,ORTHO_FEATURES 52 | I m looking mattress for slip disc problem,ORTHO_FEATURES 53 | Do we have anything for backache,ORTHO_FEATURES 54 | I am cervical and Lombard section problem,ORTHO_FEATURES 55 | Is there orthopedic mattress available,ORTHO_FEATURES 56 | What are the key features of the SOF Ergo mattress,ERGO_FEATURES 57 | Features of Ergo mattress,ERGO_FEATURES 58 | SOF ergo,ERGO_FEATURES 59 | Ergo mattress,ERGO_FEATURES 60 | Ergo features,ERGO_FEATURES 61 | Ergo,ERGO_FEATURES 62 | Tell me about SOF Ergo mattress,ERGO_FEATURES 63 | What about ergo,ERGO_FEATURES 64 | What is responsive foam,ERGO_FEATURES 65 | SOF ergo features,ERGO_FEATURES 66 | Does this have ergonomic support?,ERGO_FEATURES 67 | What is the difference between the Ergo & Ortho variants,COMPARISON 68 | Difference between Ergo & Ortho Mattress,COMPARISON 69 | Difference between the products,COMPARISON 70 | Compare the 2 mattresses,COMPARISON 71 | Product comparison,COMPARISON 72 | Comparison,COMPARISON 73 | Which mattress to buy?,COMPARISON 74 | Is the mattress good for my back,COMPARISON 75 | Mattress comparison,COMPARISON 76 | Compare ergo & ortho variants,COMPARISON 77 | I wanna know the difference,COMPARISON 78 | What is the warranty period?,WARRANTY 79 | Warranty,WARRANTY 80 | Does mattress cover is included in warranty,WARRANTY 81 | How long is the warranty you offer on your mattresses and what does it cover,WARRANTY 82 | Do you offer warranty ,WARRANTY 83 | Tell me about the product warranty,WARRANTY 84 | Share warranty information,WARRANTY 85 | Need to know the warranty details,WARRANTY 86 | Want to know about warranty,WARRANTY 87 | would interested in warranty details,WARRANTY 88 | How does the 100 night trial work,100_NIGHT_TRIAL_OFFER 89 | What is the 100-night offer,100_NIGHT_TRIAL_OFFER 90 | Trial details,100_NIGHT_TRIAL_OFFER 91 | How to enroll for trial,100_NIGHT_TRIAL_OFFER 92 | 100 night trial,100_NIGHT_TRIAL_OFFER 93 | Is the 100 night return trial applicable for custom size as well,100_NIGHT_TRIAL_OFFER 94 | 100 night,100_NIGHT_TRIAL_OFFER 95 | Trial offer on customisation,100_NIGHT_TRIAL_OFFER 96 | do you provide exchange,100_NIGHT_TRIAL_OFFER 97 | Can I try a mattress first,100_NIGHT_TRIAL_OFFER 98 | I want to check offers,100_NIGHT_TRIAL_OFFER 99 | What is 100 Night trial offer,100_NIGHT_TRIAL_OFFER 100 | Need 100 days trial,100_NIGHT_TRIAL_OFFER 101 | 100 days trial,100_NIGHT_TRIAL_OFFER 102 | Can I get free trial,100_NIGHT_TRIAL_OFFER 103 | 100 Nights trial version,100_NIGHT_TRIAL_OFFER 104 | Can you give me 100 night trial ,100_NIGHT_TRIAL_OFFER 105 | 100 free Nights,100_NIGHT_TRIAL_OFFER 106 | I want to change the size of the mattress.,SIZE_CUSTOMIZATION 107 | Need some help in changing size of the mattress,SIZE_CUSTOMIZATION 108 | How can I order a custom sized mattress,SIZE_CUSTOMIZATION 109 | Custom size,SIZE_CUSTOMIZATION 110 | Customise size ,SIZE_CUSTOMIZATION 111 | Customisation is possible?,SIZE_CUSTOMIZATION 112 | Will I get an option to Customise the size,SIZE_CUSTOMIZATION 113 | Can mattress size be customised?,SIZE_CUSTOMIZATION 114 | Mattress size change,SIZE_CUSTOMIZATION 115 | Can you help with the size?,WHAT_SIZE_TO_ORDER 116 | How do I know what size to order?,WHAT_SIZE_TO_ORDER 117 | How do I know the size of my bed?,WHAT_SIZE_TO_ORDER 118 | Inches,WHAT_SIZE_TO_ORDER 119 | Length,WHAT_SIZE_TO_ORDER 120 | Feet,WHAT_SIZE_TO_ORDER 121 | 6*3,WHAT_SIZE_TO_ORDER 122 | Help me with the size chart,WHAT_SIZE_TO_ORDER 123 | What are the available sizes?,WHAT_SIZE_TO_ORDER 124 | Can I please have the size chart?,WHAT_SIZE_TO_ORDER 125 | Share the size structure,WHAT_SIZE_TO_ORDER 126 | What are the sizes available?,WHAT_SIZE_TO_ORDER 127 | Want to know the custom size chart,WHAT_SIZE_TO_ORDER 128 | Mattress size,WHAT_SIZE_TO_ORDER 129 | What size to order?,WHAT_SIZE_TO_ORDER 130 | Show me all available sizes,WHAT_SIZE_TO_ORDER 131 | What are the sizes,WHAT_SIZE_TO_ORDER 132 | What are the available mattress sizes,WHAT_SIZE_TO_ORDER 133 | King Size,WHAT_SIZE_TO_ORDER 134 | Inches,WHAT_SIZE_TO_ORDER 135 | Get in Touch,LEAD_GEN 136 | Want to talk to an live agent,LEAD_GEN 137 | Please call me,LEAD_GEN 138 | Do you have Live agent,LEAD_GEN 139 | Want to get in touch,LEAD_GEN 140 | Schedule a callback ,LEAD_GEN 141 | Arrange a call back ,LEAD_GEN 142 | How to get in touch?,LEAD_GEN 143 | Interested in buying,LEAD_GEN 144 | I want to buy this,LEAD_GEN 145 | How to buy?,LEAD_GEN 146 | How to order,LEAD_GEN 147 | I want to order,LEAD_GEN 148 | Connect to an agent,LEAD_GEN 149 | Get In Touch,LEAD_GEN 150 | I want a call back,LEAD_GEN 151 | I need a call back,LEAD_GEN 152 | Please call me back,LEAD_GEN 153 | Call me now,LEAD_GEN 154 | I want to buy,LEAD_GEN 155 | Need a call from your representative,LEAD_GEN 156 | Do you deliver to my pincode,CHECK_PINCODE 157 | Check pincode,CHECK_PINCODE 158 | Is delivery possible on this pincode,CHECK_PINCODE 159 | Will you be able to deliver here,CHECK_PINCODE 160 | Can you make delivery on this pin code?,CHECK_PINCODE 161 | Can you deliver on my pincode,CHECK_PINCODE 162 | Do you deliver to,CHECK_PINCODE 163 | Can you please deliver on my pincode,CHECK_PINCODE 164 | Need a delivery on this pincode,CHECK_PINCODE 165 | Can I get delivery on this pincode,CHECK_PINCODE 166 | Do you have any showrooms in Delhi state,DISTRIBUTORS 167 | Do you have any distributors in Mumbai city,DISTRIBUTORS 168 | Do you have any retailers in Pune city,DISTRIBUTORS 169 | Where can I see the product before I buy,DISTRIBUTORS 170 | What is the price for size (x ft x y ft)? What is the price for size (x inches x y inches)?,DISTRIBUTORS 171 | Distributors ,DISTRIBUTORS 172 | Distributors/Retailers/Showrooms,DISTRIBUTORS 173 | Where is your showroom,DISTRIBUTORS 174 | Do you have a showroom,DISTRIBUTORS 175 | Can I visit SOF mattress showroom,DISTRIBUTORS 176 | Where is your showroom,DISTRIBUTORS 177 | Nearby Show room,DISTRIBUTORS 178 | Do you have Showroom,DISTRIBUTORS 179 | Need store,DISTRIBUTORS 180 | Need Nearby Store,DISTRIBUTORS 181 | Nearest shop,DISTRIBUTORS 182 | Demo store,DISTRIBUTORS 183 | Is there any offline stores ,DISTRIBUTORS 184 | Head office,DISTRIBUTORS 185 | Shops nearby,DISTRIBUTORS 186 | Where was this shop,DISTRIBUTORS 187 | Offline stores,DISTRIBUTORS 188 | Do you have store,DISTRIBUTORS 189 | Where is the shop ,DISTRIBUTORS 190 | Is it available in shops,DISTRIBUTORS 191 | You have any branch,DISTRIBUTORS 192 | Store in,DISTRIBUTORS 193 | We want dealer ship,DISTRIBUTORS 194 | Any shop that I can visit,DISTRIBUTORS 195 | Dealership,DISTRIBUTORS 196 | Shop near by,DISTRIBUTORS 197 | Need dealership,DISTRIBUTORS 198 | Outlet,DISTRIBUTORS 199 | Store near me,DISTRIBUTORS 200 | Price of mattress,MATTRESS_COST 201 | Mattress cost,MATTRESS_COST 202 | Cost of mattress,MATTRESS_COST 203 | How much does a SOF mattress cost,MATTRESS_COST 204 | Cost,MATTRESS_COST 205 | Custom size cost,MATTRESS_COST 206 | What does the mattress cost,MATTRESS_COST 207 | I need price,MATTRESS_COST 208 | Price Range,MATTRESS_COST 209 | Mattress price,MATTRESS_COST 210 | Price of Mattress,MATTRESS_COST 211 | Want to know the price,MATTRESS_COST 212 | How Much Cost,MATTRESS_COST 213 | Price,MATTRESS_COST 214 | Rate,MATTRESS_COST 215 | MRP,MATTRESS_COST 216 | Low price,MATTRESS_COST 217 | What will be the price,MATTRESS_COST 218 | Cost of Bed,MATTRESS_COST 219 | Cost of Mattress,MATTRESS_COST 220 | Mattress cost,MATTRESS_COST 221 | What is the cost,MATTRESS_COST 222 | What are the product variants,PRODUCT_VARIANTS 223 | Product Variants,PRODUCT_VARIANTS 224 | Help me with different products,PRODUCT_VARIANTS 225 | What are the mattress variants,PRODUCT_VARIANTS 226 | I want to check products,PRODUCT_VARIANTS 227 | What are the SOF mattress products,PRODUCT_VARIANTS 228 | Show me products,PRODUCT_VARIANTS 229 | Which product is best,PRODUCT_VARIANTS 230 | Which mattress is best,PRODUCT_VARIANTS 231 | I want to buy a mattress,PRODUCT_VARIANTS 232 | Products,PRODUCT_VARIANTS 233 | View products,PRODUCT_VARIANTS 234 | Type of foam used,PRODUCT_VARIANTS 235 | Mattress variants,PRODUCT_VARIANTS 236 | Mattress Features,PRODUCT_VARIANTS 237 | I am looking the mattress,PRODUCT_VARIANTS 238 | Type of mattress,PRODUCT_VARIANTS 239 | Show more mattress,PRODUCT_VARIANTS 240 | What are the mattress features,PRODUCT_VARIANTS 241 | What are your products,PRODUCT_VARIANTS 242 | Tell me about SOF mattress features,PRODUCT_VARIANTS 243 | How is SOF different from other mattress brands,ABOUT_SOF_MATTRESS 244 | Why SOF mattress,ABOUT_SOF_MATTRESS 245 | About SOF Mattress,ABOUT_SOF_MATTRESS 246 | What is SOF mattress,ABOUT_SOF_MATTRESS 247 | Tell me about SOF mattresses,ABOUT_SOF_MATTRESS 248 | Who are SOF mattress,ABOUT_SOF_MATTRESS 249 | What is SOF,ABOUT_SOF_MATTRESS 250 | How is SOF mattress different from,ABOUT_SOF_MATTRESS 251 | Tell me about SOF Mattress,ABOUT_SOF_MATTRESS 252 | Tell me about company,ABOUT_SOF_MATTRESS 253 | Who is SOF mattress,ABOUT_SOF_MATTRESS 254 | It's been a month,DELAY_IN_DELIVERY 255 | Why so long?,DELAY_IN_DELIVERY 256 | I did not receive my order yet,DELAY_IN_DELIVERY 257 | Why so many days,DELAY_IN_DELIVERY 258 | It's delayed,DELAY_IN_DELIVERY 259 | Almost 1 month over,DELAY_IN_DELIVERY 260 | It's been 30 days my product haven't received,DELAY_IN_DELIVERY 261 | It's been so many days,DELAY_IN_DELIVERY 262 | It's already one month,DELAY_IN_DELIVERY 263 | Delivery is delayed,DELAY_IN_DELIVERY 264 | It's too late to get delivered,DELAY_IN_DELIVERY 265 | Order Status,ORDER_STATUS 266 | What is my order status?,ORDER_STATUS 267 | Order related,ORDER_STATUS 268 | Status of my order,ORDER_STATUS 269 | What about my order,ORDER_STATUS 270 | My order Number,ORDER_STATUS 271 | Where is my order,ORDER_STATUS 272 | Order #,ORDER_STATUS 273 | Track order,ORDER_STATUS 274 | I want updates of my order,ORDER_STATUS 275 | I need my order status,ORDER_STATUS 276 | Status of my order,ORDER_STATUS 277 | State of this order,ORDER_STATUS 278 | Order status,ORDER_STATUS 279 | Present status,ORDER_STATUS 280 | What is the status?,ORDER_STATUS 281 | Current state of my order,ORDER_STATUS 282 | What is the order status?,ORDER_STATUS 283 | Where is my product,ORDER_STATUS 284 | When will the order be delivered to me?,ORDER_STATUS 285 | When can we expect,ORDER_STATUS 286 | Need my money back,RETURN_EXCHANGE 287 | I want refund,RETURN_EXCHANGE 288 | Refund,RETURN_EXCHANGE 289 | Not happy with the product please help me to return,RETURN_EXCHANGE 290 | Help me with exchange process,RETURN_EXCHANGE 291 | How do I return it,RETURN_EXCHANGE 292 | I want to return my mattress,RETURN_EXCHANGE 293 | Exchange,RETURN_EXCHANGE 294 | Looking to exchange,RETURN_EXCHANGE 295 | How can I replace the mattress.,RETURN_EXCHANGE 296 | How do I return It,RETURN_EXCHANGE 297 | Return,RETURN_EXCHANGE 298 | Return my product,RETURN_EXCHANGE 299 | Replacement policy,RETURN_EXCHANGE 300 | I want to cancel my order,CANCEL_ORDER 301 | How can I cancel my order,CANCEL_ORDER 302 | Cancel order,CANCEL_ORDER 303 | Cancellation status,CANCEL_ORDER 304 | I want cancel my order,CANCEL_ORDER 305 | Cancel the order,CANCEL_ORDER 306 | Cancel my order,CANCEL_ORDER 307 | Need to cancel my order,CANCEL_ORDER 308 | Process of cancelling order,CANCEL_ORDER 309 | Can I cancel my order here,CANCEL_ORDER 310 | Can I get pillows?,PILLOWS 311 | Do you sell pillows?,PILLOWS 312 | Pillows,PILLOWS 313 | I want to buy pillows,PILLOWS 314 | Are Pillows available,PILLOWS 315 | Need pair of Pillows,PILLOWS 316 | Can I buy pillows from here,PILLOWS 317 | Do you have cushions,PILLOWS 318 | Can I also have pillows,PILLOWS 319 | Is pillows available,PILLOWS 320 | Offers,OFFERS 321 | What are the available offers,OFFERS 322 | Give me some discount,OFFERS 323 | Any discounts,OFFERS 324 | Discount,OFFERS 325 | May I please know about the offers,OFFERS 326 | Available offers,OFFERS 327 | Is offer available,OFFERS 328 | Want to know the discount ,OFFERS 329 | Tell me about the latest offers,OFFERS 330 | -------------------------------------------------------------------------------- /data/hint3/showDataset.py: -------------------------------------------------------------------------------- 1 | # this script extracts a sub word embedding from the entire one 2 | from gensim.models.keyedvectors import KeyedVectors 3 | import numpy as np 4 | import scipy.io as sio 5 | import nltk 6 | import pdb 7 | import random 8 | import csv 9 | import string 10 | import contractions 11 | import json 12 | import random 13 | 14 | def getDomainIntent(domainLabFile): 15 | domain2lab = {} 16 | lab2domain = {} 17 | currentDomain = None 18 | with open(domainLabFile,'r') as f: 19 | for line in f: 20 | if ':' in line and currentDomain == None: 21 | currentDomain = cleanUpSentence(line) 22 | domain2lab[currentDomain] = [] 23 | elif line == "\n": 24 | currentDomain = None 25 | else: 26 | intent = cleanUpSentence(line) 27 | domain2lab[currentDomain].append(intent) 28 | 29 | for key in domain2lab: 30 | domain = key 31 | labList = domain2lab[key] 32 | for lab in labList: 33 | lab2domain[lab] = domain 34 | 35 | return domain2lab, lab2domain 36 | 37 | def cleanUpSentence(sentence): 38 | # sentence: a string, like " Hello, do you like apple? I hate it!! " 39 | 40 | # strip 41 | sentence = sentence.strip() 42 | 43 | # lower case 44 | sentence = sentence.lower() 45 | 46 | # fix contractions 47 | sentence = contractions.fix(sentence) 48 | 49 | # remove '_' and '-' 50 | sentence = sentence.replace('-',' ') 51 | sentence = sentence.replace('_',' ') 52 | 53 | # remove all punctuations 54 | sentence = ''.join(ch for ch in sentence if ch not in string.punctuation) 55 | 56 | return sentence 57 | 58 | 59 | def check_data_format(file_path): 60 | for line in open(file_path,'rb'): 61 | arr =str(line.strip(),'utf-8') 62 | arr = arr.split('\t') 63 | label = [w for w in arr[0].split(' ')] 64 | question = [w for w in arr[1].split(' ')] 65 | 66 | if len(label) == 0 or len(question) == 0: 67 | print("[ERROR] Find empty data: ", label, question) 68 | return False 69 | 70 | return True 71 | 72 | 73 | def save_data(data, file_path): 74 | # save data to disk 75 | with open(file_path, 'w') as f: 76 | json.dump(data, f) 77 | return 78 | 79 | def save_domain_intent(data, file_path): 80 | domain2intent = {} 81 | for line in data: 82 | domain = line[0] 83 | intent = line[1] 84 | 85 | if not domain in domain2intent: 86 | domain2intent[domain] = set() 87 | 88 | domain2intent[domain].add(intent) 89 | 90 | # save data to disk 91 | print("Saving domain intent out ... format: domain \t intent") 92 | with open(file_path,"w") as f: 93 | for domain in domain2intent: 94 | intentSet = domain2intent[domain] 95 | for intent in intentSet: 96 | f.write("%s\t%s\n" % (domain, intent)) 97 | return 98 | 99 | def display_data(data): 100 | # dataset count 101 | print("[INFO] We have %d dataset."%(len(data))) 102 | 103 | datasetName = 'HINT3' 104 | data = data[datasetName] 105 | 106 | # domain count 107 | domainName = set() 108 | for domain in data: 109 | domainName.add(domain) 110 | print("[INFO] There are %d domains."%(len(domainName))) 111 | print(domainName) 112 | 113 | # intent count 114 | intentName = set() 115 | for domain in data: 116 | for d in data[domain]: 117 | lab = d[1][0] 118 | intentName.add(lab) 119 | intentName = list(intentName) 120 | intentName.sort() 121 | print("[INFO] There are %d intent."%(len(intentName))) 122 | print(intentName) 123 | 124 | # data count 125 | count = 0 126 | for domain in data: 127 | for d in data[domain]: 128 | count = count+1 129 | print("[INFO] Data count: %d"%(count)) 130 | 131 | # intent for each domain 132 | domain2intentDict = {} 133 | for domain in data: 134 | if not domain in domain2intentDict: 135 | domain2intentDict[domain] = set() 136 | 137 | for d in data[domain]: 138 | lab = d[1][0] 139 | domain2intentDict[domain].add(lab) 140 | print("[INFO] Intent for each domain.") 141 | print(domain2intentDict) 142 | 143 | # data for each intent 144 | intent2count = {} 145 | for domain in data: 146 | for d in data[domain]: 147 | lab = d[1][0] 148 | if not lab in intent2count: 149 | intent2count[lab] = 0 150 | intent2count[lab] = intent2count[lab]+1 151 | print("[INFO] Intent count") 152 | print(intent2count) 153 | 154 | # examples of data 155 | exampleNum = 3 156 | while not exampleNum == 0: 157 | for domain in data: 158 | for d in data[domain]: 159 | lab = d[1] 160 | utt = d[0] 161 | if random.random() < 0.001: 162 | print("[INFO] Example:--%s, %s, %s, %s"%(datasetName, domain, lab, utt)) 163 | exampleNum = exampleNum-1 164 | break 165 | if (exampleNum==0): 166 | break 167 | 168 | return None 169 | 170 | 171 | ## 172 | # @brief clean up data, including intent and utterance 173 | # 174 | # @param data a list of data 175 | # 176 | # @return 177 | def cleanData(data): 178 | newData = [] 179 | for d in data: 180 | utt = d[0] 181 | lab = d[1] 182 | 183 | uttClr = cleanUpSentence(utt) 184 | labClr = cleanUpSentence(lab) 185 | newData.append([labClr, uttClr]) 186 | 187 | return newData 188 | 189 | def constructData(data, intent2domain): 190 | dataset2domain = {} 191 | datasetName = 'HINT3' 192 | dataset2domain[datasetName] = {} 193 | for d in data: 194 | lab = d[0] 195 | utt = d[1] 196 | domain = intent2domain[lab] 197 | if not domain in dataset2domain[datasetName]: 198 | dataset2domain[datasetName][domain] = [] 199 | dataField = [utt, [lab]] 200 | dataset2domain[datasetName][domain].append(dataField) 201 | 202 | return dataset2domain 203 | 204 | 205 | def read_data(file_path): 206 | with open(file_path) as json_file: 207 | data = json.load(json_file) 208 | return data 209 | 210 | # read in data 211 | #dataPath = "/data1/haode/projects/EMDIntentFewShot/SPIN_refactor/data/refactor_OOS/dataset.json" 212 | dataPath = "./dataset.json" 213 | print("Loading data ...", dataPath) 214 | # read lines, collect data count for different classes 215 | data = read_data(dataPath) 216 | 217 | display_data(data) 218 | print("Display.. done") 219 | -------------------------------------------------------------------------------- /data/hwu64/showDataset.py: -------------------------------------------------------------------------------- 1 | # this script extracts a sub word embedding from the entire one 2 | from gensim.models.keyedvectors import KeyedVectors 3 | import numpy as np 4 | import scipy.io as sio 5 | import nltk 6 | import pdb 7 | import random 8 | import csv 9 | import string 10 | import contractions 11 | import json 12 | import random 13 | 14 | def getDomainIntent(domainLabFile): 15 | domain2lab = {} 16 | lab2domain = {} 17 | currentDomain = None 18 | with open(domainLabFile,'r') as f: 19 | for line in f: 20 | if ':' in line and currentDomain == None: 21 | currentDomain = cleanUpSentence(line) 22 | domain2lab[currentDomain] = [] 23 | elif line == "\n": 24 | currentDomain = None 25 | else: 26 | intent = cleanUpSentence(line) 27 | domain2lab[currentDomain].append(intent) 28 | 29 | for key in domain2lab: 30 | domain = key 31 | labList = domain2lab[key] 32 | for lab in labList: 33 | lab2domain[lab] = domain 34 | 35 | return domain2lab, lab2domain 36 | 37 | def cleanUpSentence(sentence): 38 | # sentence: a string, like " Hello, do you like apple? I hate it!! " 39 | 40 | # strip 41 | sentence = sentence.strip() 42 | 43 | # lower case 44 | sentence = sentence.lower() 45 | 46 | # fix contractions 47 | sentence = contractions.fix(sentence) 48 | 49 | # remove '_' and '-' 50 | sentence = sentence.replace('-',' ') 51 | sentence = sentence.replace('_',' ') 52 | 53 | # remove all punctuations 54 | sentence = ''.join(ch for ch in sentence if ch not in string.punctuation) 55 | 56 | return sentence 57 | 58 | 59 | def check_data_format(file_path): 60 | for line in open(file_path,'rb'): 61 | arr =str(line.strip(),'utf-8') 62 | arr = arr.split('\t') 63 | label = [w for w in arr[0].split(' ')] 64 | question = [w for w in arr[1].split(' ')] 65 | 66 | if len(label) == 0 or len(question) == 0: 67 | print("[ERROR] Find empty data: ", label, question) 68 | return False 69 | 70 | return True 71 | 72 | 73 | def save_data(data, file_path): 74 | # save data to disk 75 | with open(file_path, 'w') as f: 76 | json.dump(data, f) 77 | return 78 | 79 | def save_domain_intent(data, file_path): 80 | domain2intent = {} 81 | for line in data: 82 | domain = line[0] 83 | intent = line[1] 84 | 85 | if not domain in domain2intent: 86 | domain2intent[domain] = set() 87 | 88 | domain2intent[domain].add(intent) 89 | 90 | # save data to disk 91 | print("Saving domain intent out ... format: domain \t intent") 92 | with open(file_path,"w") as f: 93 | for domain in domain2intent: 94 | intentSet = domain2intent[domain] 95 | for intent in intentSet: 96 | f.write("%s\t%s\n" % (domain, intent)) 97 | return 98 | 99 | def display_data(data): 100 | # dataset count 101 | print("[INFO] We have %d dataset."%(len(data))) 102 | 103 | datasetName = 'HWU64' 104 | data = data[datasetName] 105 | 106 | # domain count 107 | domainName = set() 108 | for domain in data: 109 | domainName.add(domain) 110 | print("[INFO] There are %d domains."%(len(domainName))) 111 | print(domainName) 112 | 113 | # intent count 114 | intentName = set() 115 | for domain in data: 116 | for d in data[domain]: 117 | lab = d[1][0] 118 | intentName.add(lab) 119 | intentName = list(intentName) 120 | intentName.sort() 121 | print("[INFO] There are %d intent."%(len(intentName))) 122 | print(intentName) 123 | 124 | # data count 125 | count = 0 126 | for domain in data: 127 | for d in data[domain]: 128 | count = count+1 129 | print("[INFO] Data count: %d"%(count)) 130 | 131 | # intent for each domain 132 | domain2intentDict = {} 133 | for domain in data: 134 | if not domain in domain2intentDict: 135 | domain2intentDict[domain] = set() 136 | 137 | for d in data[domain]: 138 | lab = d[1][0] 139 | domain2intentDict[domain].add(lab) 140 | print("[INFO] Intent for each domain.") 141 | print(domain2intentDict) 142 | 143 | # data for each intent 144 | intent2count = {} 145 | for domain in data: 146 | for d in data[domain]: 147 | lab = d[1][0] 148 | if not lab in intent2count: 149 | intent2count[lab] = 0 150 | intent2count[lab] = intent2count[lab]+1 151 | print("[INFO] Intent count") 152 | print(intent2count) 153 | 154 | # examples of data 155 | exampleNum = 3 156 | while not exampleNum == 0: 157 | for domain in data: 158 | for d in data[domain]: 159 | lab = d[1] 160 | utt = d[0] 161 | if random.random() < 0.001: 162 | print("[INFO] Example:--%s, %s, %s, %s"%(datasetName, domain, lab, utt)) 163 | exampleNum = exampleNum-1 164 | break 165 | if (exampleNum==0): 166 | break 167 | 168 | return None 169 | 170 | 171 | ## 172 | # @brief clean up data, including intent and utterance 173 | # 174 | # @param data a list of data 175 | # 176 | # @return 177 | def cleanData(data): 178 | newData = [] 179 | for d in data: 180 | utt = d[0] 181 | lab = d[1] 182 | 183 | uttClr = cleanUpSentence(utt) 184 | labClr = cleanUpSentence(lab) 185 | newData.append([labClr, uttClr]) 186 | 187 | return newData 188 | 189 | def constructData(data, intent2domain): 190 | dataset2domain = {} 191 | datasetName = 'CLINC150' 192 | dataset2domain[datasetName] = {} 193 | for d in data: 194 | lab = d[0] 195 | utt = d[1] 196 | domain = intent2domain[lab] 197 | if not domain in dataset2domain[datasetName]: 198 | dataset2domain[datasetName][domain] = [] 199 | dataField = [utt, [lab]] 200 | dataset2domain[datasetName][domain].append(dataField) 201 | 202 | return dataset2domain 203 | 204 | 205 | def read_data(file_path): 206 | with open(file_path) as json_file: 207 | data = json.load(json_file) 208 | return data 209 | 210 | # read in data 211 | #dataPath = "/data1/haode/projects/EMDIntentFewShot/SPIN_refactor/data/refactor_OOS/dataset.json" 212 | dataPath = "./dataset.json" 213 | print("Loading data ...", dataPath) 214 | # read lines, collect data count for different classes 215 | data = read_data(dataPath) 216 | 217 | display_data(data) 218 | print("Display.. done") 219 | -------------------------------------------------------------------------------- /data/mcid/original/README: -------------------------------------------------------------------------------- 1 | Data Description 2 | The following directory contains intent classification data(train, test and eval splits) for Covid-19 queries in 3 languages: 3 | 1. English 4 | 2. Spanish 5 | 3. French 6 | 4. German 7 | Additionally, there is also a Spanglish directory which contains a code-switched test set for Spanglish queries. 8 | 9 | Data Format 10 | 11 | The data is provided in the form of Tab Separated files containing two columns: 12 | 1. Column 1 contains the query utterance 13 | 2. Column 2 contains the intent/label for the query 14 | -------------------------------------------------------------------------------- /data/mcid/original/en/eval.tsv: -------------------------------------------------------------------------------- 1 | is there a cure for the coronavirus intent:what_are_treatment_options 2 | flu treatment options viable to fight corona intent:what_are_treatment_options 3 | i want treatment options for covid-19 intent:what_are_treatment_options 4 | is there a cure for the virus intent:what_are_treatment_options 5 | what are some updates on the cure for covid-19 intent:what_are_treatment_options 6 | does acetaminophen cure coronavirus intent:what_are_treatment_options 7 | where can i get the covid vaccine intent:what_are_treatment_options 8 | who offers a vaccine for corona virus intent:what_are_treatment_options 9 | any homeopathic remedies for covid 19 intent:what_are_treatment_options 10 | please tell me if it's a myth that soup gives you corona intent:myths 11 | covid debunked myths intent:myths 12 | myths about desinfect your shoes every day intent:myths 13 | do 5g antennas spread the coronavirus intent:myths 14 | is it true that i can catch covid-19 from animals intent:myths 15 | covid-19 origin myths intent:myths 16 | myth about half the population dying from covid 19 intent:myths 17 | is it a myth that i can get coronavirus from my cat intent:myths 18 | can i travel to go hiking intent:travel 19 | can i visit china now intent:travel 20 | advice for traveling in dc intent:travel 21 | is air travel safe yet intent:travel 22 | will be the hotels reopen in 40 days intent:travel 23 | i wanna see the latest news about covid in my city intent:news_and_press 24 | tell me the latest corona news for europe intent:news_and_press 25 | recent updates in bahrain intent:news_and_press 26 | anything new about new york intent:news_and_press 27 | latest coronavirus news intent:news_and_press 28 | any news on coronavirus from today intent:news_and_press 29 | are there any updates in texas intent:news_and_press 30 | covid updates from the latest press release intent:news_and_press 31 | share this intent:share 32 | share this bot intent:share 33 | please share this intent:share 34 | this would be great to share intent:share 35 | i want to share this assistant with my contacts intent:share 36 | share this article intent:share 37 | share this info intent:share 38 | share this report intent:share 39 | send link to sister intent:share 40 | how are you today bot intent:hi 41 | hello covid bot intent:hi 42 | are you still sleepy bot intent:hi 43 | hey there intent:hi 44 | whaddup doctor intent:hi 45 | hey there doc intent:hi 46 | hey doc intent:hi 47 | what's up doc intent:hi 48 | how's it going intent:hi 49 | great info, thanks intent:okay_thanks 50 | got it, thanks bot intent:okay_thanks 51 | thanks for the info intent:okay_thanks 52 | i appreciate your help, dear corona intent:okay_thanks 53 | i appreciate it doc intent:okay_thanks 54 | i appreciate your help with travel safety tips intent:okay_thanks 55 | thanks for letting me know about high risk areas intent:okay_thanks 56 | thanks so much, corona doc intent:okay_thanks 57 | thank you for everything, doctor intent:okay_thanks 58 | thanks for all the info intent:okay_thanks 59 | thanks for the information doctor intent:okay_thanks 60 | thanks for the details intent:okay_thanks 61 | thank you for your help corona bot intent:okay_thanks 62 | how can i donate to new york city intent:donate 63 | how can i donate to charity intent:donate 64 | what's the best way to make a donation in my community intent:donate 65 | i want to fund nyc intent:donate 66 | donate to china intent:donate 67 | donate for african relief intent:donate 68 | donate to a local charity intent:donate 69 | what's a good cause to donate to intent:donate 70 | funding for covid 19 response chicago intent:donate 71 | how many covid19 cases are there now intent:latest_numbers 72 | how many confirmed cases are there in new york intent:latest_numbers 73 | can you show me the ranking by confirmed cases within the u.s intent:latest_numbers 74 | what are the statistics for elderly people infected by the virus intent:latest_numbers 75 | how many us states have a shelter in place order intent:latest_numbers 76 | total recoveries from covid intent:latest_numbers 77 | how many tests has the us done intent:latest_numbers 78 | does staying home protect my family from getting infected with the virus intent:protect_yourself 79 | do gloves help intent:protect_yourself 80 | should i sanitize my hands after grocery shopping intent:protect_yourself 81 | what kind of hand sanitizer should i use intent:protect_yourself 82 | how many times a day should i wash my hands intent:protect_yourself 83 | what should essential workers do for prevention of the virus intent:protect_yourself 84 | how long is the virus alive on surfaces intent:protect_yourself 85 | how to kill the spread of corona virus on surfaces intent:protect_yourself 86 | how can i protect myself from coronavirus intent:protect_yourself 87 | do face masks help stop coronavirus intent:protect_yourself 88 | are there any medications i should take if i have the virus intent:protect_yourself 89 | do pets need protection if going outdoors intent:protect_yourself 90 | should i wear gloves intent:protect_yourself 91 | what gloves are best for protection intent:protect_yourself 92 | enlighten me about the virus intent:what_is_corona 93 | brief me about the corona intent:what_is_corona 94 | describe corona intent:what_is_corona 95 | give me more info on the pandemic intent:what_is_corona 96 | tell me what i should know about covid intent:what_is_corona 97 | what does covid-19 stand for intent:what_is_corona 98 | is diarrhea a symptom intent:what_are_symptoms 99 | what are the symptoms of covid-19 intent:what_are_symptoms 100 | tell me all the possible symptoms of covid-19 intent:what_are_symptoms 101 | how can i get better intent:what_are_symptoms 102 | how can i tell the difference between having allergies or coronavirus intent:what_are_symptoms 103 | what are possible danger signs for it intent:what_are_symptoms 104 | what are some of the symptoms associated with corona virus intent:what_are_symptoms 105 | what are the late stage symptoms intent:what_are_symptoms 106 | how many different symptoms are there intent:what_are_symptoms 107 | can corona virus cause a headache intent:what_are_symptoms 108 | i have a sore throat, is this a symptom of the covid-19 intent:what_are_symptoms 109 | what main symptoms are clear signs of the disease intent:what_are_symptoms 110 | what symptom is the easiest to notice when someone has the disease intent:what_are_symptoms 111 | is covid-19 causing my shortness of breath intent:what_are_symptoms 112 | ive been sneezing a lot is that caused buy the corona intent:what_are_symptoms 113 | does corona virus spread through farts intent:how_does_corona_spread 114 | can i get the corona from my kids ipad even after 3 days intent:how_does_corona_spread 115 | can i get infected with covid if someone sneezes on me intent:how_does_corona_spread 116 | how likely does covid spread on amazon packages intent:how_does_corona_spread 117 | is coronavirus spread on doorknobs intent:how_does_corona_spread 118 | can i get infected with covid through indirect contact intent:how_does_corona_spread 119 | how can touching shared surfuces put you at risk for corona virus intent:how_does_corona_spread 120 | can the virus spread by skin to skin interaction intent:how_does_corona_spread 121 | can the virus travel by touch intent:how_does_corona_spread 122 | should i have contact with my pet if i have the coronavirus disease intent:can_i_get_from_feces_animal_pets 123 | can pets get the coronavirus, and can we catch it from them intent:can_i_get_from_feces_animal_pets 124 | how do i avoid getting covid from animals intent:can_i_get_from_feces_animal_pets 125 | is there a vaccination against the covid-19 coronavirus that my pet can receive intent:can_i_get_from_feces_animal_pets 126 | can i test my animal through a private veterinary laboratory intent:can_i_get_from_feces_animal_pets 127 | can i get corona virus from eating sushi intent:can_i_get_from_feces_animal_pets 128 | how can i tell if my cat has corona intent:can_i_get_from_feces_animal_pets 129 | can snakes get the virus intent:can_i_get_from_feces_animal_pets 130 | can horses get the virus intent:can_i_get_from_feces_animal_pets 131 | do mosquitos carry covid intent:can_i_get_from_feces_animal_pets 132 | how long does covid19 stay on public surfaces intent:can_i_get_from_packages_surfaces 133 | can i get coronavirus from touching groceries intent:can_i_get_from_packages_surfaces 134 | have people got covid from deliveries from china intent:can_i_get_from_packages_surfaces 135 | should i stop getting deliveries intent:can_i_get_from_packages_surfaces 136 | will bleach kill coronavirus on my stove intent:can_i_get_from_packages_surfaces 137 | is covid good at living on steel intent:can_i_get_from_packages_surfaces 138 | does covid-19 stay on furniture intent:can_i_get_from_packages_surfaces 139 | does covid-19 stay on ceilings intent:can_i_get_from_packages_surfaces 140 | which cities in the united states are at most risk for coronavirus right now intent:what_if_i_visited_high_risk_area 141 | is the hospital considered a high risk area intent:what_if_i_visited_high_risk_area 142 | which country is currently the highest risk intent:what_if_i_visited_high_risk_area 143 | should i be concerned since i just visited china intent:what_if_i_visited_high_risk_area 144 | what city in california has the highest risk intent:what_if_i_visited_high_risk_area 145 | should i travel within the u.s during the pandemic intent:what_if_i_visited_high_risk_area 146 | which countries are safe for nonessential international travel intent:what_if_i_visited_high_risk_area 147 | will i get in trouble visiting a high risk area intent:what_if_i_visited_high_risk_area 148 | which essential business has the highest risk of infection intent:what_if_i_visited_high_risk_area 149 | -------------------------------------------------------------------------------- /data/mcid/original/es/eval.tsv: -------------------------------------------------------------------------------- 1 | me gustaría saber el número de muertes de corona virus en florida intent:latest_numbers 2 | dime cuántos casos confirmados de covid hay en méxico intent:latest_numbers 3 | estadísticas de pacientes recuperados a fecha de hoy en el mundo intent:latest_numbers 4 | me facilitás la cantidad de recuperados de corona virus en mi ciudad intent:latest_numbers 5 | cuántos casos confirmados de covid19 hay en el mundo intent:latest_numbers 6 | cuáles son los países con más muertos por covid19 intent:latest_numbers 7 | estadísticas mundiales sobre el coronavirus intent:latest_numbers 8 | dígame cuántos casos hay actualmente en cuba intent:latest_numbers 9 | cuántos contagiados hay en argentina intent:latest_numbers 10 | cómo puedo proteger a los menores frente al corona virus intent:protect_yourself 11 | me dices, por favor, qué debo hacer para no contagiarme de corona virus intent:protect_yourself 12 | si puedes dime la mejor forma de prevenir contagios en los menores intent:protect_yourself 13 | qué puedo usar si no tengo cubrebocas en mi casa intent:protect_yourself 14 | cómo puedo prevenir el contagio de corona virus intent:protect_yourself 15 | cómo evito que el virus entre a mi casa si salgo a comprar comida intent:protect_yourself 16 | cuánta lejía tengo que usar para limpiar los productos del covid19 intent:protect_yourself 17 | medidas de cuidado para que adultos mayores no se contagien de covid19 intent:protect_yourself 18 | cómo puedo hacer para protegerme del coronavirus intent:protect_yourself 19 | a cuántos metros puedo estar de otra persona sin que me contagie intent:protect_yourself 20 | qué prevenciones deben tomar los adultos mayores frente al covid-19 intent:protect_yourself 21 | cómo hago para no atrapar el corona virus intent:protect_yourself 22 | no me quiero enfermar. ¿qué puedo hacer intent:protect_yourself 23 | cómo evito contagiarme del covid 19 intent:protect_yourself 24 | es necesario desinfectar los calzados al volver de lugares públicos intent:protect_yourself 25 | quiero aprender más del coronavirus intent:what_is_corona 26 | dime todo del coronavirus intent:what_is_corona 27 | qué tan mortal es el covid-19 intent:what_is_corona 28 | muéstrame información del coronavirus, por favor intent:what_is_corona 29 | me dices la diferencia entre coronavirus y sars intent:what_is_corona 30 | qué enfermedades están relacionadas con el covid intent:what_is_corona 31 | hay otro virus que se llame corona virus intent:what_is_corona 32 | cuál es la diferencia entre el corona virus y el de la influenza intent:what_is_corona 33 | me explicás qué es exactamente el covid 19 intent:what_is_corona 34 | en qué consiste el virus del covid19 intent:what_is_corona 35 | dame datos básicos del coronavirus intent:what_is_corona 36 | quiero que me des información científica sobre el coronavirus intent:what_is_corona 37 | dame una definición de covid 19 intent:what_is_corona 38 | define corona virus intent:what_is_corona 39 | qué tal si me das un resumen sobre el corona virus intent:what_is_corona 40 | estoy interesada en saber cuáles son los síntomas del corona virus, por favor intent:what_are_symptoms 41 | tengo mucha necesidad por conocer cuáles son los principales sintomas del corona virus intent:what_are_symptoms 42 | crees que el dolor de garganta es un síntoma de corona virus intent:what_are_symptoms 43 | como puedo saber si tengo corona virus intent:what_are_symptoms 44 | cuáles son los principales síntomas del corona virus intent:what_are_symptoms 45 | cómo afecta la covid al cuerpo intent:what_are_symptoms 46 | qué ataca el covid intent:what_are_symptoms 47 | diferencias entre la influenza y el covid intent:what_are_symptoms 48 | cuáles son los síntomas de alguien que tiene coronavirus intent:what_are_symptoms 49 | es posible que sólo tenga síntomas estomacales y sea corona virus intent:what_are_symptoms 50 | cuál es el síntoma más evidente que indica corona virus intent:what_are_symptoms 51 | dime si tener diarrea es un síntoma suficiente para hacerse un examen intent:what_are_symptoms 52 | necesito una lista de síntomas frecuentes de corona virus intent:what_are_symptoms 53 | afectaciones de alguien con coronavirus intent:what_are_symptoms 54 | cómo saber si estoy infectada de coronavirus intent:what_are_symptoms 55 | lista de síntomas del coronavirus intent:what_are_symptoms 56 | cómo se contagia el corona virus intent:how_does_corona_spread 57 | búscame información de cómo se contagia el corona virus, por favor intent:how_does_corona_spread 58 | cuáles son los factores de riesgo del covid intent:how_does_corona_spread 59 | puedo contagiarme por caminar en la calle intent:how_does_corona_spread 60 | por tocar una superficie me puedo contagiar de covid19 intent:how_does_corona_spread 61 | el coronavirus, ¿es una enfermedad de transmición sexual intent:how_does_corona_spread 62 | el virus le puede dar a uno por contacto de piel intent:how_does_corona_spread 63 | quisiera informacion sobre si los animales pueden contagiar el corona virus intent:can_i_get_from_feces_animal_pets 64 | realmente necesito saber si mediante los animales podemos contagiarnos de corona virus intent:can_i_get_from_feces_animal_pets 65 | el coronavirus se transmite entre especies intent:can_i_get_from_feces_animal_pets 66 | le puede dar coronavirus a mi mascota intent:can_i_get_from_feces_animal_pets 67 | a mis gatos les puede dar coronavirus intent:can_i_get_from_feces_animal_pets 68 | si me como un murciélago tendré coronavirus intent:can_i_get_from_feces_animal_pets 69 | las palomas propagan el corona virus intent:can_i_get_from_feces_animal_pets 70 | por qué hay tigres con corona virus intent:can_i_get_from_feces_animal_pets 71 | necesito proteger a mi perro del virus intent:can_i_get_from_feces_animal_pets 72 | es cierto que el covid19 viene de los murciélagos intent:can_i_get_from_feces_animal_pets 73 | los cerdos se pueden enfermar del virus intent:can_i_get_from_feces_animal_pets 74 | puede un perro darle corona virus a un gato intent:can_i_get_from_feces_animal_pets 75 | montar un caballo puede pasarme el virus intent:can_i_get_from_feces_animal_pets 76 | tienes información de cuánto tiempo dura el corona virus en el cartón intent:can_i_get_from_packages_surfaces 77 | quiero que me digas si el corona virus sobrevive mucho tiempo en el papel intent:can_i_get_from_packages_surfaces 78 | el coronavirus aguanta el calor intent:can_i_get_from_packages_surfaces 79 | el coronavirus puede quedarse en los muebles intent:can_i_get_from_packages_surfaces 80 | cuánto tiempo vive el virus en la madera intent:can_i_get_from_packages_surfaces 81 | cuánto tiempo vive el virus en la mesada de mi cocina intent:can_i_get_from_packages_surfaces 82 | dime si las superficies del interior de un vehículo pueden propagar el virus intent:can_i_get_from_packages_surfaces 83 | qué puede pasarme si visito una zona de riesgo por corona virus intent:what_if_i_visited_high_risk_area 84 | reglas para visitar zonas de riesgo de covid intent:what_if_i_visited_high_risk_area 85 | pude contraer coronavirus si estuve en una zona de riesgo intent:what_if_i_visited_high_risk_area 86 | cuáles son las normativas que debo seguir si volví de una zona de riesgo intent:what_if_i_visited_high_risk_area 87 | que debe tener en cuenta alguien que estuvo hace poco en españa intent:what_if_i_visited_high_risk_area 88 | españa es zona de alto riesgo de covid 19 intent:what_if_i_visited_high_risk_area 89 | qué pasa si he estado de viaje en roma intent:what_if_i_visited_high_risk_area 90 | cuál es el protocolo si estuve en tránsito en zona de alto riesgo intent:what_if_i_visited_high_risk_area 91 | hay algún tratamiento eficaz contra el corona virus intent:what_are_treatment_options 92 | cómo dar tratamiento a un infectado de covid intent:what_are_treatment_options 93 | cómo tratan los hospitales a los enfermos del covid intent:what_are_treatment_options 94 | si me da coronavirus, ¿cómo me voy a curar intent:what_are_treatment_options 95 | cómo están curando a los pacientes intent:what_are_treatment_options 96 | puedo tomarme un té para elevar mis defensas contra el coronavirus intent:what_are_treatment_options 97 | la vacuna de la influenza sirve para el coronavirus intent:what_are_treatment_options 98 | sirve la hidroxicloroquina para matar el coronavirus intent:what_are_treatment_options 99 | es efectiva la vacuna contra la gripe para lidiar con el corona virus intent:what_are_treatment_options 100 | qué información sobre covid resultó ser mito intent:myths 101 | es cierto que el coronavirus es un invento del gobierno para matar a los viejitos intent:myths 102 | quiero saber cuáles son los mitos más importantes sobre el virus y su desmentida intent:myths 103 | decime el top tres de mitos sobre el covid-19 intent:myths 104 | los anteojos de sol, ¿ayudan a evitar el contagio del virus intent:myths 105 | necesito saber qué información es ficticia sobre el coronavirus intent:myths 106 | que noticias sobre el coronavirus son mitos intent:myths 107 | te es posible decirme si los servicios de tren están cancelados en italia intent:travel 108 | medidas de seguridad para viajeros intent:travel 109 | cuándo la gente va a poder viajar otra vez intent:travel 110 | en septiembre ya habrá viajes a europa intent:travel 111 | nombrar las recomendaciones para ir al aeropuerto intent:travel 112 | están funcionando los taxis estos días intent:travel 113 | qué tengo que hacer para viajar en avión de forma segura en una pandemia intent:travel 114 | están más baratos los vuelos a causa del covid19 intent:travel 115 | el subterráneo sigue funcionando intent:travel 116 | tienes consejos para viajar durante la epidemia intent:travel 117 | es buena idea pedir un lyft ahora que existe el virus intent:travel 118 | cómo va el coronavirus? ¿qué dicen las noticias intent:news_and_press 119 | cuáles son las noticias destacadas de hoy en el mundo sobre covid intent:news_and_press 120 | dame los titulares de las principales noticias sobre coronavirus intent:news_and_press 121 | cuáles fueron las noticias sobre el coronavirus hoy en argentina intent:news_and_press 122 | muéstreme las últimas noticias sobre el virus en colombia intent:news_and_press 123 | quiero tener actualizaciones diarias sobre el virus en canadá intent:news_and_press 124 | dame info sobre el virus en australia intent:news_and_press 125 | comparte en mi muro intent:share 126 | dale like a este comentario y compártelo intent:share 127 | podrías compartir esto que me acabás de decir intent:share 128 | compartamos esto en mis redes intent:share 129 | necesito que todos sepan lo que acabas de contarme, compartámoslo intent:share 130 | lo que me estás contando, compártelo intent:share 131 | usa mis redes sociales para compartir este servicio intent:share 132 | compartir la noticia con mis amigos intent:share 133 | puedes compartirle a mis familiares esta información intent:share 134 | cómo se hace para que todos en mis redes tengan este servicio intent:share 135 | compartir el vínculo intent:share 136 | compártale el asistente de corona virus a mi familia intent:share 137 | buenas noches, compadre intent:hi 138 | cómo está mi informante intent:hi 139 | qué más, coronin intent:hi 140 | cómo va eso, corona intent:hi 141 | hola, asistente corona virus intent:hi 142 | ¡qué tal intent:hi 143 | cómo dice que le va, doctor digital intent:hi 144 | ¡hola, corona intent:hi 145 | millones de gracias por la información intent:okay_thanks 146 | excelente, gracias intent:okay_thanks 147 | me queda claro, muchas gracias intent:okay_thanks 148 | gracias por toda la información intent:okay_thanks 149 | oye, doctor, que sepas que agradezco la información intent:okay_thanks 150 | sabes que siento agradecimiento hacia ti intent:okay_thanks 151 | estoy muy agradecida por los datos que me diste intent:okay_thanks 152 | gracias por decirme, muy útil intent:okay_thanks 153 | un profundo agradecimiento intent:okay_thanks 154 | dame información de cómo puedo hacer una donación intent:donate 155 | quisiera realizar una donación intent:donate 156 | por favor, aplica una donación de 1000 pesos a esta propuesta intent:donate 157 | quiero donar elementos de limpieza al pueblo de mi mamá intent:donate 158 | querría que hicieras una donación a la provincia de buenos aires, por favor intent:donate 159 | me ayudarías a donar parte de mis ahorros intent:donate 160 | ayúdame a enviar dinero a san juan intent:donate 161 | me gustaría regalarles dinero a los más afectados por el covid intent:donate 162 | -------------------------------------------------------------------------------- /data/mcid/original/fr/eval.tsv: -------------------------------------------------------------------------------- 1 | combien d'enfants infectés par coronavirus en europe intent:latest_numbers 2 | combien de gens sont hospitalisés en russie intent:latest_numbers 3 | derniers chiffres coronavirus intent:latest_numbers 4 | pourcentage de personnes guéries ontario canada intent:latest_numbers 5 | statistiques covid-19 les plus récentes par pays intent:latest_numbers 6 | nouveaux cas de covid-19 cette semaine en chine intent:latest_numbers 7 | statistiques du covid-19 par état aux états-unis intent:latest_numbers 8 | évolution nombre d'hospitalisations france intent:latest_numbers 9 | dernières statistiques d'hospitalisations au canada intent:latest_numbers 10 | quel est le nombre de personnes infectées à travers le monde intent:latest_numbers 11 | masque pour enfant intent:protect_yourself 12 | combien de temps dois-je garder un masque intent:protect_yourself 13 | puis-je réutiliser un masque intent:protect_yourself 14 | les masques en tissus sont-ils efficaces intent:protect_yourself 15 | a combien de dégrés laver les habits pour tuer le virus intent:protect_yourself 16 | secourir une personne en période de distanciation sociale intent:protect_yourself 17 | traitements préventifs intent:protect_yourself 18 | méthodes de protection covid-19 intent:protect_yourself 19 | merci encore une fois intent:okay_thanks 20 | je te remercie énormement doc covid intent:okay_thanks 21 | merci beaucoup docteur bot intent:okay_thanks 22 | tu es très efficace, merci intent:okay_thanks 23 | merci pour les renseignements intent:okay_thanks 24 | merci pour l'explication intent:okay_thanks 25 | parfait merci bien intent:okay_thanks 26 | je te remercie pour toutes ces infos intent:okay_thanks 27 | mille mercis pour ces infos intent:okay_thanks 28 | merci énormément mon ami intent:okay_thanks 29 | cimer le doc intent:okay_thanks 30 | je te remercie intent:okay_thanks 31 | merci infiniment intent:okay_thanks 32 | dis-donc, t'as réponse à tout toi intent:okay_thanks 33 | merci de ton aide intent:okay_thanks 34 | merci je te revaudrai ça intent:okay_thanks 35 | don pour les victimes du coronavirus au kenya intent:donate 36 | faire un don pour lutter contre le covid-19 intent:donate 37 | comment faire un don pour les personnes touchées intent:donate 38 | urgence sanitaire, comment aider financièrement intent:donate 39 | faire une contribution pour la corée intent:donate 40 | je veux verser une somme pour soutenir la recherche sur le coronavirus intent:donate 41 | aide financière coronavirus intent:donate 42 | donne à la fondation louis pasteur intent:donate 43 | pays origine coronavirus intent:what_is_corona 44 | début du covid intent:what_is_corona 45 | c’est une maladie respiratoire intent:what_is_corona 46 | que veut dire covid-19 intent:what_is_corona 47 | c'est quoi le nouveau coronavirus intent:what_is_corona 48 | différence coronavirus grippe intent:what_is_corona 49 | déclaration pandémie oms covid-19 intent:what_is_corona 50 | peut-on guérir du coronavirus intent:what_is_corona 51 | test covid-19 france intent:what_is_corona 52 | quel traitement contre le coronavirus intent:what_is_corona 53 | puis-je avoir des renseignements sur le covid19 intent:what_is_corona 54 | le coronavirus est-il plus mortel que la grippe saisonnière intent:what_is_corona 55 | c'est quoi les premiers symptômes lors de l'infection intent:what_are_symptoms 56 | peut-on tomber malade du coronavirus une deuxième fois intent:what_are_symptoms 57 | quand apparaissent les symptômes intent:what_are_symptoms 58 | est-ce que j'ai le coronavirus intent:what_are_symptoms 59 | pas de symptômes corona intent:what_are_symptoms 60 | quelles différences entre les symptômes de la grippe et du coronavirus intent:what_are_symptoms 61 | symptômes les plus fréquents intent:what_are_symptoms 62 | puis-je m'autodiagnostiquer du covid19 intent:what_are_symptoms 63 | peut-on attraper le virus en jouant dans le sable intent:how_does_corona_spread 64 | comment les enfants partagent le virus intent:how_does_corona_spread 65 | dis-moi comment se transmet le covid-19 intent:how_does_corona_spread 66 | le coronavirus peut-il se propager via la nourriture intent:how_does_corona_spread 67 | qu'est-ce que la transmission communautaire intent:how_does_corona_spread 68 | informations sur la propagation du coronavirus intent:how_does_corona_spread 69 | propagation du virus dans le milieu médical intent:how_does_corona_spread 70 | que veut dire la charge virale du virus intent:how_does_corona_spread 71 | quels sont les moyens de propagation confirmés pour le covid-19 intent:how_does_corona_spread 72 | le covid-19 se transmet-il par la transpiration intent:how_does_corona_spread 73 | est-ce qu’il faut tester mon chien intent:can_i_get_from_feces_animal_pets 74 | chien et chat covid 19 intent:can_i_get_from_feces_animal_pets 75 | dois-je me laver les mains après avoir joué avec mon chien intent:can_i_get_from_feces_animal_pets 76 | contamination au covid par voie alimentaire intent:can_i_get_from_feces_animal_pets 77 | quels animaux peuvent contaminer un homme du coronavirus intent:can_i_get_from_feces_animal_pets 78 | est-ce que les lapins transmettent le coronavirus intent:can_i_get_from_feces_animal_pets 79 | quel est le risque d'attraper le covid-19 par un animal de compagnie intent:can_i_get_from_feces_animal_pets 80 | durée de vie coronavirus sur le béton intent:can_i_get_from_packages_surfaces 81 | durée de vie du coronavirus sur plastique intent:can_i_get_from_packages_surfaces 82 | coronavirus, durée vie, verre intent:can_i_get_from_packages_surfaces 83 | combien de temps le coronavirus reste sur les objets intent:can_i_get_from_packages_surfaces 84 | survie du corona sur les vêtements intent:can_i_get_from_packages_surfaces 85 | temps de survie corona boîtes en carton intent:can_i_get_from_packages_surfaces 86 | le virus survit-il sur les plantes intent:can_i_get_from_packages_surfaces 87 | le covid-19 survit-il au frigo intent:can_i_get_from_packages_surfaces 88 | est-il conseillé de désinfecter mes courses avant de les ranger intent:can_i_get_from_packages_surfaces 89 | combien de temps le coronavirus survit-il sur une matière non organique intent:can_i_get_from_packages_surfaces 90 | le coronavirus survit-il sur des surfaces très chaudes intent:can_i_get_from_packages_surfaces 91 | nombre de pays à haut risque intent:what_if_i_visited_high_risk_area 92 | coronavirus endroits à éviter intent:what_if_i_visited_high_risk_area 93 | la chine est-elle encore une zone à haut risque intent:what_if_i_visited_high_risk_area 94 | puis-je retourner sans risques de france à wuhan après le déconfinement là-bas intent:what_if_i_visited_high_risk_area 95 | quels sont les procédures de confinement pour les citoyens français rapatriés de zones à risque comme la ville de new-york intent:what_if_i_visited_high_risk_area 96 | où sont les zones covid 19 à haut risques intent:what_if_i_visited_high_risk_area 97 | quelles villes sont les plus risquées en france pour le covid 19 intent:what_if_i_visited_high_risk_area 98 | où est -il interdit d'aller car trop de risque à cause du coronavirus intent:what_if_i_visited_high_risk_area 99 | remède maison coronavirus intent:what_are_treatment_options 100 | aliments pouvant combattre le coronavirus intent:what_are_treatment_options 101 | faire du sport peut-il aider à guérir le coronavirus intent:what_are_treatment_options 102 | l'ibuprofène est-il recommandé pour traiter la fièvre si on est atteint par le covid-19 intent:what_are_treatment_options 103 | les remèdes naturels guérissent-ils le virus intent:what_are_treatment_options 104 | où en sont les recherches de vaccin contre le coronavirus intent:what_are_treatment_options 105 | est-ce la chloroquine utilisée pour traiter les le coronavirus intent:what_are_treatment_options 106 | exercices yoga toux douloureuse intent:what_are_treatment_options 107 | combien de temps de repos après infection au coronavirus intent:what_are_treatment_options 108 | quelles sont mes options pour traiter une infection virale intent:what_are_treatment_options 109 | immunostimulants à essayer pour se proteger du covid-19 intent:what_are_treatment_options 110 | covid-19 et aspirine intent:what_are_treatment_options 111 | quels médicaments peut-on prendre quand on a le coronavirus intent:what_are_treatment_options 112 | est-ce que je peux prendre un anti-migraineux alors que je suis contaminée intent:what_are_treatment_options 113 | quand peut-on espérer avoir un vaccin intent:what_are_treatment_options 114 | quels médicaments efficaces contre corona intent:what_are_treatment_options 115 | est-ce que les huiles essentielles sont efficaces contre le coronavirus intent:what_are_treatment_options 116 | antibiotiques les plus efficaces contre corona intent:what_are_treatment_options 117 | est-ce un mythe que le virus est faible dans les pays chaud intent:myths 118 | le coronavirus est-il réellement créé dans le labo p4 intent:myths 119 | manger de l'ail empêche la contamination intent:myths 120 | le pangolin est-il un mythe dans l'incubation du virus intent:myths 121 | quels sont les mythes sur la propagation du virus en inde intent:myths 122 | il faut boire beaucoup … d’alcool – il tue le coronavirus intent:myths 123 | le covid 19 provient des chauves-souris intent:myths 124 | des états totalitaires ont importé et ont fait circuler de manière délibérée le coronavirus dans leur pays pour se débarrasser de l'opposition intent:myths 125 | mythe coronavirus a été créé dans un laboratoire intent:myths 126 | corona pangolin ou chauve-souris intent:myths 127 | puis-je prendre le métro à toronto intent:travel 128 | comment faire rembourser mon trajet vers une région à risque ? paris intent:travel 129 | conseils de sécurité pour les trajets professionnels intent:travel 130 | quel est le moyen de transport à éviter durant l'épidémie de covid-19 intent:travel 131 | puis-je quitter ma ville pendant l'épidémie de covid 19 intent:travel 132 | risque contamination voyage métro intent:travel 133 | est-ce que je peux aller en chine demain intent:travel 134 | voyager en europe cet été intent:travel 135 | de combien dois-je repousser mon voyage intent:travel 136 | remboursement vol annulé corona intent:travel 137 | peux-tu me communiquer les dernières informations disponibles au sujet du covid19 intent:news_and_press 138 | dernières mises à jour corona en chine intent:news_and_press 139 | peux-tu me donner l'actualité sur le covid19 intent:news_and_press 140 | informations de dernière minute sur le coronavirus intent:news_and_press 141 | derniers communiqués de presse concernant le coronavirus intent:news_and_press 142 | donne-moi les infos les plus récentes sur l'épidémie de corona en asie intent:news_and_press 143 | peux-tu me trouver les dernières informations sur le coronavirus aux etats-unis intent:news_and_press 144 | fais-moi voir les infos du covid d’hier intent:news_and_press 145 | news d'aujourd'hui sur le corona virus intent:news_and_press 146 | est-ce qu’il y a du nouveau sur le virus intent:news_and_press 147 | que s’est-il passé ces dernières heures intent:news_and_press 148 | partage ce lien sur snap intent:share 149 | partager ce service sur fb intent:share 150 | partager ces infos sur insta intent:share 151 | peux-tu publier ta réponse sur mon facebook intent:share 152 | peux-tu transmettre cela à mon père s'il te plaît intent:share 153 | comment faire pour partager tes réponses avec mes amis sur twitter intent:share 154 | lance-toi sur les réseaux sociaux intent:share 155 | comment te partager avec mes amis intent:share 156 | partage ton lien sur instagram intent:share 157 | partage les services de ce bot avec mes contacts intent:share 158 | il faut mettre cet article corona sur mon mur facebook intent:share 159 | coucou docteur intent:hi 160 | bonjour médecin covid intent:hi 161 | coucou corona intent:hi 162 | salut, comment ça va bot intent:hi 163 | tu vas bien docteur covid intent:hi 164 | comment vas-tu doc intent:hi 165 | hey monsieur info corona intent:hi 166 | bonjour infocorona intent:hi 167 | hey virus info intent:hi 168 | y'a t'il un docteur ici intent:hi 169 | salut spécialiste du corona intent:hi 170 | salut mon pote le bot intent:hi 171 | ave, corona, ceux qui vont mourir te saluent intent:hi 172 | hey médecin coronavirus, comment vas tu intent:hi 173 | ça va doc intent:hi 174 | -------------------------------------------------------------------------------- /data/mcid/showDataset.py: -------------------------------------------------------------------------------- 1 | # this script extracts a sub word embedding from the entire one 2 | from gensim.models.keyedvectors import KeyedVectors 3 | import numpy as np 4 | import scipy.io as sio 5 | import nltk 6 | import pdb 7 | import random 8 | import csv 9 | import string 10 | import contractions 11 | import json 12 | import random 13 | 14 | def getDomainIntent(domainLabFile): 15 | domain2lab = {} 16 | lab2domain = {} 17 | currentDomain = None 18 | with open(domainLabFile,'r') as f: 19 | for line in f: 20 | if ':' in line and currentDomain == None: 21 | currentDomain = cleanUpSentence(line) 22 | domain2lab[currentDomain] = [] 23 | elif line == "\n": 24 | currentDomain = None 25 | else: 26 | intent = cleanUpSentence(line) 27 | domain2lab[currentDomain].append(intent) 28 | 29 | for key in domain2lab: 30 | domain = key 31 | labList = domain2lab[key] 32 | for lab in labList: 33 | lab2domain[lab] = domain 34 | 35 | return domain2lab, lab2domain 36 | 37 | def cleanUpSentence(sentence): 38 | # sentence: a string, like " Hello, do you like apple? I hate it!! " 39 | 40 | # strip 41 | sentence = sentence.strip() 42 | 43 | # lower case 44 | sentence = sentence.lower() 45 | 46 | # fix contractions 47 | sentence = contractions.fix(sentence) 48 | 49 | # remove '_' and '-' 50 | sentence = sentence.replace('-',' ') 51 | sentence = sentence.replace('_',' ') 52 | 53 | # remove all punctuations 54 | sentence = ''.join(ch for ch in sentence if ch not in string.punctuation) 55 | 56 | return sentence 57 | 58 | 59 | def check_data_format(file_path): 60 | for line in open(file_path,'rb'): 61 | arr =str(line.strip(),'utf-8') 62 | arr = arr.split('\t') 63 | label = [w for w in arr[0].split(' ')] 64 | question = [w for w in arr[1].split(' ')] 65 | 66 | if len(label) == 0 or len(question) == 0: 67 | print("[ERROR] Find empty data: ", label, question) 68 | return False 69 | 70 | return True 71 | 72 | 73 | def save_data(data, file_path): 74 | # save data to disk 75 | with open(file_path, 'w') as f: 76 | json.dump(data, f) 77 | return 78 | 79 | def save_domain_intent(data, file_path): 80 | domain2intent = {} 81 | for line in data: 82 | domain = line[0] 83 | intent = line[1] 84 | 85 | if not domain in domain2intent: 86 | domain2intent[domain] = set() 87 | 88 | domain2intent[domain].add(intent) 89 | 90 | # save data to disk 91 | print("Saving domain intent out ... format: domain \t intent") 92 | with open(file_path,"w") as f: 93 | for domain in domain2intent: 94 | intentSet = domain2intent[domain] 95 | for intent in intentSet: 96 | f.write("%s\t%s\n" % (domain, intent)) 97 | return 98 | 99 | def display_data(data): 100 | # dataset count 101 | print("[INFO] We have %d dataset."%(len(data))) 102 | 103 | datasetName = 'MCID' 104 | data = data[datasetName] 105 | 106 | # domain count 107 | domainName = set() 108 | for domain in data: 109 | domainName.add(domain) 110 | print("[INFO] There are %d domains."%(len(domainName))) 111 | print(domainName) 112 | 113 | # intent count 114 | intentName = set() 115 | for domain in data: 116 | for d in data[domain]: 117 | lab = d[1][0] 118 | intentName.add(lab) 119 | intentName = list(intentName) 120 | intentName.sort() 121 | print("[INFO] There are %d intent."%(len(intentName))) 122 | print(intentName) 123 | 124 | # data count 125 | count = 0 126 | for domain in data: 127 | for d in data[domain]: 128 | count = count+1 129 | print("[INFO] Data count: %d"%(count)) 130 | 131 | # intent for each domain 132 | domain2intentDict = {} 133 | for domain in data: 134 | if not domain in domain2intentDict: 135 | domain2intentDict[domain] = set() 136 | 137 | for d in data[domain]: 138 | lab = d[1][0] 139 | domain2intentDict[domain].add(lab) 140 | print("[INFO] Intent for each domain.") 141 | print(domain2intentDict) 142 | 143 | # data for each intent 144 | intent2count = {} 145 | for domain in data: 146 | for d in data[domain]: 147 | lab = d[1][0] 148 | if not lab in intent2count: 149 | intent2count[lab] = 0 150 | intent2count[lab] = intent2count[lab]+1 151 | print("[INFO] Intent count") 152 | print(intent2count) 153 | 154 | # examples of data 155 | exampleNum = 3 156 | while not exampleNum == 0: 157 | for domain in data: 158 | for d in data[domain]: 159 | lab = d[1] 160 | utt = d[0] 161 | if random.random() < 0.001: 162 | print("[INFO] Example:--%s, %s, %s, %s"%(datasetName, domain, lab, utt)) 163 | exampleNum = exampleNum-1 164 | break 165 | if (exampleNum==0): 166 | break 167 | 168 | return None 169 | 170 | 171 | ## 172 | # @brief clean up data, including intent and utterance 173 | # 174 | # @param data a list of data 175 | # 176 | # @return 177 | def cleanData(data): 178 | newData = [] 179 | for d in data: 180 | utt = d[0] 181 | lab = d[1] 182 | 183 | uttClr = cleanUpSentence(utt) 184 | labClr = cleanUpSentence(lab) 185 | newData.append([labClr, uttClr]) 186 | 187 | return newData 188 | 189 | def constructData(data, intent2domain): 190 | dataset2domain = {} 191 | datasetName = 'CLINC150' 192 | dataset2domain[datasetName] = {} 193 | for d in data: 194 | lab = d[0] 195 | utt = d[1] 196 | domain = intent2domain[lab] 197 | if not domain in dataset2domain[datasetName]: 198 | dataset2domain[datasetName][domain] = [] 199 | dataField = [utt, [lab]] 200 | dataset2domain[datasetName][domain].append(dataField) 201 | 202 | return dataset2domain 203 | 204 | 205 | def read_data(file_path): 206 | with open(file_path) as json_file: 207 | data = json.load(json_file) 208 | return data 209 | 210 | # read in data 211 | #dataPath = "/data1/haode/projects/EMDIntentFewShot/SPIN_refactor/data/refactor_OOS/dataset.json" 212 | dataPath = "./dataset.json" 213 | print("Loading data ...", dataPath) 214 | # read lines, collect data count for different classes 215 | data = read_data(dataPath) 216 | 217 | display_data(data) 218 | print("Display.. done") 219 | -------------------------------------------------------------------------------- /data/oos/original/domain_intent.txt: -------------------------------------------------------------------------------- 1 | banking: 2 | transfer 3 | transactions 4 | balance 5 | freeze account 6 | pay bill 7 | bill balance 8 | bill due 9 | interest rate 10 | routing 11 | min payment 12 | order checks 13 | pin change 14 | report fraud 15 | spending history 16 | account blocked 17 | 18 | credit cards: 19 | credit limit 20 | credit limit change 21 | international fees 22 | expiration date 23 | credit score 24 | replacement card duration 25 | card declined 26 | improve credit score 27 | report lost card 28 | damaged card 29 | redeem rewards 30 | apr 31 | application status 32 | new card 33 | rewards balance 34 | 35 | kitchen & dining: 36 | recipe 37 | restaurant reservation 38 | restaurant reviews 39 | nutrition info 40 | ingredients list 41 | cancel reservation 42 | how busy 43 | ingredient substitution 44 | confirm reservation 45 | meal suggestion 46 | accept reservations 47 | restaurant suggestion 48 | cook time 49 | calories 50 | food last 51 | 52 | home: 53 | shopping list 54 | shopping list update 55 | smart home 56 | reminder 57 | reminder update 58 | order 59 | order status 60 | what song 61 | todo list 62 | todo list update 63 | next song 64 | play music 65 | update playlist 66 | calendar 67 | calendar update 68 | 69 | auto & commute: 70 | traffic 71 | tire change 72 | tire pressure 73 | last maintenance 74 | schedule maintenance 75 | uber 76 | jump start 77 | oil change how 78 | oil change when 79 | mpg 80 | current location 81 | distance 82 | gas 83 | gas type 84 | directions 85 | 86 | travel: 87 | travel alert 88 | flight status 89 | travel notification 90 | travel suggestion 91 | exchange rate 92 | plug type 93 | lost luggage 94 | international visa 95 | translate 96 | timezone 97 | vaccines 98 | book flight 99 | book hotel 100 | car rental 101 | carry on 102 | 103 | utility: 104 | roll dice 105 | definition 106 | flip coin 107 | measurement conversion 108 | timer 109 | spelling 110 | date 111 | text 112 | weather 113 | make call 114 | time 115 | share location 116 | calculator 117 | alarm 118 | find phone 119 | 120 | work: 121 | direct deposit 122 | pto balance 123 | insurance 124 | insurance change 125 | next holiday 126 | w2 127 | payday 128 | taxes 129 | meeting schedule 130 | pto request 131 | income 132 | rollover 401k 133 | schedule meeting 134 | pto request status 135 | pto used 136 | 137 | small talk: 138 | goodbye 139 | where are you from 140 | greeting 141 | tell joke 142 | how old are you 143 | what is your name 144 | who made you 145 | thank you 146 | what can i ask you 147 | do you have pets 148 | what are your hobbies 149 | are you a bot 150 | meaning of life 151 | who do you work for 152 | fun fact 153 | 154 | meta: 155 | cancel 156 | change accent 157 | change ai name 158 | change language 159 | change speed 160 | change user name 161 | change volume 162 | maybe 163 | no 164 | repeat 165 | reset settings 166 | sync device 167 | user name 168 | whisper mode 169 | yes 170 | -------------------------------------------------------------------------------- /data/oos/original/oos-eval-master/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.json~ 3 | *.csv~ 4 | .DS_Store 5 | -------------------------------------------------------------------------------- /data/oos/original/oos-eval-master/README.md: -------------------------------------------------------------------------------- 1 | [![Clinc](clinc_logo.png)](https://clinc.com) 2 | 3 | # An Evaluation Dataset for Intent Classification and Out-of-Scope Prediction 4 | Repository that accompanies [An Evaluation Dataset for Intent Classification and Out-of-Scope Prediction](https://www.aclweb.org/anthology/D19-1131/). 5 | 6 | 7 | ## FAQs 8 | ### 1. What are the relevant files? 9 | See `data/data_full.json` for the "full" dataset. This is the dataset used in Table 1 (the "Full" columns). This file contains 150 "in-scope" intent classes, each with 100 train, 20 validation, and 30 test samples. There are 100 train and validation out-of-scope samples, and 1000 out-of-scope test samples. 10 | 11 | ### 2. What is the name of the dataset? 12 | The dataset was not given a name in the original paper, but [others](https://arxiv.org/pdf/2003.04807.pdf) have called it `CLINC150`. 13 | 14 | ### 3. What is this dataset for? 15 | This dataset is for evaluating the performance of intent classification systems in the presence of "out-of-scope" queries. By "out-of-scope", we mean queries that do not fall into any of the system-supported intent classes. Most datasets include only data that is "in-scope". Our dataset includes both in-scope and out-of-scope data. You might also know the term "out-of-scope" by other terms, including "out-of-domain" or "out-of-distribution". 16 | 17 | ### 4. What language is the dataset in? 18 | All queries are in English. 19 | 20 | ### 5. How does your dataset/evaluation handle multi-intent queries? 21 | All samples/queries in our dataset are single-intent samples. We consider the problem of multi-intent classification to be future work. 22 | 23 | ### 6. How did you gather the dataset? 24 | We used crowdsourcing to generate the dataset. We asked crowd workers to either paraphrase "seed" phrases, or respond to scenarios (e.g. "pretend you need to book a flight, what would you say?"). We used crowdsourcing to generate data for both in-scope and out-of-scope data. 25 | 26 | ## Citation 27 | 28 | If you find our dataset useful, please be sure to cite: 29 | 30 | ``` 31 | @inproceedings{larson-etal-2019-evaluation, 32 | title = "An Evaluation Dataset for Intent Classification and Out-of-Scope Prediction", 33 | author = "Larson, Stefan and 34 | Mahendran, Anish and 35 | Peper, Joseph J. and 36 | Clarke, Christopher and 37 | Lee, Andrew and 38 | Hill, Parker and 39 | Kummerfeld, Jonathan K. and 40 | Leach, Kevin and 41 | Laurenzano, Michael A. and 42 | Tang, Lingjia and 43 | Mars, Jason", 44 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)", 45 | year = "2019", 46 | url = "https://www.aclweb.org/anthology/D19-1131" 47 | } 48 | ``` 49 | -------------------------------------------------------------------------------- /data/oos/original/oos-eval-master/clinc_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/data/oos/original/oos-eval-master/clinc_logo.png -------------------------------------------------------------------------------- /data/oos/original/oos-eval-master/hyperparameters.csv: -------------------------------------------------------------------------------- 1 | ,,,,,,,,,,, 2 | SVM-OH Hyperparameters,,,,,,,,,,, 3 | ,Full,classifier,vectorizer,kernel,C,,,,,, 4 | ,,svm,onehot,linear,1,,,,,, 5 | ,,,,,,,,,,, 6 | ,Small,classifier,vectorizer,kernel,C,,,,,, 7 | ,,svm,onehot,linear,1,,,,,, 8 | ,,,,,,,,,,, 9 | ,Imbalanced,classifier,vectorizer,kernel,C,,,,,, 10 | ,,svm,onehot,linear,1,,,,,, 11 | ,,,,,,,,,,, 12 | ,OOS+,classifier,vectorizer,kernel,C,,,,,, 13 | ,,svm,onehot,linear,1,,,,,, 14 | ,,,,,,,,,,, 15 | ,UnderSample Binary,classifier,vectorizer,kernel,C,,,,,, 16 | ,,svm,onehot,linear,1,,,,,, 17 | ,,,,,,,,,,, 18 | ,Wiki Aug Binary,classifier,vectorizer,kernel,C,,,,,, 19 | ,,svm,onehot,linear,1,,,,,, 20 | ,,,,,,,,,,, 21 | ,,,,,,,,,,, 22 | FastText Hyperparameters,,,,,,,,,,, 23 | ,Full,classifier,learning_rate,dim,word_ngrams,lr_update_rate,ws,loss,,, 24 | ,,fasttext,1,100,3,100,3,ns,,, 25 | ,,,,,,,,,,, 26 | ,Small,classifier,learning_rate,dim,word_ngrams,lr_update_rate,ws,loss,,, 27 | ,,fasttext,0.1,200,1,200,7,softmax,,, 28 | ,,,,,,,,,,, 29 | ,Imbalanced,classifier,learning_rate,dim,word_ngrams,lr_update_rate,ws,loss,,, 30 | ,,fasttext,1,400,2,200,7,ns,,, 31 | ,,,,,,,,,,, 32 | ,OOS+,classifier,learning_rate,dim,word_ngrams,lr_update_rate,ws,loss,,, 33 | ,,fasttext,1,50,2,200,7,ns,,, 34 | ,,,,,,,,,,, 35 | ,UnderSample Binary,classifier,learning_rate,dim,word_ngrams,lr_update_rate,ws,loss,,, 36 | ,,fasttext,0.1,200,4,100,3,ns,,, 37 | ,,,,,,,,,,, 38 | ,Wiki Aug Binary,classifier,learning_rate,dim,word_ngrams,lr_update_rate,ws,loss,,, 39 | ,,fasttext,1,400,4,100,3,ns,,, 40 | ,,,,,,,,,,, 41 | ,,,,,,,,,,, 42 | CNN Hyperparameters,,,,,,,,,,, 43 | ,Full,classifier,num_filters,conv_activation,f_activation,strides,padding,dense_dim,dropout,s_activation,batch_size 44 | ,,cnn,200,softmax,relu,1,valid,300,0.6,softmax,16 45 | ,,,,,,,,,,, 46 | ,Small,classifier,num_filters,conv_activation,f_activation,strides,padding,dense_dim,dropout,s_activation,batch_size 47 | ,,cnn,150,softmax,relu,1,valid,200,0.6,softmax,16 48 | ,,,,,,,,,,, 49 | ,Imbalanced,classifier,num_filters,conv_activation,f_activation,strides,padding,dense_dim,dropout,s_activation,batch_size 50 | ,,cnn,150,softmax,relu,1,valid,300,0.6,softmax,16 51 | ,,,,,,,,,,, 52 | ,OOS+,classifier,num_filters,conv_activation,f_activation,strides,padding,dense_dim,dropout,s_activation,batch_size 53 | ,,cnn,200,softmax,relu,1,valid,200,0.6,softmax,16 54 | ,,,,m,,,,,,, 55 | ,UnderSample Binary,classifier,num_filters,conv_activation,f_activation,strides,padding,dense_dim,dropout,s_activation,batch_size 56 | ,,cnn,150,softmax,relu,1,valid,100,0.2,softmax,16 57 | ,,,,,,,,,,, 58 | ,Wiki Aug Binary,classifier,num_filters,conv_activation,f_activation,strides,padding,dense_dim,dropout,s_activation,batch_size 59 | ,,cnn,200,softmax,relu,1,valid,300,0.2,softmax,16 60 | ,,,,,,,,,,, 61 | ,,,,,,,,,,, 62 | MLP Hyperparameters,,,,,,,,,,, 63 | ,Full,classifier,f_hidden_activation,s_hidden_activation,hidden_dim,vectorizer,batch_size,dropout,,, 64 | ,,mlp,tanh,softmax,400,use,1,0,,, 65 | ,,,,,,,,,,, 66 | ,Small,classifier,f_hidden_activation,s_hidden_activation,hidden_dim,vectorizer,batch_size,dropout,,, 67 | ,,mlp,tanh,softmax,200,use,1,0.1,,, 68 | ,,,,,,,,,,, 69 | ,Imbalanced,classifier,f_hidden_activation,s_hidden_activation,hidden_dim,vectorizer,batch_size,dropout,,, 70 | ,,mlp,tanh,softmax,200,use,64,0,,, 71 | ,,,,,,,,,,, 72 | ,OOS+,classifier,f_hidden_activation,s_hidden_activation,hidden_dim,vectorizer,batch_size,dropout,,, 73 | ,,mlp,tanh,softmax,200,use,16,0.1,,, 74 | ,,,,,,,,,,, 75 | ,UnderSample Binary,classifier,f_hidden_activation,s_hidden_activation,hidden_dim,vectorizer,batch_size,dropout,,, 76 | ,,mlp,tanh,softmax,100,use,64,0,,, 77 | ,,,,,,,,,,, 78 | ,Wiki Aug Binary,classifier,f_hidden_activation,s_hidden_activation,hidden_dim,vectorizer,batch_size,dropout,,, 79 | ,,mlp,tanh,softmax,300,use,16,0,,, 80 | ,,,,,,,,,,, 81 | ,,,,,,,,,,, 82 | BERT Hyperparameters,,,,,,,,,,, 83 | ,Full,classifier,learning_rate,warmup_proportion,train_batch_size,num_train_epochs,gradient_accumulation_steps,bert_model,,, 84 | ,,bert,4.00E-05,0.1,32,5,1,bert-large-uncased,,, 85 | ,,,,,,,,,,, 86 | ,Small,classifier,learning_rate,warmup_proportion,train_batch_size,num_train_epochs,gradient_accumulation_steps,bert_model,,, 87 | ,,bert,4.00E-05,0.1,32,5,1,bert-large-uncased,,, 88 | ,,,,,,,,,,, 89 | ,Imbalanced,classifier,learning_rate,warmup_proportion,train_batch_size,num_train_epochs,gradient_accumulation_steps,bert_model,,, 90 | ,,bert,4.00E-05,0.1,32,5,1,bert-large-uncased,,, 91 | ,,,,,,,,,,, 92 | ,OOS+,classifier,learning_rate,warmup_proportion,train_batch_size,num_train_epochs,gradient_accumulation_steps,bert_model,,, 93 | ,,bert,4.00E-05,0.1,32,5,1,bert-large-uncased,,, 94 | ,,,,,,,,,,, 95 | ,UnderSample Binary,classifier,learning_rate,warmup_proportion,train_batch_size,num_train_epochs,gradient_accumulation_steps,bert_model,,, 96 | ,,bert,2.00E-05,0.1,32,5,1,bert-large-uncased,,, 97 | ,,,,,,,,,,, 98 | ,Wiki Aug Binary,classifier,learning_rate,warmup_proportion,train_batch_size,num_train_epochs,gradient_accumulation_steps,bert_model,,, 99 | ,,bert,4.00E-05,0.1,32,5,1,bert-large-uncased,,, -------------------------------------------------------------------------------- /data/oos/original/oos-eval-master/paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/data/oos/original/oos-eval-master/paper.pdf -------------------------------------------------------------------------------- /data/oos/original/oos-eval-master/poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/data/oos/original/oos-eval-master/poster.pdf -------------------------------------------------------------------------------- /data/oos/original/oos-eval-master/supplementary.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/data/oos/original/oos-eval-master/supplementary.pdf -------------------------------------------------------------------------------- /data/oos/showDataset.py: -------------------------------------------------------------------------------- 1 | # this script extracts a sub word embedding from the entire one 2 | from gensim.models.keyedvectors import KeyedVectors 3 | import numpy as np 4 | import scipy.io as sio 5 | import nltk 6 | import pdb 7 | import random 8 | import csv 9 | import string 10 | import contractions 11 | import json 12 | import random 13 | 14 | def getDomainIntent(domainLabFile): 15 | domain2lab = {} 16 | lab2domain = {} 17 | currentDomain = None 18 | with open(domainLabFile,'r') as f: 19 | for line in f: 20 | if ':' in line and currentDomain == None: 21 | currentDomain = cleanUpSentence(line) 22 | domain2lab[currentDomain] = [] 23 | elif line == "\n": 24 | currentDomain = None 25 | else: 26 | intent = cleanUpSentence(line) 27 | domain2lab[currentDomain].append(intent) 28 | 29 | for key in domain2lab: 30 | domain = key 31 | labList = domain2lab[key] 32 | for lab in labList: 33 | lab2domain[lab] = domain 34 | 35 | return domain2lab, lab2domain 36 | 37 | def cleanUpSentence(sentence): 38 | # sentence: a string, like " Hello, do you like apple? I hate it!! " 39 | 40 | # strip 41 | sentence = sentence.strip() 42 | 43 | # lower case 44 | sentence = sentence.lower() 45 | 46 | # fix contractions 47 | sentence = contractions.fix(sentence) 48 | 49 | # remove '_' and '-' 50 | sentence = sentence.replace('-',' ') 51 | sentence = sentence.replace('_',' ') 52 | 53 | # remove all punctuations 54 | sentence = ''.join(ch for ch in sentence if ch not in string.punctuation) 55 | 56 | return sentence 57 | 58 | 59 | def check_data_format(file_path): 60 | for line in open(file_path,'rb'): 61 | arr =str(line.strip(),'utf-8') 62 | arr = arr.split('\t') 63 | label = [w for w in arr[0].split(' ')] 64 | question = [w for w in arr[1].split(' ')] 65 | 66 | if len(label) == 0 or len(question) == 0: 67 | print("[ERROR] Find empty data: ", label, question) 68 | return False 69 | 70 | return True 71 | 72 | 73 | def save_data(data, file_path): 74 | # save data to disk 75 | with open(file_path, 'w') as f: 76 | json.dump(data, f) 77 | return 78 | 79 | def save_domain_intent(data, file_path): 80 | domain2intent = {} 81 | for line in data: 82 | domain = line[0] 83 | intent = line[1] 84 | 85 | if not domain in domain2intent: 86 | domain2intent[domain] = set() 87 | 88 | domain2intent[domain].add(intent) 89 | 90 | # save data to disk 91 | print("Saving domain intent out ... format: domain \t intent") 92 | with open(file_path,"w") as f: 93 | for domain in domain2intent: 94 | intentSet = domain2intent[domain] 95 | for intent in intentSet: 96 | f.write("%s\t%s\n" % (domain, intent)) 97 | return 98 | 99 | def display_data(data): 100 | # dataset count 101 | print("[INFO] We have %d dataset."%(len(data))) 102 | 103 | datasetName = 'CLINC150' 104 | data = data[datasetName] 105 | 106 | # domain count 107 | domainName = set() 108 | for domain in data: 109 | domainName.add(domain) 110 | print("[INFO] There are %d domains."%(len(domainName))) 111 | print(domainName) 112 | 113 | # intent count 114 | intentName = set() 115 | for domain in data: 116 | for d in data[domain]: 117 | lab = d[1][0] 118 | intentName.add(lab) 119 | intentName = list(intentName) 120 | intentName.sort() 121 | print("[INFO] There are %d intent."%(len(intentName))) 122 | print(intentName) 123 | 124 | # data count 125 | count = 0 126 | for domain in data: 127 | for d in data[domain]: 128 | count = count+1 129 | print("[INFO] Data count: %d"%(count)) 130 | 131 | # intent for each domain 132 | domain2intentDict = {} 133 | for domain in data: 134 | if not domain in domain2intentDict: 135 | domain2intentDict[domain] = set() 136 | 137 | for d in data[domain]: 138 | lab = d[1][0] 139 | domain2intentDict[domain].add(lab) 140 | print("[INFO] Intent for each domain.") 141 | print(domain2intentDict) 142 | 143 | # data for each intent 144 | intent2count = {} 145 | for domain in data: 146 | for d in data[domain]: 147 | lab = d[1][0] 148 | if not lab in intent2count: 149 | intent2count[lab] = 0 150 | intent2count[lab] = intent2count[lab]+1 151 | print("[INFO] Intent count") 152 | print(intent2count) 153 | 154 | # examples of data 155 | exampleNum = 3 156 | while not exampleNum == 0: 157 | for domain in data: 158 | for d in data[domain]: 159 | lab = d[1] 160 | utt = d[0] 161 | if random.random() < 0.001: 162 | print("[INFO] Example:--%s, %s, %s, %s"%(datasetName, domain, lab, utt)) 163 | exampleNum = exampleNum-1 164 | break 165 | if (exampleNum==0): 166 | break 167 | 168 | return None 169 | 170 | 171 | ## 172 | # @brief clean up data, including intent and utterance 173 | # 174 | # @param data a list of data 175 | # 176 | # @return 177 | def cleanData(data): 178 | newData = [] 179 | for d in data: 180 | utt = d[0] 181 | lab = d[1] 182 | 183 | uttClr = cleanUpSentence(utt) 184 | labClr = cleanUpSentence(lab) 185 | newData.append([labClr, uttClr]) 186 | 187 | return newData 188 | 189 | def constructData(data, intent2domain): 190 | dataset2domain = {} 191 | datasetName = 'CLINC150' 192 | dataset2domain[datasetName] = {} 193 | for d in data: 194 | lab = d[0] 195 | utt = d[1] 196 | domain = intent2domain[lab] 197 | if not domain in dataset2domain[datasetName]: 198 | dataset2domain[datasetName][domain] = [] 199 | dataField = [utt, [lab]] 200 | dataset2domain[datasetName][domain].append(dataField) 201 | 202 | return dataset2domain 203 | 204 | 205 | def read_data(file_path): 206 | with open(file_path) as json_file: 207 | data = json.load(json_file) 208 | return data 209 | 210 | # read in data 211 | #dataPath = "/data1/haode/projects/EMDIntentFewShot/SPIN_refactor/data/refactor_OOS/dataset.json" 212 | dataPath = "./dataset.json" 213 | print("Loading data ...", dataPath) 214 | # read lines, collect data count for different classes 215 | data = read_data(dataPath) 216 | 217 | display_data(data) 218 | print("Display.. done") 219 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # This file assembles three popular metric learnign baselines, matching network, prototype network and relation network. 2 | # This file is coded based on train_matchingNet.py. 3 | # coding=utf-8 4 | import torch 5 | import argparse 6 | import time 7 | from transformers import AutoTokenizer 8 | 9 | from utils.models import IntentBERT 10 | from utils.IntentDataset import IntentDataset 11 | from utils.Evaluator import FewShotEvaluator 12 | from utils.commonVar import * 13 | from utils.printHelper import * 14 | from utils.tools import * 15 | from utils.Logger import logger 16 | 17 | def set_seed(seed): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | torch.manual_seed(seed) 21 | 22 | def parseArgument(): 23 | # ==== parse argument ==== 24 | parser = argparse.ArgumentParser(description='Evaluate few-shot performance') 25 | 26 | # ==== model ==== 27 | parser.add_argument('--seed', default=1, type=int) 28 | parser.add_argument('--mode', default='multi-class', 29 | help='Choose from multi-class') 30 | parser.add_argument('--tokenizer', default='bert-base-uncased', 31 | help="Name of tokenizer") 32 | parser.add_argument('--LMName', default='bert-base-uncased', 33 | help='Name for models and path to saved model') 34 | parser.add_argument('--multi_label', action="store_true") 35 | 36 | # ==== dataset ==== 37 | parser.add_argument('--dataDir', 38 | help="Dataset names included in this experiment and separated by comma. " 39 | "For example:'OOS,bank77,hwu64'") 40 | parser.add_argument('--targetDomain', 41 | help='Target domain names and separated by comma') 42 | 43 | # ==== evaluation task ==== 44 | parser.add_argument('--way', type=int, default=5) 45 | parser.add_argument('--shot', type=int, default=2) 46 | parser.add_argument('--query', type=int, default=5) 47 | parser.add_argument('--clsFierName', default='Linear', 48 | help="Classifer name for few-shot evaluation" 49 | "Choose from Linear|SVM|NN|Cosine|MultiLabel") 50 | 51 | # ==== training arguments ==== 52 | parser.add_argument('--disableCuda', action="store_true") 53 | parser.add_argument('--taskNum', type=int, default=500) 54 | 55 | # ==== other things ==== 56 | parser.add_argument('--loggingLevel', default='INFO', 57 | help="python logging level") 58 | 59 | args = parser.parse_args() 60 | 61 | return args 62 | 63 | def main(): 64 | # ======= process arguments ====== 65 | args = parseArgument() 66 | print(args) 67 | 68 | if args.multi_label: 69 | args.clsFierName = "MultiLabel" 70 | 71 | # ==== setup logger ==== 72 | if args.loggingLevel == LOGGING_LEVEL_INFO: 73 | loggingLevel = logging.INFO 74 | elif args.loggingLevel == LOGGING_LEVEL_DEBUG: 75 | loggingLevel = logging.DEBUG 76 | else: 77 | raise NotImplementedError("Not supported logging level %s", args.loggingLevel) 78 | logger.setLevel(loggingLevel) 79 | 80 | # ======= process data ====== 81 | # tokenizer 82 | tok = AutoTokenizer.from_pretrained(args.tokenizer) 83 | # load raw dataset 84 | logger.info(f"Loading data from {args.dataDir}") 85 | dataset = IntentDataset(multi_label=args.multi_label) 86 | dataset.loadDataset(splitName(args.dataDir)) 87 | dataset.tokenize(tok) 88 | logger.info("----- Testing Data -----") 89 | testData = dataset.splitDomain(splitName(args.targetDomain), multi_label=args.multi_label) 90 | 91 | # ======= prepare model ====== 92 | # initialize model 93 | modelConfig = {} 94 | modelConfig['device'] = torch.device('cuda:0' if not args.disableCuda else 'cpu') 95 | modelConfig['clsNumber'] = args.shot 96 | modelConfig['LMName'] = args.LMName 97 | model = IntentBERT(modelConfig) 98 | logger.info("----- IntentBERT initialized -----") 99 | 100 | # setup evaluator 101 | valParam = {"evalTaskNum": args.taskNum, "clsFierName": args.clsFierName, 'multi_label':args.multi_label} 102 | valTaskParam = {"way":args.way, "shot":args.shot, "query":args.query} 103 | tester = FewShotEvaluator(valParam, valTaskParam, testData) 104 | 105 | # set up model 106 | logger.info("Evaluating model ...") 107 | # evaluate before finetuning begins 108 | tester.evaluate(model, tok, args.mode, logLevel='INFO') 109 | # print config 110 | logger.info(args) 111 | logger.info(time.asctime()) 112 | 113 | if __name__ == "__main__": 114 | main() 115 | exit(0) 116 | -------------------------------------------------------------------------------- /images/combined.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/images/combined.png -------------------------------------------------------------------------------- /images/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/images/main.png -------------------------------------------------------------------------------- /mlm.py: -------------------------------------------------------------------------------- 1 | # This file assembles three popular metric learnign baselines, matching network, prototype network and relation network. 2 | # This file is coded based on train_matchingNet.py. 3 | # coding=utf-8 4 | import os 5 | import torch 6 | import torch.optim as optim 7 | import argparse 8 | import time 9 | from transformers import AutoTokenizer 10 | 11 | from utils.models import IntentBERT 12 | from utils.IntentDataset import IntentDataset 13 | from utils.Trainer import MLMOnlyTrainer 14 | from utils.Evaluator import FewShotEvaluator 15 | from utils.commonVar import * 16 | from utils.printHelper import * 17 | from utils.tools import * 18 | from utils.Logger import logger 19 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 20 | 21 | def set_seed(seed): 22 | random.seed(seed) 23 | np.random.seed(seed) 24 | torch.manual_seed(seed) 25 | 26 | def parseArgument(): 27 | # ==== parse argument ==== 28 | parser = argparse.ArgumentParser(description='Train with MLM loss') 29 | 30 | # ==== model ==== 31 | parser.add_argument('--seed', default=1, type=int) 32 | parser.add_argument('--mode', default='multi-class', 33 | help='Choose from multi-class') 34 | parser.add_argument('--tokenizer', default='bert-base-uncased', 35 | help="Name of tokenizer") 36 | parser.add_argument('--LMName', default='bert-base-uncased', 37 | help='Name for models and path to saved model') 38 | 39 | # ==== dataset ==== 40 | parser.add_argument('--dataDir', 41 | help="Dataset names included in this experiment and separated by comma. " 42 | "For example:'OOS,bank77,hwu64'") 43 | parser.add_argument('--targetDomain', 44 | help='Target domain names and separated by comma') 45 | 46 | # ==== evaluation task ==== 47 | parser.add_argument('--way', type=int, default=5) 48 | parser.add_argument('--shot', type=int, default=2) 49 | parser.add_argument('--query', type=int, default=5) 50 | parser.add_argument('--clsFierName', default='Linear', 51 | help="Classifer name for few-shot evaluation" 52 | "Choose from Linear|SVM|NN|Cosine") 53 | 54 | # ==== optimizer ==== 55 | parser.add_argument('--optimizer', default='Adam', 56 | help='Choose from SGD|Adam') 57 | parser.add_argument('--learningRate', type=float, default=2e-5) 58 | parser.add_argument('--weightDecay', type=float, default=0) 59 | 60 | # ==== training arguments ==== 61 | parser.add_argument('--disableCuda', action="store_true") 62 | parser.add_argument('--epochs', type=int, default=50) 63 | parser.add_argument('--batch_size', type=int, default=32) 64 | parser.add_argument('--taskNum', type=int, default=500) 65 | 66 | # ==== other things ==== 67 | parser.add_argument('--loggingLevel', default='INFO', 68 | help="python logging level") 69 | parser.add_argument('--saveModel', action='store_true', 70 | help="Whether to save pretrained model") 71 | parser.add_argument('--saveName', default='none', 72 | help="Specify a unique name to save your model" 73 | "If none, then there will be a specific name controlled by how the model is trained") 74 | parser.add_argument('--tensorboard', action='store_true', 75 | help="Enable tensorboard to log training and validation accuracy") 76 | 77 | args = parser.parse_args() 78 | 79 | return args 80 | 81 | def main(): 82 | # ======= process arguments ====== 83 | args = parseArgument() 84 | print(args) 85 | 86 | if not args.saveModel: 87 | logger.info("The model will not be saved after training!") 88 | 89 | # ==== setup logger ==== 90 | if args.loggingLevel == LOGGING_LEVEL_INFO: 91 | loggingLevel = logging.INFO 92 | elif args.loggingLevel == LOGGING_LEVEL_DEBUG: 93 | loggingLevel = logging.DEBUG 94 | else: 95 | raise NotImplementedError("Not supported logging level %s", args.loggingLevel) 96 | logger.setLevel(loggingLevel) 97 | 98 | # ======= process data ====== 99 | # tokenizer 100 | tok = AutoTokenizer.from_pretrained(args.tokenizer) 101 | # load raw dataset 102 | logger.info(f"Loading data from {args.dataDir}") 103 | dataset = IntentDataset() 104 | dataset.loadDataset(splitName(args.dataDir)) 105 | dataset.tokenize(tok) 106 | # spit data into training, validation and testing 107 | logger.info("----- Testing Data -----") 108 | testData = dataset.splitDomain(splitName(args.targetDomain)) 109 | 110 | # ======= prepare model ====== 111 | # initialize model 112 | modelConfig = {} 113 | modelConfig['device'] = torch.device('cuda:0' if not args.disableCuda else 'cpu') 114 | modelConfig['clsNumber'] = testData.getLabNum() 115 | modelConfig['LMName'] = args.LMName 116 | model = IntentBERT(modelConfig) 117 | logger.info("----- IntentBERT initialized -----") 118 | 119 | # setup validator 120 | valParam = {"evalTaskNum": args.taskNum, "clsFierName": args.clsFierName, "multi_label": False} 121 | valTaskParam = {"way":args.way, "shot":args.shot, "query":args.query} 122 | tester = FewShotEvaluator(valParam, valTaskParam, testData) 123 | 124 | # setup trainer 125 | optimizer = None 126 | if args.optimizer == OPTER_ADAM: 127 | optimizer = optim.Adam(model.parameters(), lr=args.learningRate, weight_decay=args.weightDecay) 128 | elif args.optimizer == OPTER_SGD: 129 | optimizer = optim.SGD(model.parameters(), lr=args.learningRate, weight_decay=args.weightDecay) 130 | else: 131 | raise NotImplementedError("Not supported optimizer %s"%(args.optimizer)) 132 | 133 | trainingParam = {"epoch" : args.epochs, \ 134 | "batch" : args.batch_size, \ 135 | "tensorboard": args.tensorboard} 136 | trainer = MLMOnlyTrainer(trainingParam, optimizer, testData, testData, tester) 137 | 138 | # train 139 | trainer.train(model, tok) 140 | 141 | # evaluate once more to show results 142 | tester.evaluate(model, tok, args.mode, logLevel='INFO') 143 | 144 | # save model into disk 145 | if args.saveModel: 146 | if args.saveName == 'none': 147 | prefix = "MLMOnly" 148 | save_path = os.path.join(SAVE_PATH, f'{prefix}_{args.targetDomain}') 149 | else: 150 | save_path = os.path.join(SAVE_PATH, args.saveName) 151 | logger.info("Saving model.pth into folder: %s", save_path) 152 | if not os.path.exists(save_path): 153 | os.mkdir(save_path) 154 | model.save(save_path) 155 | 156 | # print config 157 | logger.info(args) 158 | logger.info(time.asctime()) 159 | 160 | if __name__ == "__main__": 161 | main() 162 | exit(0) 163 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo usage: 3 | echo scriptName.sh : run in normal mode 4 | echo scriptName.sh debug : run in debug mode 5 | 6 | # hardware 7 | cudaID=$2 8 | 9 | # debug mode 10 | if [[ $# != 0 ]] && [[ $1 == "debug" ]] 11 | then 12 | debug=true 13 | else 14 | debug=false 15 | fi 16 | 17 | seed=1 18 | 19 | # dataset 20 | dataDir='bank77' 21 | targetDomain="BANKING" 22 | dataDir=mcid 23 | targetDomain="MEDICAL" 24 | dataDir=hint3 25 | targetDomain='curekart,powerplay11,sofmattress' 26 | 27 | # setting 28 | shot=2 29 | 30 | # model initialization 31 | LMName=intent-bert-base-uncased 32 | # LMName=joint-intent-bert-base-uncased-bank77 33 | # LMName=joint-intent-bert-base-uncased-mcid 34 | # LMName=joint-intent-bert-base-uncased-hint3 35 | 36 | # modify arguments if it's debug mode 37 | RED='\033[0;31m' 38 | GRN='\033[0;32m' 39 | NC='\033[0m' # No Color 40 | if $debug 41 | then 42 | echo -e "Run in ${RED} debug ${NC} mode." 43 | epochs=1 44 | else 45 | echo -e "Run in ${GRN} normal ${NC} mode." 46 | fi 47 | 48 | echo "Start Experiment ..." 49 | logFolder=./log/ 50 | mkdir -p ${logFolder} 51 | logFile=${logFolder}/eval_${dataDir}_${way}way_${shot}shot_LMName${LMName}.log 52 | if $debug 53 | then 54 | logFlie=${logFolder}/logDebug.log 55 | fi 56 | 57 | export CUDA_VISIBLE_DEVICES=${cudaID} 58 | python eval.py \ 59 | --seed ${seed} \ 60 | --targetDomain ${targetDomain} \ 61 | --dataDir ${dataDir} \ 62 | --shot ${shot} \ 63 | --LMName ${LMName} \ 64 | | tee "${logFile}" 65 | echo "Experiment finished." 66 | -------------------------------------------------------------------------------- /scripts/mlm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo usage: 3 | echo scriptName.sh : run in normal mode 4 | echo scriptName.sh debug : run in debug mode 5 | 6 | # hardware 7 | cudaID=$2 8 | 9 | # debug mode 10 | if [[ $# != 0 ]] && [[ $1 == "debug" ]] 11 | then 12 | debug=true 13 | else 14 | debug=false 15 | fi 16 | 17 | seed=1 18 | 19 | # dataset 20 | dataDir='bank77' 21 | # dataDir='mcid' 22 | # dataDir='hint3' 23 | targetDomain="BANKING" 24 | # targetDomain='MEDICAL' 25 | # targetDomain='curekart,powerplay11,sofmattress' 26 | 27 | # setting 28 | shot=2 29 | 30 | # training 31 | tensorboard= 32 | saveModel=--saveModel 33 | saveName=none 34 | 35 | # model setting 36 | # common 37 | LMName=bert-base-uncased 38 | 39 | # modify arguments if it's debug mode 40 | RED='\033[0;31m' 41 | GRN='\033[0;32m' 42 | NC='\033[0m' # No Color 43 | if $debug 44 | then 45 | echo -e "Run in ${RED} debug ${NC} mode." 46 | # validationTaskNum=10 47 | # testTaskNum=10 48 | epochs=1 49 | # repeatNum=1 50 | else 51 | echo -e "Run in ${GRN} normal ${NC} mode." 52 | fi 53 | 54 | echo "Start Experiment ..." 55 | logFolder=./log/ 56 | mkdir -p ${logFolder} 57 | logFile=${logFolder}/mlm_${sourceDomainName}_to_${targetDomainName}_${way}way_${shot}shot.log 58 | if $debug 59 | then 60 | logFlie=${logFolder}/logDebug.log 61 | fi 62 | 63 | export CUDA_VISIBLE_DEVICES=${cudaID} 64 | python mlm.py \ 65 | --seed ${seed} \ 66 | --targetDomain ${targetDomain} \ 67 | ${tensorboard} \ 68 | --dataDir ${dataDir} \ 69 | --shot ${shot} \ 70 | ${saveModel} \ 71 | --LMName ${LMName} \ 72 | --saveName ${saveName} \ 73 | | tee "${logFile}" 74 | echo "Experiment finished." 75 | -------------------------------------------------------------------------------- /scripts/transfer.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo usage: 3 | echo scriptName.sh : run in normal mode 4 | echo scriptName.sh debug : run in debug mode 5 | 6 | # hardware 7 | cudaID=$2 8 | 9 | # debug mode 10 | if [[ $# != 0 ]] && [[ $1 == "debug" ]] 11 | then 12 | debug=true 13 | else 14 | debug=false 15 | fi 16 | 17 | seed=4321 18 | seed=1 19 | 20 | # dataset 21 | dataDir='oos,hwu64,bank77' 22 | # dataDir="oos,hwu64,mcid" 23 | # dataDir='oos,hwu64,hint3' 24 | 25 | sourceDomain="utility,auto_commute,kitchen_dining,work,home,meta,small_talk,travel" 26 | # valDomain="alarm,audio,iot,calendar,play,datetime,takeaway,news,music,weather,qa,social,recommendation,cooking,email,transport,lists" 27 | valDomain="iot,play,qa" 28 | 29 | # below is only for evaluation 30 | targetDomain="BANKING" 31 | # targetDomain="MEDICAL" 32 | # targetDomain="curekart,powerplay11,sofmattress" 33 | 34 | # setting 35 | shot=2 36 | 37 | # training 38 | tensorboard= 39 | saveModel=--saveModel 40 | saveName=none 41 | validation=--validation 42 | mlm= 43 | learningRate=1e-5 44 | learningRate=5e-6 45 | 46 | # model setting 47 | # common 48 | LMName=bert-base-uncased 49 | 50 | # modify arguments if it's debug mode 51 | RED='\033[0;31m' 52 | GRN='\033[0;32m' 53 | NC='\033[0m' # No Color 54 | if $debug 55 | then 56 | echo -e "Run in ${RED} debug ${NC} mode." 57 | epochs=1 58 | else 59 | echo -e "Run in ${GRN} normal ${NC} mode." 60 | fi 61 | 62 | echo "Start Experiment ..." 63 | logFolder=./log/ 64 | mkdir -p ${logFolder} 65 | logFile=${logFolder}/transfer_${sourceDomain}_to_${targetDomain}_${way}way_${shot}shot.log 66 | # logFile=${logFolder}/transfer_mlm_${sourceDomainName}_to_${targetDomainName}_${way}way_${shot}shot.log 67 | if $debug 68 | then 69 | logFlie=${logFolder}/logDebug.log 70 | fi 71 | 72 | export CUDA_VISIBLE_DEVICES=${cudaID} 73 | python transfer.py \ 74 | --seed ${seed} \ 75 | --valDomain ${valDomain} \ 76 | --sourceDomain ${sourceDomain} \ 77 | --targetDomain ${targetDomain} \ 78 | ${tensorboard} \ 79 | --dataDir ${dataDir} \ 80 | --shot ${shot} \ 81 | ${saveModel} \ 82 | ${validation} \ 83 | ${mlm} \ 84 | --learningRate ${learningRate} \ 85 | --LMName ${LMName} \ 86 | --saveName ${saveName} \ 87 | | tee "${logFile}" 88 | echo "Experiment finished." 89 | -------------------------------------------------------------------------------- /transfer.py: -------------------------------------------------------------------------------- 1 | # This file assembles three popular metric learnign baselines, matching network, prototype network and relation network. 2 | # This file is coded based on train_matchingNet.py. 3 | # coding=utf-8 4 | import os 5 | import torch 6 | import torch.optim as optim 7 | import argparse 8 | import time 9 | import copy 10 | from transformers import AutoTokenizer 11 | import random 12 | 13 | from utils.models import IntentBERT 14 | from utils.IntentDataset import IntentDataset 15 | from utils.Trainer import TransferTrainer 16 | from utils.Evaluator import FewShotEvaluator 17 | from utils.commonVar import * 18 | from utils.printHelper import * 19 | from utils.tools import * 20 | from utils.Logger import logger 21 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 22 | 23 | def set_seed(seed): 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | 28 | def parseArgument(): 29 | # ==== parse argument ==== 30 | parser = argparse.ArgumentParser(description='Train IntentBERT') 31 | 32 | # ==== model ==== 33 | parser.add_argument('--seed', default=1, type=int) 34 | parser.add_argument('--mode', default='multi-class', 35 | help='Choose from multi-class') 36 | parser.add_argument('--tokenizer', default='bert-base-uncased', 37 | help="Name of tokenizer") 38 | parser.add_argument('--LMName', default='bert-base-uncased', 39 | help='Name for models and path to saved model') 40 | 41 | # ==== dataset ==== 42 | parser.add_argument('--dataDir', 43 | help="Dataset names included in this experiment and separated by comma. " 44 | "For example:'OOS,bank77,hwu64'") 45 | parser.add_argument('--sourceDomain', 46 | help="Source domain names and separated by comma. " 47 | "For example:'travel,banking,home'") 48 | parser.add_argument('--valDomain', 49 | help='Validation domain names and separated by comma') 50 | parser.add_argument('--targetDomain', 51 | help='Target domain names and separated by comma') 52 | 53 | # ==== evaluation task ==== 54 | parser.add_argument('--way', type=int, default=5) 55 | parser.add_argument('--shot', type=int, default=2) 56 | parser.add_argument('--query', type=int, default=5) 57 | parser.add_argument('--clsFierName', default='Linear', 58 | help="Classifer name for few-shot evaluation" 59 | "Choose from Linear|SVM|NN|Cosine|MultiLabel") 60 | 61 | # ==== optimizer ==== 62 | parser.add_argument('--optimizer', default='Adam', 63 | help='Choose from SGD|Adam') 64 | parser.add_argument('--learningRate', type=float, default=2e-5) 65 | parser.add_argument('--weightDecay', type=float, default=0) 66 | 67 | # ==== training arguments ==== 68 | parser.add_argument('--disableCuda', action="store_true") 69 | parser.add_argument('--validation', action="store_true") 70 | parser.add_argument('--epochs', type=int, default=10) 71 | parser.add_argument('--batch_size', type=int, default=32) 72 | parser.add_argument('--taskNum', type=int, default=500) 73 | parser.add_argument('--patience', type=int, default=3, 74 | help="Early stop when performance does not go better") 75 | parser.add_argument('--mlm', action='store_true', 76 | help="If use mlm as auxiliary loss") 77 | parser.add_argument('--lambda_mlm', type=float, default=1.0, 78 | help="The weight for mlm loss") 79 | parser.add_argument('--mlm_data', default='target', type=str, 80 | help="Data for mlm. Choose from target|source") 81 | parser.add_argument('--shuffle_mlm', action="store_true") 82 | parser.add_argument('--shuffle', action="store_true") 83 | parser.add_argument('--regression', action="store_true", 84 | help="If the pretrain task is a regression task") 85 | 86 | # ==== other things ==== 87 | parser.add_argument('--loggingLevel', default='INFO', 88 | help="python logging level") 89 | parser.add_argument('--saveModel', action='store_true', 90 | help="Whether to save pretrained model") 91 | parser.add_argument('--saveName', default='none', 92 | help="Specify a unique name to save your model" 93 | "If none, then there will be a specific name controlled by how the model is trained") 94 | parser.add_argument('--tensorboard', action='store_true', 95 | help="Enable tensorboard to log training and validation accuracy") 96 | 97 | args = parser.parse_args() 98 | 99 | return args 100 | 101 | def main(): 102 | # ======= process arguments ====== 103 | args = parseArgument() 104 | print(args) 105 | 106 | if not args.saveModel: 107 | logger.info("The model will not be saved after training!") 108 | 109 | # ==== setup logger ==== 110 | if args.loggingLevel == LOGGING_LEVEL_INFO: 111 | loggingLevel = logging.INFO 112 | elif args.loggingLevel == LOGGING_LEVEL_DEBUG: 113 | loggingLevel = logging.DEBUG 114 | else: 115 | raise NotImplementedError("Not supported logging level %s", args.loggingLevel) 116 | logger.setLevel(loggingLevel) 117 | 118 | # ==== set seed ==== 119 | set_seed(args.seed) 120 | 121 | # ======= process data ====== 122 | # tokenizer 123 | tok = AutoTokenizer.from_pretrained(args.tokenizer) 124 | # load raw dataset 125 | logger.info(f"Loading data from {args.dataDir}") 126 | dataset = IntentDataset(regression=args.regression) 127 | dataset.loadDataset(splitName(args.dataDir)) 128 | dataset.tokenize(tok) 129 | # spit data into training, validation and testing 130 | logger.info("----- Training Data -----") 131 | trainData = dataset.splitDomain(splitName(args.sourceDomain), regression=args.regression) 132 | logger.info("----- Validation Data -----") 133 | valData = dataset.splitDomain(splitName(args.valDomain)) 134 | logger.info("----- Testing Data -----") 135 | testData = dataset.splitDomain(splitName(args.targetDomain)) 136 | # shuffle word order 137 | if args.shuffle: 138 | trainData.shuffle_words() 139 | 140 | # ======= prepare model ====== 141 | # initialize model 142 | modelConfig = {} 143 | modelConfig['device'] = torch.device('cuda:0' if not args.disableCuda else 'cpu') 144 | if args.regression: 145 | modelConfig['clsNumber'] = 1 146 | else: 147 | modelConfig['clsNumber'] = trainData.getLabNum() 148 | modelConfig['LMName'] = args.LMName 149 | model = IntentBERT(modelConfig) 150 | logger.info("----- IntentBERT initialized -----") 151 | 152 | # setup validator 153 | valParam = {"evalTaskNum": args.taskNum, "clsFierName": args.clsFierName, "multi_label": False} 154 | valTaskParam = {"way":args.way, "shot":args.shot, "query":args.query} 155 | validator = FewShotEvaluator(valParam, valTaskParam, valData) 156 | tester = FewShotEvaluator(valParam, valTaskParam, testData) 157 | 158 | # setup trainer 159 | optimizer = None 160 | if args.optimizer == OPTER_ADAM: 161 | optimizer = optim.Adam(model.parameters(), lr=args.learningRate, weight_decay=args.weightDecay) 162 | elif args.optimizer == OPTER_SGD: 163 | optimizer = optim.SGD(model.parameters(), lr=args.learningRate, weight_decay=args.weightDecay) 164 | else: 165 | raise NotImplementedError("Not supported optimizer %s"%(args.optimizer)) 166 | 167 | if args.mlm and args.mlm_data == "target": 168 | args.validation = False 169 | trainingParam = {"epoch" : args.epochs, \ 170 | "batch" : args.batch_size, \ 171 | "validation" : args.validation, \ 172 | "patience" : args.patience, \ 173 | "tensorboard": args.tensorboard, \ 174 | "mlm" : args.mlm, \ 175 | "lambda mlm" : args.lambda_mlm, \ 176 | "regression" : args.regression} 177 | unlabeledData = None 178 | if args.mlm_data == "source": 179 | unlabeledData = copy.deepcopy(trainData) 180 | elif args.mlm_data == "target": 181 | unlabeledData = copy.deepcopy(testData) 182 | if args.shuffle_mlm: 183 | unlabeledData.shuffle_words() 184 | trainer = TransferTrainer(trainingParam, optimizer, trainData, unlabeledData, validator, tester) 185 | 186 | # train 187 | trainer.train(model, tok, args.mode) 188 | 189 | # load best model 190 | bestModelStateDict = trainer.getBestModelStateDict() 191 | model.load_state_dict(bestModelStateDict) 192 | 193 | # evaluate once more to show results 194 | tester.evaluate(model, tok, args.mode, logLevel='INFO') 195 | 196 | # save model into disk 197 | if args.saveModel: 198 | # decide the save name 199 | if args.saveName == 'none': 200 | prefix = "STMLM" if args.mlm else "ST" 201 | if args.mlm: 202 | if args.mlm_data == 'target': 203 | prefix += f"_{args.targetDomain}" 204 | elif args.mlm_data == 'source': 205 | prefix += "_source" 206 | save_path = os.path.join(SAVE_PATH, f'{prefix}_{args.mode}_{args.sourceDomain}') 207 | if args.shuffle: 208 | save_path += "_shuffle" 209 | if args.shuffle_mlm: 210 | save_path += "_shuffle_mlm" 211 | else: 212 | save_path = os.path.join(SAVE_PATH, args.saveName) 213 | # save 214 | logger.info("Saving model.pth into folder: %s", save_path) 215 | if not os.path.exists(save_path): 216 | os.mkdir(save_path) 217 | model.save(save_path) 218 | 219 | # print config 220 | logger.info(args) 221 | logger.info(time.asctime()) 222 | 223 | if __name__ == "__main__": 224 | main() 225 | exit(0) 226 | -------------------------------------------------------------------------------- /utils/Evaluator.py: -------------------------------------------------------------------------------- 1 | from utils.IntentDataset import IntentDataset 2 | from utils.TaskSampler import MultiLabTaskSampler, UniformTaskSampler 3 | from utils.tools import makeEvalExamples 4 | from utils.printHelper import * 5 | from utils.Logger import logger 6 | from utils.commonVar import * 7 | import logging 8 | import torch 9 | import numpy as np 10 | from tqdm import tqdm 11 | 12 | from sklearn.metrics import accuracy_score, precision_recall_fscore_support 13 | 14 | ## 15 | # @brief base class of evaluator 16 | class EvaluatorBase(): 17 | def __init__(self): 18 | self.roundN = 4 19 | pass 20 | 21 | def round(self, floatNum): 22 | return round(floatNum, self.roundN) 23 | 24 | def evaluate(self): 25 | raise NotImplementedError("train() is not implemented.") 26 | 27 | ## 28 | # @brief MetaEvaluator used to do meta evaluation. Tasks are sampled and the model is evaluated task by task. 29 | class FewShotEvaluator(EvaluatorBase): 30 | def __init__(self, evalParam, taskParam, dataset: IntentDataset): 31 | super(FewShotEvaluator, self).__init__() 32 | self.way = taskParam['way'] 33 | self.shot = taskParam['shot'] 34 | self.query = taskParam['query'] 35 | 36 | self.dataset = dataset 37 | 38 | self.multi_label = evalParam['multi_label'] 39 | self.clsFierName = evalParam['clsFierName'] 40 | self.evalTaskNum = evalParam['evalTaskNum'] 41 | logger.info("In evaluator classifier %s is used.", self.clsFierName) 42 | 43 | if self.multi_label: 44 | self.taskSampler = MultiLabTaskSampler(self.dataset, self.shot, self.query) 45 | else: 46 | self.taskSampler = UniformTaskSampler(self.dataset, self.way, self.shot, self.query) 47 | 48 | def evaluate(self, model, tokenizer, mode='multi-class', logLevel='DEBUG'): 49 | model.eval() 50 | 51 | performList = [] # acc, pre, rec, fsc 52 | with torch.no_grad(): 53 | for task in range(self.evalTaskNum): 54 | # sample a task 55 | task = self.taskSampler.sampleOneTask() 56 | 57 | # collect data 58 | supportX = task[META_TASK_SHOT_TOKEN] 59 | queryX = task[META_TASK_QUERY_TOKEN] 60 | if mode == 'multi-class': 61 | supportY = task[META_TASK_SHOT_LOC_LABID] 62 | queryY = task[META_TASK_QUERY_LOC_LABID] 63 | else: 64 | logger.error("Invalid model %d"%(mode)) 65 | 66 | # padding 67 | supportX, supportY, queryX, queryY =\ 68 | makeEvalExamples(supportX, supportY, queryX, queryY, tokenizer, mode=mode) 69 | 70 | # forward 71 | queryPrediction = model.fewShotPredict(supportX.to(model.device), 72 | supportY, 73 | queryX.to(model.device), 74 | self.clsFierName, 75 | mode=mode) 76 | 77 | # calculate acc 78 | acc = accuracy_score(queryY, queryPrediction) # acc 79 | if self.multi_label: 80 | performDetail = precision_recall_fscore_support(queryY, queryPrediction, average='micro', warn_for=tuple()) 81 | else: 82 | performDetail = precision_recall_fscore_support(queryY, queryPrediction, average='macro', warn_for=tuple()) 83 | 84 | performList.append([acc, performDetail[0], performDetail[1], performDetail[2]]) 85 | 86 | # performance mean and std 87 | performMean = np.mean(np.stack(performList, 0), 0) 88 | performStd = np.std(np.stack(performList, 0), 0) 89 | 90 | if logLevel == 'DEBUG': 91 | itemList = ["acc", "pre", "rec", "fsc"] 92 | logger.debug("Evaluate statistics: ") 93 | printMeanStd(performMean, performStd, itemList, debugLevel=logging.DEBUG) 94 | else: 95 | itemList = ["acc", "pre", "rec", "fsc"] 96 | logger.info("Evaluate statistics: ") 97 | printMeanStd(performMean, performStd, itemList, debugLevel=logging.INFO) 98 | 99 | # acc, pre, rec, F1 100 | return performMean[0], performMean[1], performMean[2], performMean[3] 101 | 102 | 103 | ## 104 | # @brief MetaEvaluator used to do meta evaluation. Tasks are sampled and the model is evaluated task by task. 105 | class FineTuneEvaluator(EvaluatorBase): 106 | def __init__(self, evalParam, taskParam, optimizer, dataset: IntentDataset): 107 | super(FineTuneEvaluator, self).__init__() 108 | self.way = taskParam['way'] 109 | self.shot = taskParam['shot'] 110 | self.query = taskParam['query'] 111 | 112 | self.dataset = dataset 113 | self.optimizer = optimizer 114 | 115 | self.finetuneSteps = evalParam['finetuneSteps'] 116 | self.evalTaskNum = evalParam['evalTaskNum'] 117 | 118 | self.taskSampler = UniformTaskSampler(self.dataset, self.way, self.shot, self.query) 119 | 120 | def evaluate(self, model, tokenizer, mode='multi-class', logLevel='DEBUG'): 121 | performList = [] # acc, pre, rec, fsc 122 | initial_model = model.state_dict().copy() 123 | initial_optim = self.optimizer.state_dict().copy() 124 | 125 | for task in tqdm(range(self.evalTaskNum)): 126 | # sample a task 127 | task = self.taskSampler.sampleOneTask() 128 | 129 | # collect data 130 | supportX = task[META_TASK_SHOT_TOKEN] 131 | queryX = task[META_TASK_QUERY_TOKEN] 132 | if mode == 'multi-class': 133 | supportY = task[META_TASK_SHOT_LOC_LABID] 134 | queryY = task[META_TASK_QUERY_LOC_LABID] 135 | else: 136 | logger.error("Invalid model %d"%(mode)) 137 | 138 | # padding 139 | supportX, supportY, queryX, queryY =\ 140 | makeEvalExamples(supportX, supportY, queryX, queryY, tokenizer, mode=mode) 141 | 142 | # finetune 143 | model.train() 144 | for _ in range(self.finetuneSteps): 145 | logits = model(supportX.to(model.device)) 146 | loss = model.loss_ce(logits, torch.tensor(supportY).to(model.device)) 147 | self.optimizer.zero_grad() 148 | loss.backward() 149 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 150 | self.optimizer.step() 151 | 152 | model.eval() 153 | with torch.no_grad(): 154 | if mode == 'multi-class': 155 | queryPrediction = model(queryX.to(model.device)).argmax(-1) 156 | else: 157 | logger.error("Invalid model %d"%(mode)) 158 | 159 | queryPrediction = queryPrediction.cpu().numpy() 160 | 161 | # calculate acc 162 | acc = accuracy_score(queryY, queryPrediction) # acc 163 | performDetail = precision_recall_fscore_support(queryY, queryPrediction, average='macro', warn_for=tuple()) 164 | 165 | performList.append([acc, performDetail[0], performDetail[1], performDetail[2]]) 166 | 167 | model.load_state_dict(initial_model) 168 | self.optimizer.load_state_dict(initial_optim) 169 | 170 | # performance mean and std 171 | performMean = np.mean(np.stack(performList, 0), 0) 172 | performStd = np.std(np.stack(performList, 0), 0) 173 | 174 | if logLevel == 'DEBUG': 175 | itemList = ["acc", "pre", "rec", "fsc"] 176 | logger.debug("Evaluate statistics: ") 177 | printMeanStd(performMean, performStd, itemList, debugLevel=logging.DEBUG) 178 | else: 179 | itemList = ["acc", "pre", "rec", "fsc"] 180 | logger.info("Evaluate statistics: ") 181 | printMeanStd(performMean, performStd, itemList, debugLevel=logging.INFO) 182 | 183 | # acc, pre, rec, F1 184 | return performMean[0], performMean[1], performMean[2], performMean[3] 185 | -------------------------------------------------------------------------------- /utils/IntentDataset.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import nltk 3 | from nltk.corpus import stopwords 4 | nltk.download('stopwords') 5 | 6 | from utils.commonVar import * 7 | import json 8 | from utils.Logger import logger 9 | import os 10 | import copy 11 | import random 12 | 13 | 14 | class IntentDataset(): 15 | def __init__(self, 16 | domList=None, 17 | labList=None, 18 | uttList=None, 19 | tokList=None, 20 | regression=False, 21 | multi_label=False): 22 | self.regression = regression 23 | self.multi_label = multi_label 24 | 25 | self.domList = [] if domList is None else domList 26 | self.labList = [] if labList is None else labList 27 | self.uttList = [] if uttList is None else uttList 28 | self.tokList = [] if tokList is None else tokList 29 | 30 | if (self.labList is not None) and (not self.regression): 31 | self.createLabID() 32 | if self.regression: 33 | self.labIDList = self.labList 34 | if not self.multi_label: 35 | self.convertLabs() 36 | self.labID2DataInd = None 37 | self.dataInd2LabID = None 38 | 39 | def getDomList(self): 40 | return self.domList 41 | 42 | def getLabList(self): 43 | return self.labList 44 | 45 | def getUttList(self): 46 | return self.uttList 47 | 48 | def getTokList(self): 49 | return self.tokList 50 | 51 | def getAllData(self): 52 | return self.domList, self.labList, self.uttList, self.tokList 53 | 54 | def getLabNum(self): 55 | labSet = set() 56 | for lab in self.labList: 57 | if self.multi_label: 58 | for l in lab: 59 | labSet.add(l) 60 | else: 61 | labSet.add(lab) 62 | return len(labSet) 63 | 64 | def getLabID(self): 65 | return self.labIDList 66 | 67 | def checkData(self, utt: str, label: str): 68 | if not self.regression: 69 | if len(label) == 0 or len(utt) == 0: 70 | logger.warning("Illegal label %s or utterance %s, 0 length", label, utt) 71 | return 1 72 | return 0 73 | 74 | def loadDataset(self, dataDirList): 75 | self.domList, self.labList, self.uttList = [], [], [] 76 | dataFilePathList = \ 77 | [os.path.join(DATA_PATH, dataDir, FILE_NAME_DATASET) for dataDir in dataDirList] 78 | 79 | dataList = [] 80 | for dataFilePath in dataFilePathList: 81 | with open(dataFilePath, 'r') as json_file: 82 | dataList.append(json.load(json_file)) 83 | 84 | delDataNum = 0 85 | for data in dataList: 86 | for datasetName in data: 87 | dataset = data[datasetName] 88 | for domainName in dataset: 89 | domain = dataset[domainName] 90 | for dataItem in domain: 91 | utt = dataItem[0] 92 | labList = dataItem[1] 93 | 94 | if self.multi_label: 95 | lab = labList 96 | else: 97 | lab = labList[0] 98 | 99 | if not self.checkData(utt, lab) == 0: 100 | logger.warning("Illegal label %s or utterance %s, too short length", lab, utt) 101 | delDataNum = delDataNum+1 102 | else: 103 | self.domList.append(domainName) 104 | self.labList.append(lab) 105 | self.uttList.append(utt) 106 | 107 | # report deleted data number 108 | if (delDataNum>0): 109 | logger.warning("%d data is deleted from dataset.", delDataNum) 110 | 111 | # sanity check 112 | countSet = set() 113 | countSet.add(len(self.domList)) 114 | countSet.add(len(self.labList)) 115 | countSet.add(len(self.uttList)) 116 | if len(countSet) > 1: 117 | logger.error("Unaligned data list. Length of data list: dataset %d, domain %d, lab %d, utterance %d", len(self.domainList), len(self.labList), len(self.uttList)) 118 | exit(1) 119 | if not self.regression: 120 | self.createLabID() 121 | else: 122 | self.labIDList = self.labList 123 | logger.info(f"{countSet} data collected") 124 | return 0 125 | 126 | def removeStopWord(self): 127 | raise NotImplementedError 128 | 129 | # print info 130 | logger.info("Removing stop words ...") 131 | logger.info("Before removing stop words: data count is %d", len(self.uttList)) 132 | 133 | # remove stop word 134 | stopwordsEnglish = stopwords.words('english') 135 | uttListNew = [] 136 | labListNew = [] 137 | delLabListNew = [] 138 | delUttListNew = [] # Utt for utterance 139 | maxLen = -1 140 | for lab, utt in zip(self.labList, self.uttList): 141 | uttWordListNew = [w for w in utt.split(' ') if not word in stopwordsEnglish] 142 | uttNew = ' '.join(uttWordListNew) 143 | 144 | uttNewLen = len(uttWordListNew) 145 | if uttNewLen <= 0: # too short utterance, delete it from dataset 146 | delLabListNew.append(lab) 147 | delUttListNew.append(uttNew) 148 | else: # utt with normal length 149 | if uttNewLen > maxLen: 150 | maxLen = uttNewLen 151 | labListNew.append(lab) 152 | uttListNew.append(uttNew) 153 | self.labList = labListNew 154 | self.uttListNew = uttListNew 155 | self.delLabList.append(delLabListNew) 156 | self.delUttList.append(delUttListNew) 157 | 158 | # update data list 159 | logger.info("After removing stop words: data count is %d", len(self.uttList)) 160 | logger.info("Removing stop words ... done.") 161 | 162 | return 0 163 | 164 | def splitDomain(self, domainName: list, regression=False, multi_label=False): 165 | domList = self.getDomList() 166 | 167 | # collect index 168 | indList = [] 169 | for ind, domain in enumerate(domList): 170 | if domain in domainName: 171 | indList.append(ind) 172 | 173 | # sanity check 174 | dataCount = len(indList) 175 | if dataCount<1: 176 | logger.error("Empty data for domain %s", domainName) 177 | exit(1) 178 | 179 | logger.info("For domain %s, %d data is selected from %d data in the dataset.", domainName, dataCount, len(domList)) 180 | 181 | # get all data from dataset 182 | domList, labList, uttList, tokList = self.getAllData() 183 | domDomList = [domList[i] for i in indList] 184 | domLabList = [labList[i] for i in indList] 185 | domUttList = [uttList[i] for i in indList] 186 | if self.tokList: 187 | domTokList = [tokList[i] for i in indList] 188 | else: 189 | domTokList = [] 190 | domDataset = IntentDataset(domDomList, domLabList, domUttList, domTokList, regression=regression, multi_label=multi_label) 191 | 192 | return domDataset 193 | 194 | def tokenize(self, tokenizer): 195 | self.tokList = [] 196 | for u in self.uttList: 197 | ut = tokenizer(u) 198 | if 'token_type_ids' not in ut: 199 | ut['token_type_ids'] = [0]*len(ut['input_ids']) 200 | self.tokList.append(ut) 201 | 202 | def shuffle_words(self): 203 | newList = [] 204 | for u in self.uttList: 205 | replace = copy.deepcopy(u) 206 | replace = replace.split(' ') 207 | random.shuffle(replace) 208 | replace = ' '.join(replace) 209 | newList.append(replace) 210 | self.uttList = newList 211 | 212 | # convert label names to label IDs: 0, 1, 2, 3 213 | def createLabID(self): 214 | # get unique label 215 | labSet = set() 216 | for lab in self.labList: 217 | if self.multi_label: 218 | for l in lab: 219 | labSet.add(l) 220 | else: 221 | labSet.add(lab) 222 | 223 | # get number 224 | self.labNum = len(labSet) 225 | sortedLabList = list(labSet) 226 | sortedLabList.sort() 227 | 228 | # fill up dict: lab -> labID 229 | self.name2LabID = {} 230 | for ind, lab in enumerate(sortedLabList): 231 | if not lab in self.name2LabID: 232 | self.name2LabID[lab] = ind 233 | 234 | # fill up label ID list 235 | self.labIDList =[] 236 | for lab in self.labList: 237 | if self.multi_label: 238 | labID = [] 239 | for l in lab: 240 | labID.append(self.name2LabID[l]) 241 | self.labIDList.append(labID) 242 | else: 243 | self.labIDList.append(self.name2LabID[lab]) 244 | 245 | # sanity check 246 | if not len(self.labIDList) == len(self.uttList): 247 | logger.error("create labID error. Not consistence labe ID list length and utterance list length.") 248 | exit(1) 249 | 250 | def getLabID2dataInd(self): 251 | if not self.labID2DataInd == None: 252 | return self.labID2DataInd 253 | else: 254 | self.labID2DataInd = {} 255 | for dataInd, labID in enumerate(self.labIDList): 256 | if self.multi_label: 257 | for l in labID: 258 | if not l in self.labID2DataInd: 259 | self.labID2DataInd[l] = [] 260 | self.labID2DataInd[l].append(dataInd) 261 | else: 262 | if not labID in self.labID2DataInd: 263 | self.labID2DataInd[labID] = [] 264 | self.labID2DataInd[labID].append(dataInd) 265 | 266 | # sanity check 267 | if not self.multi_label: 268 | dataCount = 0 269 | for labID in self.labID2DataInd: 270 | dataCount = dataCount + len(self.labID2DataInd[labID]) 271 | if not dataCount == len(self.uttList): 272 | logger.error("Inconsistent data count %d and %d when generating dict, labID2DataInd", dataCount, len(self.uttList)) 273 | exit(1) 274 | 275 | return self.labID2DataInd 276 | 277 | def getDataInd2labID(self): 278 | if not self.dataInd2LabID == None: 279 | return self.dataInd2LabID 280 | else: 281 | self.dataInd2LabID = {} 282 | for dataInd, labID in enumerate(self.labIDList): 283 | self.dataInd2LabID[dataInd] = labID 284 | return self.dataInd2LabID 285 | 286 | def convertLabs(self): 287 | # when the dataset is not multi-label, convert labels from list to a single instance 288 | if self.labList: 289 | if isinstance(self.labList[0], list): 290 | newList = [] 291 | for l in self.labList: 292 | newList.append(l[0]) 293 | self.labList = newList 294 | if self.labIDList: 295 | if isinstance(self.labIDList[0], list): 296 | newList = [] 297 | for l in self.labIDList: 298 | newList.append(l[0]) 299 | self.labIDList = newList -------------------------------------------------------------------------------- /utils/Logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | logging.basicConfig(stream=sys.stdout, format='[%(levelname)s] %(message)s') 5 | logger = logging.getLogger('globalLogger') 6 | -------------------------------------------------------------------------------- /utils/TaskSampler.py: -------------------------------------------------------------------------------- 1 | from utils.IntentDataset import IntentDataset 2 | from utils.Logger import logger 3 | import random 4 | from utils.commonVar import * 5 | import numpy as np 6 | import copy 7 | from sklearn.preprocessing import MultiLabelBinarizer 8 | # random.seed(0) 9 | # base class for task samper 10 | # sample meta-task from a dataset for training and evaluation 11 | class TaskSampler(): 12 | def __init__(self, dataset:IntentDataset): 13 | self.dataset = dataset 14 | 15 | def sampleOneTask(): 16 | raise NotImplementedError("sampleOneTask() is not implemented.") 17 | 18 | class UniformTaskSampler(TaskSampler): 19 | def __init__(self, dataset:IntentDataset, way, shot, query): 20 | super(UniformTaskSampler, self).__init__(dataset) 21 | self.way = way 22 | self.shot = shot 23 | self.query = query 24 | self.taskPool = None 25 | self.dataset = dataset 26 | 27 | ## 28 | # @brief sample data index for a task. Class global IDs are also sampled. 29 | # 30 | # @return a dict, print it to see what's there 31 | def sampleClassIDsDataInd(self): 32 | taskInfo = {} 33 | glbLabNum = self.dataset.getLabNum() 34 | labID2DataInd = self.dataset.getLabID2dataInd() 35 | 36 | uniqueGlbLabIDs = list(range(glbLabNum)) 37 | # random sample global label IDs 38 | taskGlbLabIDs = random.sample(uniqueGlbLabIDs, self.way) 39 | taskInfo[META_TASK_GLB_LABID] = taskGlbLabIDs 40 | 41 | # sample data for each labID 42 | taskInfo[META_TASK_SHOT_GLB_LABID] = [] 43 | taskInfo[META_TASK_QUERY_GLB_LABID] = [] 44 | taskInfo[META_TASK_SHOT_DATAIND] = [] 45 | taskInfo[META_TASK_QUERY_DATAIND] = [] 46 | for labID in taskGlbLabIDs: 47 | # random sample support data and query data 48 | dataInds = random.sample(labID2DataInd[labID], self.shot+self.query) 49 | random.shuffle(dataInds) 50 | shotDataInds = dataInds[:self.shot] 51 | queryDataInds = dataInds[(-self.query):] 52 | 53 | taskInfo[META_TASK_SHOT_GLB_LABID].extend([labID]*(self.shot)) 54 | taskInfo[META_TASK_SHOT_DATAIND].extend(shotDataInds) 55 | taskInfo[META_TASK_QUERY_GLB_LABID].extend([labID]*(self.query)) 56 | taskInfo[META_TASK_QUERY_DATAIND].extend(queryDataInds) 57 | 58 | return taskInfo 59 | 60 | 61 | ## 62 | # @brief it works with sampleClassIDsDataInd(), taking a taskInfo returned by sampleClassIDsDataInd(), then return data in the task. 63 | # 64 | # @param taskDataInds a dict containing data index 65 | # 66 | # @return a dict containing task data, print it to see what's there 67 | def collectDataForTask(self, taskDataInds): 68 | task = {} 69 | 70 | # compose local labID from glbLabID 71 | glbLabIDList = taskDataInds[META_TASK_GLB_LABID] 72 | glbLabID2LocLabID = {} 73 | for pos, glbLabID in enumerate(glbLabIDList): 74 | glbLabID2LocLabID[glbLabID] = pos 75 | tokList = self.dataset.getTokList() 76 | labList = self.dataset.getLabList() 77 | 78 | # support 79 | task[META_TASK_SHOT_LOC_LABID] = [glbLabID2LocLabID[glbLabID] for glbLabID in taskDataInds[META_TASK_SHOT_GLB_LABID]] 80 | task[META_TASK_SHOT_TOKEN] = [tokList[i] for i in taskDataInds[META_TASK_SHOT_DATAIND]] 81 | task[META_TASK_SHOT_LAB] = [labList[i] for i in taskDataInds[META_TASK_SHOT_DATAIND]] 82 | 83 | # query 84 | task[META_TASK_QUERY_LOC_LABID] = [glbLabID2LocLabID[glbLabID] for glbLabID in taskDataInds[META_TASK_QUERY_GLB_LABID]] 85 | task[META_TASK_QUERY_TOKEN] = [tokList[i] for i in taskDataInds[META_TASK_QUERY_DATAIND]] 86 | task[META_TASK_QUERY_LAB] = [labList[i] for i in taskDataInds[META_TASK_QUERY_DATAIND]] 87 | 88 | return task 89 | 90 | def sampleOneTask(self): 91 | # 1. sample classes and data index 92 | taskDataInds = self.sampleClassIDsDataInd() 93 | 94 | # 2. according to data index, select data, such tokens, lenths, label names, etc. 95 | task = self.collectDataForTask(taskDataInds) 96 | 97 | return task 98 | 99 | 100 | class MultiLabTaskSampler(TaskSampler): 101 | def __init__(self, dataset:IntentDataset, shot, query): 102 | super(MultiLabTaskSampler, self).__init__(dataset) 103 | self.shot = shot 104 | self.query = query 105 | self.taskPool = None 106 | self.dataset = dataset 107 | 108 | ## 109 | # @brief sample data index for a task. Class global IDs are also sampled. 110 | # 111 | # @return a dict, print it to see what's there 112 | def sampleClassIDsDataInd(self): 113 | taskInfo = {} 114 | glbLabNum = self.dataset.getLabNum() 115 | labID2DataInd = self.dataset.getLabID2dataInd() 116 | 117 | uniqueGlbLabIDs = list(range(glbLabNum)) 118 | # random sample global label IDs 119 | taskInfo[META_TASK_GLB_LABID] = uniqueGlbLabIDs 120 | 121 | # sample data for each labID 122 | taskInfo[META_TASK_SHOT_GLB_LABID] = [] 123 | taskInfo[META_TASK_QUERY_GLB_LABID] = [] 124 | taskInfo[META_TASK_SHOT_DATAIND] = [] 125 | taskInfo[META_TASK_QUERY_DATAIND] = [] 126 | shotDataInds, queryDataInds = [], [] 127 | for labID in uniqueGlbLabIDs: 128 | # random sample support data 129 | dataInds = random.sample(labID2DataInd[labID], self.shot) 130 | shotDataInds += dataInds 131 | # random sample query data 132 | remain = list(set(labID2DataInd[labID])-set(dataInds)) 133 | dataInds = random.sample(remain, self.query) 134 | queryDataInds += dataInds 135 | shotDataInds = list(set(shotDataInds)) 136 | queryDataInds = list(set(queryDataInds)) 137 | shotDataInds = self.checkForDuplicate(shotDataInds, required_num=self.shot) 138 | queryDataInds = self.checkForDuplicate(queryDataInds, required_num=self.query) 139 | 140 | taskInfo[META_TASK_SHOT_DATAIND] = shotDataInds 141 | taskInfo[META_TASK_QUERY_DATAIND] = queryDataInds 142 | 143 | dataInd2LabID = self.dataset.getDataInd2labID() 144 | for d in shotDataInds: 145 | taskInfo[META_TASK_SHOT_GLB_LABID].append(dataInd2LabID[d]) 146 | for d in queryDataInds: 147 | taskInfo[META_TASK_QUERY_GLB_LABID].append(dataInd2LabID[d]) 148 | return taskInfo 149 | 150 | def checkForDuplicate(self, dataInds, required_num): 151 | dataInd2LabID = self.dataset.getDataInd2labID() 152 | label_lists = [] 153 | for di in dataInds: 154 | label_lists.extend(dataInd2LabID[di]) 155 | label_names, counts = np.unique(label_lists, return_counts=True) 156 | shot_counts = {ln: c for ln, c in zip(label_names, counts)} 157 | loopInds = copy.deepcopy(dataInds) 158 | for di in loopInds: 159 | can_remove = True 160 | for l in dataInd2LabID[di]: 161 | if (l in shot_counts) and (shot_counts[l] - 1 < required_num): 162 | can_remove = False 163 | if can_remove: 164 | dataInds.remove(di) 165 | for l in dataInd2LabID[di]: 166 | shot_counts[l] -= 1 167 | return dataInds 168 | 169 | ## 170 | # @brief it works with sampleClassIDsDataInd(), taking a taskInfo returned by sampleClassIDsDataInd(), then return data in the task. 171 | # 172 | # @param taskDataInds a dict containing data index 173 | # 174 | # @return a dict containing task data, print it to see what's there 175 | def collectDataForTask(self, taskDataInds): 176 | task = {} 177 | 178 | # compose local labID from glbLabID 179 | tokList = self.dataset.getTokList() 180 | labList = self.dataset.getLabList() 181 | 182 | mlb = MultiLabelBinarizer() 183 | 184 | # support 185 | task[META_TASK_SHOT_LOC_LABID] = mlb.fit_transform(taskDataInds[META_TASK_SHOT_GLB_LABID]).tolist() 186 | task[META_TASK_SHOT_TOKEN] = [tokList[i] for i in taskDataInds[META_TASK_SHOT_DATAIND]] 187 | task[META_TASK_SHOT_LAB] = [labList[i] for i in taskDataInds[META_TASK_SHOT_DATAIND]] 188 | 189 | # query 190 | task[META_TASK_QUERY_LOC_LABID] = mlb.fit_transform(taskDataInds[META_TASK_QUERY_GLB_LABID]).tolist() 191 | task[META_TASK_QUERY_TOKEN] = [tokList[i] for i in taskDataInds[META_TASK_QUERY_DATAIND]] 192 | task[META_TASK_QUERY_LAB] = [labList[i] for i in taskDataInds[META_TASK_QUERY_DATAIND]] 193 | 194 | return task 195 | 196 | def sampleOneTask(self): 197 | # 1. sample classes and data index 198 | taskDataInds = self.sampleClassIDsDataInd() 199 | 200 | # 2. according to data index, select data, such tokens, lenths, label names, etc. 201 | task = self.collectDataForTask(taskDataInds) 202 | 203 | return task -------------------------------------------------------------------------------- /utils/Trainer.py: -------------------------------------------------------------------------------- 1 | from utils.IntentDataset import IntentDataset 2 | from utils.Evaluator import EvaluatorBase 3 | from utils.Logger import logger 4 | from utils.commonVar import * 5 | from utils.tools import mask_tokens, makeTrainExamples 6 | import time 7 | import torch 8 | from torch.utils.data import DataLoader 9 | import numpy as np 10 | import copy 11 | from sklearn.metrics import accuracy_score, r2_score 12 | from torch.utils.tensorboard import SummaryWriter 13 | 14 | ## 15 | # @brief base class of trainer 16 | class TrainerBase(): 17 | def __init__(self): 18 | self.finished=False 19 | self.bestModelStateDict = None 20 | self.roundN = 4 21 | pass 22 | 23 | def round(self, floatNum): 24 | return round(floatNum, self.roundN) 25 | 26 | def train(self): 27 | raise NotImplementedError("train() is not implemented.") 28 | 29 | def getBestModelStateDict(self): 30 | return self.bestModelStateDict 31 | 32 | ## 33 | # @brief TransferTrainer used to do transfer-training. The training is performed in a supervised manner. All available data is used fo training. By contrast, meta-training is performed by tasks. 34 | class TransferTrainer(TrainerBase): 35 | def __init__(self, 36 | trainingParam:dict, 37 | optimizer, 38 | dataset:IntentDataset, 39 | unlabeled:IntentDataset, 40 | valEvaluator: EvaluatorBase, 41 | testEvaluator:EvaluatorBase): 42 | super(TransferTrainer, self).__init__() 43 | self.epoch = trainingParam['epoch'] 44 | self.batch_size = trainingParam['batch'] 45 | self.validation = trainingParam['validation'] 46 | self.patience = trainingParam['patience'] 47 | self.tensorboard = trainingParam['tensorboard'] 48 | self.mlm = trainingParam['mlm'] 49 | self.lambda_mlm = trainingParam['lambda mlm'] 50 | self.regression = trainingParam['regression'] 51 | 52 | self.dataset = dataset 53 | self.unlabeled = unlabeled 54 | self.optimizer = optimizer 55 | self.valEvaluator = valEvaluator 56 | self.testEvaluator = testEvaluator 57 | 58 | if self.tensorboard: 59 | self.writer = SummaryWriter() 60 | 61 | def train(self, model, tokenizer, mode='multi-class'): 62 | self.bestModelStateDict = copy.deepcopy(model.state_dict()) 63 | durationOverallTrain = 0.0 64 | durationOverallVal = 0.0 65 | valBestAcc = -1 66 | accumulateStep = 0 67 | 68 | # evaluate before training 69 | valAcc, valPre, valRec, valFsc = self.valEvaluator.evaluate(model, tokenizer, mode) 70 | teAcc, tePre, teRec, teFsc = self.testEvaluator.evaluate(model, tokenizer, mode) 71 | logger.info('---- Before training ----') 72 | logger.info("ValAcc %f, Val pre %f, Val rec %f , Val Fsc %f", valAcc, valPre, valRec, valFsc) 73 | logger.info("TestAcc %f, Test pre %f, Test rec %f, Test Fsc %f", teAcc, tePre, teRec, teFsc) 74 | 75 | if mode == 'multi-class': 76 | labTensorData = makeTrainExamples(self.dataset.getTokList(), tokenizer, self.dataset.getLabID(), mode=mode) 77 | else: 78 | logger.error("Invalid model %d"%(mode)) 79 | dataloader = DataLoader(labTensorData, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True) 80 | 81 | if self.mlm: 82 | unlabTensorData = makeTrainExamples(self.unlabeled.getTokList(), tokenizer, mode='unlabel') 83 | unlabeledloader = DataLoader(unlabTensorData, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True) 84 | unlabelediter = iter(unlabeledloader) 85 | 86 | for epoch in range(self.epoch): # an epoch means all sampled tasks are done 87 | model.train() 88 | batchTrAccSum = 0.0 89 | batchTrLossSPSum = 0.0 90 | batchTrLossMLMSum = 0.0 91 | timeEpochStart = time.time() 92 | 93 | for batch in dataloader: 94 | # task data 95 | Y, ids, types, masks = batch 96 | X = {'input_ids':ids.to(model.device), 97 | 'token_type_ids':types.to(model.device), 98 | 'attention_mask':masks.to(model.device)} 99 | 100 | # forward 101 | logits = model(X) 102 | # loss 103 | if self.regression: 104 | lossSP = model.loss_mse(logits, Y.to(model.device)) 105 | else: 106 | lossSP = model.loss_ce(logits, Y.to(model.device)) 107 | 108 | if self.mlm: 109 | try: 110 | ids, types, masks = unlabelediter.next() 111 | except StopIteration: 112 | unlabelediter = iter(unlabeledloader) 113 | ids, types, masks = unlabelediter.next() 114 | X_un = {'input_ids':ids.to(model.device), 115 | 'token_type_ids':types.to(model.device), 116 | 'attention_mask':masks.to(model.device)} 117 | mask_ids, mask_lb = mask_tokens(X_un['input_ids'].cpu(), tokenizer) 118 | X_un = {'input_ids':mask_ids.to(model.device), 119 | 'token_type_ids':X_un['token_type_ids'], 120 | 'attention_mask':X_un['attention_mask']} 121 | lossMLM = model.mlmForward(X_un, mask_lb.to(model.device)) 122 | lossTOT = lossSP + self.lambda_mlm * lossMLM 123 | else: 124 | lossTOT = lossSP 125 | 126 | # backward 127 | self.optimizer.zero_grad() 128 | lossTOT.backward() 129 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 130 | self.optimizer.step() 131 | 132 | # calculate train acc 133 | YTensor = Y.cpu() 134 | logits = logits.detach().clone() 135 | if torch.cuda.is_available(): 136 | logits = logits.cpu() 137 | if self.regression: 138 | predictResult = torch.sigmoid(logits).numpy() 139 | acc = r2_score(YTensor, predictResult) 140 | else: 141 | logits = logits.numpy() 142 | predictResult = np.argmax(logits, 1) 143 | acc = accuracy_score(YTensor, predictResult) 144 | 145 | # accumulate statistics 146 | batchTrAccSum += acc 147 | batchTrLossSPSum += lossSP.item() 148 | if self.mlm: 149 | batchTrLossMLMSum += lossMLM.item() 150 | 151 | # current epoch training done, collect data 152 | durationTrain = self.round(time.time() - timeEpochStart) 153 | durationOverallTrain += durationTrain 154 | batchTrAccAvrg = self.round(batchTrAccSum/len(dataloader)) 155 | batchTrLossSPAvrg = batchTrLossSPSum/len(dataloader) 156 | batchTrLossMLMAvrg = batchTrLossMLMSum/len(dataloader) 157 | 158 | valAcc, valPre, valRec, valFsc = self.valEvaluator.evaluate(model, tokenizer, mode) 159 | teAcc, tePre, teRec, teFsc = self.testEvaluator.evaluate(model, tokenizer, mode) 160 | 161 | # display current epoch's info 162 | logger.info("---- epoch: %d/%d, train_time %f ----", epoch, self.epoch, durationTrain) 163 | logger.info("SPLoss %f, MLMLoss %f, TrainAcc %f", batchTrLossSPAvrg, batchTrLossMLMAvrg, batchTrAccAvrg) 164 | logger.info("ValAcc %f, Val pre %f, Val rec %f , Val Fsc %f", valAcc, valPre, valRec, valFsc) 165 | logger.info("TestAcc %f, Test pre %f, Test rec %f, Test Fsc %f", teAcc, tePre, teRec, teFsc) 166 | if self.tensorboard: 167 | self.writer.add_scalar('train loss', batchTrLossSPAvrg+self.lambda_mlm*batchTrLossMLMAvrg, global_step=epoch) 168 | self.writer.add_scalar('val acc', valAcc, global_step=epoch) 169 | self.writer.add_scalar('test acc', teAcc, global_step=epoch) 170 | 171 | # early stop 172 | if not self.validation: 173 | valAcc = -1 174 | if (valAcc >= valBestAcc): # better validation result 175 | print("[INFO] Find a better model. Val acc: %f -> %f"%(valBestAcc, valAcc)) 176 | valBestAcc = valAcc 177 | accumulateStep = 0 178 | 179 | # cache current model, used for evaluation later 180 | self.bestModelStateDict = copy.deepcopy(model.state_dict()) 181 | else: 182 | accumulateStep += 1 183 | if accumulateStep > self.patience/2: 184 | print('[INFO] accumulateStep: ', accumulateStep) 185 | if accumulateStep == self.patience: # early stop 186 | logger.info('Early stop.') 187 | logger.debug("Overall training time %f", durationOverallTrain) 188 | logger.debug("Overall validation time %f", durationOverallVal) 189 | logger.debug("best_val_acc: %f", valBestAcc) 190 | break 191 | 192 | logger.info("best_val_acc: %f", valBestAcc) 193 | 194 | 195 | ## 196 | # @brief TransferTrainer used to do transfer-training. The training is performed in a supervised manner. All available data is used fo training. By contrast, meta-training is performed by tasks. 197 | class MLMOnlyTrainer(TrainerBase): 198 | def __init__(self, 199 | trainingParam:dict, 200 | optimizer, 201 | dataset:IntentDataset, 202 | unlabeled:IntentDataset, 203 | testEvaluator:EvaluatorBase): 204 | super(MLMOnlyTrainer, self).__init__() 205 | self.epoch = trainingParam['epoch'] 206 | self.batch_size = trainingParam['batch'] 207 | self.tensorboard = trainingParam['tensorboard'] 208 | 209 | self.dataset = dataset 210 | self.unlabeled = unlabeled 211 | self.optimizer = optimizer 212 | self.testEvaluator = testEvaluator 213 | 214 | if self.tensorboard: 215 | self.writer = SummaryWriter() 216 | 217 | def train(self, model, tokenizer): 218 | durationOverallTrain = 0.0 219 | 220 | # evaluate before training 221 | teAcc, tePre, teRec, teFsc = self.testEvaluator.evaluate(model, tokenizer, 'multi-class') 222 | logger.info('---- Before training ----') 223 | logger.info("TestAcc %f, Test pre %f, Test rec %f, Test Fsc %f", teAcc, tePre, teRec, teFsc) 224 | 225 | labTensorData = makeTrainExamples(self.dataset.getTokList(), tokenizer, mode='unlabel') 226 | dataloader = DataLoader(labTensorData, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True) 227 | 228 | for epoch in range(self.epoch): # an epoch means all sampled tasks are done 229 | model.train() 230 | batchTrLossSum = 0.0 231 | timeEpochStart = time.time() 232 | 233 | for batch in dataloader: 234 | # task data 235 | ids, types, masks = batch 236 | X = {'input_ids':ids.to(model.device), 237 | 'token_type_ids':types.to(model.device), 238 | 'attention_mask':masks.to(model.device)} 239 | 240 | # forward 241 | mask_ids, mask_lb = mask_tokens(X['input_ids'].cpu(), tokenizer) 242 | X = {'input_ids':mask_ids.to(model.device), 243 | 'token_type_ids':X['token_type_ids'], 244 | 'attention_mask':X['attention_mask']} 245 | loss = model.mlmForward(X, mask_lb.to(model.device)) 246 | 247 | # backward 248 | self.optimizer.zero_grad() 249 | loss.backward() 250 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 251 | self.optimizer.step() 252 | 253 | durationTrain = self.round(time.time() - timeEpochStart) 254 | durationOverallTrain += durationTrain 255 | batchTrLossAvrg = batchTrLossSum/len(dataloader) 256 | 257 | teAcc, tePre, teRec, teFsc = self.testEvaluator.evaluate(model, tokenizer, 'multi-class') 258 | 259 | # display current epoch's info 260 | logger.info("---- epoch: %d/%d, train_time %f ----", epoch, self.epoch, durationTrain) 261 | logger.info("TrainLoss %f", batchTrLossAvrg) 262 | logger.info("TestAcc %f, Test pre %f, Test rec %f, Test Fsc %f", teAcc, tePre, teRec, teFsc) 263 | if self.tensorboard: 264 | self.writer.add_scalar('train loss', batchTrLossAvrg, global_step=epoch) 265 | self.writer.add_scalar('test acc', teAcc, global_step=epoch) 266 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/utils/__init__.py -------------------------------------------------------------------------------- /utils/commonVar.py: -------------------------------------------------------------------------------- 1 | # embedding name 2 | LM_NAME_FASTTEXT="fasttext" 3 | LM_NAME_BERT_BASE_UNCASED='bert-base-uncased' 4 | 5 | # debug level 6 | LOGGING_LEVEL_CRITICAL="CRITICAL" 7 | LOGGING_LEVEL_ERROR="ERROR" 8 | LOGGING_LEVEL_WARNING="WARNING" 9 | LOGGING_LEVEL_INFO="INFO" 10 | LOGGING_LEVEL_DEBUG="DEBUG" 11 | LOGGING_LEVEL_NOTSET="NOTSET" 12 | 13 | # dir and file name 14 | FILE_NAME_DATASET = "dataset.json" 15 | 16 | # meta-task information 17 | META_TASK_GLB_LABID = "META_TASK_GLB_LABID" 18 | META_TASK_SHOT_GLB_LABID = "META_TASK_SHOT_GLB_LABID" 19 | META_TASK_SHOT_LOC_LABID = "META_TASK_SHOT_LOC_LABID" 20 | META_TASK_SHOT_DATAIND = "META_TASK_SHOT_DATAIND" 21 | META_TASK_SHOT_TOKEN = "META_TASK_SHOT_TOKEN" 22 | META_TASK_SHOT_LAB = "META_TASK_SHOT_LAB" 23 | META_TASK_QUERY_GLB_LABID = "META_TASK_QUERY_GLB_LABID" 24 | META_TASK_QUERY_LOC_LABID = "META_TASK_QUERY_LOC_LABID" 25 | META_TASK_QUERY_DATAIND = "META_TASK_QUERY_DATAIND" 26 | META_TASK_QUERY_TOKEN = "META_TASK_QUERY_TOKEN" 27 | META_TASK_QUERY_LAB = "META_TASK_QUERY_LAB" 28 | 29 | # classifier name for validation and evaluation, meta-evaluation 30 | CLSFIER_LINEAR_REGRESSION = "Linear" 31 | CLSFIER_SVM = "SVM" 32 | CLSFIER_NN = "NN" 33 | CLSFIER_COSINE = "Cosine" 34 | CLSFIER_MULTI_LABEL = "MultiLabel" 35 | 36 | # optmizer name 37 | OPTER_ADAM = "Adam" 38 | OPTER_SGD = "SGD" 39 | 40 | # path 41 | SAVE_PATH = './saved_models' 42 | DATA_PATH = './data' -------------------------------------------------------------------------------- /utils/models.py: -------------------------------------------------------------------------------- 1 | #coding=utf-8 2 | import os 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from transformers import AutoModelForMaskedLM 8 | from utils.commonVar import * 9 | 10 | from sklearn.linear_model import LogisticRegression 11 | from sklearn.svm import SVC 12 | from sklearn.multioutput import MultiOutputClassifier 13 | from sklearn.pipeline import make_pipeline 14 | from sklearn.preprocessing import StandardScaler 15 | 16 | class IntentBERT(nn.Module): 17 | def __init__(self, config): 18 | super(IntentBERT, self).__init__() 19 | self.device = config['device'] 20 | self.LMName = config['LMName'] 21 | self.clsNum = config['clsNumber'] 22 | try: 23 | self.word_embedding = AutoModelForMaskedLM.from_pretrained(self.LMName) 24 | except: 25 | self.word_embedding = AutoModelForMaskedLM.from_pretrained(os.path.join(SAVE_PATH, self.LMName)) 26 | self.linearClsfier = nn.Linear(768, self.clsNum) 27 | self.dropout = nn.Dropout(0.1) # follow the default in bert model 28 | # self.word_embedding = nn.DataParallel(self.word_embedding) 29 | self.word_embedding.to(self.device) 30 | self.linearClsfier.to(self.device) 31 | 32 | def loss_ce(self, logits, Y): 33 | loss = nn.CrossEntropyLoss() 34 | output = loss(logits, Y) 35 | return output 36 | 37 | def loss_mse(self, logits, Y): 38 | loss = nn.MSELoss() 39 | output = loss(torch.sigmoid(logits).squeeze(), Y) 40 | return output 41 | 42 | def loss_kl(self, logits, label): 43 | # KL-div loss 44 | probs = F.log_softmax(logits, dim=1) 45 | # label_probs = F.log_softmax(label, dim=1) 46 | loss = F.kl_div(probs, label, reduction='batchmean') 47 | return loss 48 | 49 | def forward(self, X): 50 | # BERT forward 51 | outputs = self.word_embedding(**X, output_hidden_states=True) 52 | 53 | # extract [CLS] for utterance representation 54 | CLSEmbedding = outputs.hidden_states[-1][:,0] 55 | 56 | # linear classifier 57 | CLSEmbedding = self.dropout(CLSEmbedding) 58 | logits = self.linearClsfier(CLSEmbedding) 59 | 60 | return logits 61 | 62 | def mlmForward(self, X, Y): 63 | # BERT forward 64 | outputs = self.word_embedding(**X, labels=Y) 65 | 66 | return outputs.loss 67 | 68 | def fewShotPredict(self, supportX, supportY, queryX, clsFierName, mode='multi-class'): 69 | # calculate word embedding 70 | # BERT forward 71 | s_embedding = self.word_embedding(**supportX, output_hidden_states=True).hidden_states[-1] 72 | q_embedding = self.word_embedding(**queryX, output_hidden_states=True).hidden_states[-1] 73 | 74 | # extract [CLS] for utterance representation 75 | supportEmbedding = s_embedding[:,0] 76 | queryEmbedding = q_embedding[:,0] 77 | support_features = self.normalize(supportEmbedding).cpu() 78 | query_features = self.normalize(queryEmbedding).cpu() 79 | 80 | # select clsfier 81 | clf = None 82 | if clsFierName == CLSFIER_LINEAR_REGRESSION: 83 | clf = LogisticRegression(penalty='l2', 84 | random_state=0, 85 | C=1.0, 86 | solver='lbfgs', 87 | max_iter=1000, 88 | multi_class='multinomial') 89 | # fit and predict 90 | clf.fit(support_features, supportY) 91 | elif clsFierName == CLSFIER_SVM: 92 | clf = make_pipeline(StandardScaler(), 93 | SVC(gamma='auto',C=1, 94 | kernel='linear', 95 | decision_function_shape='ovr')) 96 | # fit and predict 97 | clf.fit(support_features, supportY) 98 | elif clsFierName == CLSFIER_MULTI_LABEL: 99 | clf = MultiOutputClassifier(LogisticRegression(penalty='l2', 100 | random_state=0, 101 | C=1.0, 102 | solver='liblinear', 103 | max_iter=1000, 104 | multi_class='ovr', 105 | class_weight='balanced')) 106 | 107 | clf.fit(support_features, supportY) 108 | else: 109 | raise NotImplementedError("Not supported clasfier name %s", clsFierName) 110 | 111 | if mode == 'multi-class': 112 | query_pred = clf.predict(query_features) 113 | else: 114 | logger.error("Invalid model %d"%(mode)) 115 | 116 | return query_pred 117 | 118 | def reinit_clsfier(self): 119 | self.linearClsfier.weight.data.normal_(mean=0.0, std=0.02) 120 | self.linearClsfier.bias.data.zero_() 121 | 122 | def set_dropout_layer(self, dropout_rate): 123 | self.dropout = nn.Dropout(dropout_rate) 124 | 125 | def set_linear_layer(self, clsNum): 126 | self.linearClsfier = nn.Linear(768, clsNum) 127 | 128 | def normalize(self, x): 129 | norm = x.pow(2).sum(1, keepdim=True).pow(1. / 2) 130 | out = x.div(norm) 131 | return out 132 | 133 | def NN(self, support, support_ys, query): 134 | """nearest classifier""" 135 | support = np.expand_dims(support.transpose(), 0) 136 | query = np.expand_dims(query, 2) 137 | 138 | diff = np.multiply(query - support, query - support) 139 | distance = diff.sum(1) 140 | min_idx = np.argmin(distance, axis=1) 141 | pred = [support_ys[idx] for idx in min_idx] 142 | return pred 143 | 144 | def CosineClsfier(self, support, support_ys, query): 145 | """Cosine classifier""" 146 | support_norm = np.linalg.norm(support, axis=1, keepdims=True) 147 | support = support / support_norm 148 | query_norm = np.linalg.norm(query, axis=1, keepdims=True) 149 | query = query / query_norm 150 | 151 | cosine_distance = query @ support.transpose() 152 | max_idx = np.argmax(cosine_distance, axis=1) 153 | pred = [support_ys[idx] for idx in max_idx] 154 | return pred 155 | 156 | def save(self, path): 157 | self.word_embedding.save_pretrained(path) 158 | -------------------------------------------------------------------------------- /utils/printHelper.py: -------------------------------------------------------------------------------- 1 | from utils.Logger import logger 2 | import logging 3 | 4 | ## 5 | # @brief print means value, std value and item names 6 | # 7 | # @param meanList: example [1.1, 2.2, 1.2] 8 | # @param stdList: example [0.1, 0.15, 0.001] 9 | # @param itemList: example ['acc', 'pre', 'recall', 'fsc'] 10 | # @param debugLevel: example logging.INFO 11 | # 12 | # @return 13 | def printMeanStd(meanList, stdList, itemList, debugLevel=logging.INFO): 14 | # select logging function 15 | loggingFunc = None 16 | if (debugLevel == logging.INFO): 17 | loggingFunc = logger.info 18 | elif (debugLevel == logging.DEBUG): 19 | loggingFunc = logger.debug 20 | else: 21 | raise NotImplementedError("Not supported logging level.") 22 | 23 | lengthSet = set() 24 | lengthSet.add(len(meanList)) 25 | lengthSet.add(len(stdList)) 26 | lengthSet.add(len(itemList)) 27 | if not len(lengthSet) == 1: 28 | logger.error("Inconsisten list lengths when printing statistics.") 29 | exit(1) 30 | 31 | for mean, std, item in zip(meanList, stdList, itemList): 32 | loggingFunc("%-6s: %f +- %f", item, mean, std) 33 | 34 | 35 | ## 36 | # @brief print means value and item names 37 | # 38 | # @param meanList: example [1.1, 2.2, 1.2] 39 | # @param itemList: example ['acc', 'pre', 'recall', 'fsc'] 40 | # @param debugLevel: example logging.INFO 41 | # 42 | # @return 43 | def printMean(meanList, itemList, debugLevel=logging.INFO): 44 | # select logging function 45 | loggingFunc = None 46 | if (debugLevel == logging.INFO): 47 | loggingFunc = logger.info 48 | elif (debugLevel == logging.DEBUG): 49 | loggingFunc = logger.debug 50 | else: 51 | raise NotImplementedError("Not supported logging level.") 52 | 53 | lengthSet = set() 54 | lengthSet.add(len(meanList)) 55 | lengthSet.add(len(itemList)) 56 | if not len(lengthSet) == 1: 57 | logger.error("Inconsisten list lengths when printing statistics.") 58 | exit(1) 59 | 60 | for mean, item in zip(meanList, itemList): 61 | loggingFunc("%-6s: %f", item, mean) 62 | -------------------------------------------------------------------------------- /utils/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import TensorDataset 3 | import numpy as np 4 | from random import sample 5 | import random 6 | random.seed(0) 7 | 8 | entail_label_map = {'entail':1, 'nonentail':0} 9 | 10 | def getDomainName(name): 11 | if name=="auto_commute": 12 | return "auto commute" 13 | elif name=="credit_cards": 14 | return "credit cards" 15 | elif name=="kitchen_dining": 16 | return "kitchen dining" 17 | elif name=="small_talk": 18 | return "small talk" 19 | elif ' ' not in name: 20 | return name 21 | else: 22 | raise NotImplementedError("Not supported domain name %s"%(name)) 23 | 24 | def splitName(dom): 25 | domList = [] 26 | for name in dom.split(','): 27 | domList.append(getDomainName(name)) 28 | return domList 29 | 30 | def makeTrainExamples(data:list, tokenizer, label=None, mode='unlabel'): 31 | """ 32 | unlabel: simply pad data and then convert into tensor 33 | multi-class: pad data and compose tensor dataset with labels 34 | """ 35 | if mode != "unlabel": 36 | assert label is not None, f"Label is provided for the required setting {mode}" 37 | if mode == "multi-class": 38 | examples = tokenizer.pad(data, padding='longest', return_tensors='pt') 39 | if not isinstance(label, torch.Tensor): 40 | label = torch.tensor(label) 41 | examples = TensorDataset(label, 42 | examples['input_ids'], 43 | examples['token_type_ids'], 44 | examples['attention_mask']) 45 | else: 46 | raise ValueError(f"Undefined setting {mode}") 47 | else: 48 | examples = tokenizer.pad(data, padding='longest', return_tensors='pt') 49 | examples = TensorDataset(examples['input_ids'], 50 | examples['token_type_ids'], 51 | examples['attention_mask']) 52 | return examples 53 | 54 | def makeEvalExamples(supportX, supportY, queryX, queryY, tokenizer, mode='multi-class'): 55 | """ 56 | multi-class: simply pad data 57 | """ 58 | if mode == "multi-class": 59 | supportX = tokenizer.pad(supportX, padding='longest', return_tensors='pt') 60 | queryX = tokenizer.pad(queryX, padding='longest', return_tensors='pt') 61 | else: 62 | raise ValueError("Invalid mode %d."%(mode)) 63 | return supportX, supportY, queryX, queryY 64 | 65 | #https://github.com/huggingface/transformers/blob/master/src/transformers/data/data_collator.py#L70 66 | def mask_tokens(inputs, tokenizer,\ 67 | special_tokens_mask=None, mlm_probability=0.15): 68 | """ 69 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. 70 | """ 71 | labels = inputs.clone() 72 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`) 73 | probability_matrix = torch.full(labels.shape, mlm_probability) 74 | if special_tokens_mask is None: 75 | special_tokens_mask = [ 76 | tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist() 77 | ] 78 | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool) 79 | else: 80 | special_tokens_mask = special_tokens_mask.bool() 81 | 82 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0) 83 | probability_matrix[torch.where(inputs==0)] = 0.0 84 | masked_indices = torch.bernoulli(probability_matrix).bool() 85 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 86 | 87 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 88 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices 89 | inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) 90 | 91 | # 10% of the time, we replace masked input tokens with random word 92 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced 93 | random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long) 94 | inputs[indices_random] = random_words[indices_random] 95 | 96 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 97 | return inputs, labels 98 | --------------------------------------------------------------------------------