├── .DS_Store
├── README.md
├── data
├── .DS_Store
├── bank77
│ ├── dataset.json
│ ├── original
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── banking_data
│ │ │ ├── categories.json
│ │ │ ├── test.csv
│ │ │ └── train.csv
│ │ ├── polyai-logo.png
│ │ └── span_extraction
│ │ │ ├── dstc8
│ │ │ ├── Buses_1
│ │ │ │ ├── test.json
│ │ │ │ ├── train_0.json
│ │ │ │ ├── train_1.json
│ │ │ │ ├── train_2.json
│ │ │ │ ├── train_3.json
│ │ │ │ ├── train_4.json
│ │ │ │ └── train_5.json
│ │ │ ├── Events_1
│ │ │ │ ├── test.json
│ │ │ │ ├── train_0.json
│ │ │ │ ├── train_1.json
│ │ │ │ ├── train_2.json
│ │ │ │ ├── train_3.json
│ │ │ │ ├── train_4.json
│ │ │ │ └── train_5.json
│ │ │ ├── Homes_1
│ │ │ │ ├── test.json
│ │ │ │ ├── train_0.json
│ │ │ │ ├── train_1.json
│ │ │ │ ├── train_2.json
│ │ │ │ ├── train_3.json
│ │ │ │ ├── train_4.json
│ │ │ │ └── train_5.json
│ │ │ ├── RentalCars_1
│ │ │ │ ├── test.json
│ │ │ │ ├── train_0.json
│ │ │ │ ├── train_1.json
│ │ │ │ ├── train_2.json
│ │ │ │ ├── train_3.json
│ │ │ │ ├── train_4.json
│ │ │ │ └── train_5.json
│ │ │ └── stats.csv
│ │ │ └── restaurant8k
│ │ │ ├── test.json
│ │ │ ├── train_0.json
│ │ │ ├── train_1.json
│ │ │ ├── train_2.json
│ │ │ ├── train_3.json
│ │ │ ├── train_4.json
│ │ │ ├── train_5.json
│ │ │ ├── train_6.json
│ │ │ ├── train_7.json
│ │ │ └── train_8.json
│ └── showDataset.py
├── hint3
│ ├── dataset.json
│ ├── original
│ │ ├── readme.rd
│ │ └── v2
│ │ │ ├── test
│ │ │ ├── curekart_test.csv
│ │ │ ├── powerplay11_test.csv
│ │ │ └── sofmattress_test.csv
│ │ │ └── train
│ │ │ ├── curekart_train.csv
│ │ │ ├── powerplay11_train.csv
│ │ │ └── sofmattress_train.csv
│ └── showDataset.py
├── hwu64
│ ├── dataset.json
│ ├── original
│ │ └── NLU-Data-Home-Domain-Annotated-All.csv
│ └── showDataset.py
├── mcid
│ ├── dataset.json
│ ├── original
│ │ ├── README
│ │ ├── de
│ │ │ ├── eval.tsv
│ │ │ ├── test.tsv
│ │ │ └── train.tsv
│ │ ├── en
│ │ │ ├── eval.tsv
│ │ │ ├── test.tsv
│ │ │ └── train.tsv
│ │ ├── es
│ │ │ ├── eval.tsv
│ │ │ ├── test.tsv
│ │ │ └── train.tsv
│ │ ├── fr
│ │ │ ├── eval.tsv
│ │ │ ├── test.tsv
│ │ │ └── train.tsv
│ │ └── spanglish
│ │ │ └── test.tsv
│ └── showDataset.py
└── oos
│ ├── dataset.json
│ ├── original
│ ├── domain_intent.txt
│ └── oos-eval-master
│ │ ├── .gitignore
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── clinc_logo.png
│ │ ├── data
│ │ ├── all_wiki_sents.txt
│ │ ├── binary_undersample.json
│ │ ├── binary_wiki_aug.json
│ │ ├── data_full.json
│ │ ├── data_imbalanced.json
│ │ ├── data_oos_plus.json
│ │ └── data_small.json
│ │ ├── hyperparameters.csv
│ │ ├── paper.pdf
│ │ ├── poster.pdf
│ │ └── supplementary.pdf
│ └── showDataset.py
├── eval.py
├── images
├── combined.png
└── main.png
├── mlm.py
├── scripts
├── eval.sh
├── mlm.sh
└── transfer.sh
├── transfer.py
└── utils
├── Evaluator.py
├── IntentDataset.py
├── Logger.py
├── TaskSampler.py
├── Trainer.py
├── __init__.py
├── commonVar.py
├── models.py
├── printHelper.py
└── tools.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## IntentBERT: Effectiveness of Pre-training for Few-shot Intent Classification
2 |
3 | This repository contains the code and pre-trained models for our paper on *EMNLP-findings*: [Effectiveness of Pre-training for Few-shot Intent Classification](https://arxiv.org/abs/2109.05782). We write this readme thanks to this [repo](https://github.com/princeton-nlp/SimCSE).
4 |
5 | ## Quick Links
6 |
7 | - [Overview](#overview)
8 | - [Train IntentBERT](#train-intentbert)
9 | - [Requirements](#requirements)
10 | - [Evaluation](#evaluation)
11 | - [Dataset](#dataset)
12 | - [Before running](#before-running)
13 | - [Training](#training)
14 | - [Bugs or Questions?](#bugs-or-questions)
15 | - [Citation](#citation)
16 |
17 | ## Overview
18 |
19 | * Is it possible to learn transferable task-specific knowledge to generalize across different domains for intent detection?
20 |
21 | In this paper, we offer a free lunch solution for few-shot intent detection by pre-training on a large publicly available dataset. Our experiment shows significant improvement over previous pre-trained models on a drastically different target domain, which indicates IntentBERT possesses high generalizability and is a ready-to-use model without further fine-tuning. We also propose a joint pre-training scheme (IntentBERT+MLM) to leverage unlabeled data on target domain.
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 | ## Train IntentBERT
32 |
33 | In the following section, we describe how to train a IntentBERT model by using our code.
34 |
35 | ### Requirements
36 |
37 | Run the following script to install the dependencies,
38 |
39 | ```bash
40 | pip install -r requirements.txt
41 | ```
42 |
43 | ### Dataset
44 |
45 | We provide dataset required for training and evaluation in `data` folder of this repo. Specifically, "oos" & "hwu64" are used for training or validation, "bank77", "mcid" & "hint3" are used as target dataset. In each dataset, there is a `showDataset.py`. You can `cd` into the dataset folder and run it to display the statistics and examples of the dataset.
46 |
47 | ```bash
48 | python showDataset.py
49 | ```
50 |
51 | ### Before running
52 |
53 | Set up the path for data and models in `./utils/commonVar.py` as you wish. For example,
54 |
55 | ```python
56 | SAVE_PATH = './saved_models'
57 | DATA_PATH = './data'
58 | ```
59 |
60 | Download the pre-trained IntentBERT model [here](https://1drv.ms/u/s!AsY5oOBeNeY-hCRMnhQQPojqdK8R?e=Ixz4ke), and save under `SAVE_PATH`. The scripts for running experiments are kept in `./scripts`. You can run a script with an argument `debug` for debug mode and `normal` for experiment mode. A log file will save all the outputs into a file under `./log`.
61 |
62 | ### Evaluation
63 | Code for few-shot evaluation is kept in `eval.py` with a corresponding bash script in `./scripts`.
64 |
65 | Run with the default parameters as,
66 | ```bash
67 | ./scripts/eval.sh normal ${cuda_id}
68 | ```
69 |
70 | Necessary arguments for the evaluation script are as follows,
71 |
72 | * `--dataDir`: Directory for evaluation data
73 | * `--targetDomain`: Target domain name for evaluation
74 | * `--shot`: Shot number for each class
75 | * `--LMName`: Language model name to be evaluated. Could be a language model name in huggingface hub or a directory in `SAVE_PATH`
76 |
77 | Change `LMName` to evaluate our provided pre-trained models. We provide four trained models under ./saved_models:
78 | 1. intent-bert-base-uncased
79 | 2. joint-intent-bert-base-uncased-hint3
80 | 3. joint-intent-bert-base-uncased-bank77
81 | 4. joint-intent-bert-base-uncased-mcid
82 |
83 | They are corresponding to 'IntentBERT (OOS)', 'IntentBERT (OOS)+MLM'(on hint3), 'IntentBERT (OOS)+MLM'(on bank77) and 'IntentBERT (OOS)+MLM'(on mcid) in the paper.
84 |
85 |
86 | ### Training
87 |
88 | Both supervised pre-training and joint pre-training can be run by `transfer.py` with a corresponding script in `./scripts`. Important arguments are shown here,
89 | * `--dataDir`: Directory for training, validation and test data, concat with ","
90 | * `--sourceDomain`: Source domain name for training, concat with ","
91 | * `--valDomain`: Validation domain name, concat with ","
92 | * `--targetDomain`: Target domain name for evaluation, concat with ","
93 | * `--shot`: Shot number for each class
94 | * `--tensorboard`: Enable tensorboard
95 | * `--saveModel`: Enable to save model
96 | * `--saveName`: The name you want to specify for the saved model, or "none" to use the default name
97 | * `--validation`: Enable validation, it is turned off automatically while using joint pre-training
98 | * `--mlm`: Enable mlm, enable this while using joint pre-training
99 | * `--LMName`: Languge model name as an initialization. Could be a language model name in huggingface hub or a directory in `SAVE_PATH`
100 |
101 | Note that the results might be different from the reported by 1~3% when training with different seeds.
102 |
103 | **Supervised Pre-training**
104 |
105 | Turn off `mlm` and turn on `validation`. Change the datasets and domain names for different settings.
106 |
107 | **Joint Pre-training**
108 |
109 | Turn on `mlm`, `validation` will be turned off automatically. Change the datasets and domain names for different settings.
110 |
111 | ## Bugs or questions?
112 |
113 | If you have any questions related to the code or the paper, feel free to email Haode (`haode.zhang@connect.poly.hk`) and Yuwei (`zhangyuwei.work@gmail.com`). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker!
114 |
115 | ## Citation
116 |
117 | Please cite our paper if you use IntentBERT in your work:
118 |
119 | ```bibtex
120 | @article{zhang2021effectiveness,
121 | title={Effectiveness of Pre-training for Few-shot Intent Classification},
122 | author={Haode Zhang and Yuwei Zhang and Li-Ming Zhan and Jiaxin Chen and Guangyuan Shi and Xiao-Ming Wu and Albert Y. S. Lam},
123 | journal={arXiv preprint arXiv:2109.05782},
124 | year={2021}
125 | }
126 | ```
127 |
--------------------------------------------------------------------------------
/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/data/.DS_Store
--------------------------------------------------------------------------------
/data/bank77/original/README.md:
--------------------------------------------------------------------------------
1 | [](https://poly-ai.com/)
2 |
3 | # task-specific-datasets
4 |
5 | *A collection of NLU datasets in constrained domains.*
6 |
7 | ## Datasets
8 |
9 | ## Banking
10 |
11 | Dataset composed of online banking queries annotated with their corresponding intents.
12 |
13 | | Dataset statistics | |
14 | | --- | --- |
15 | | Train examples | 10003 |
16 | | Test examples | 3080 |
17 | | Number of intents | 77 |
18 |
19 | | Example Query | Intent |
20 | | --- | --- |
21 | | Is there a way to know when my card will arrive?| card_arrival |
22 | | I think my card is broken | card_not_working |
23 | | I made a mistake and need to cancel a transaction | cancel_transfer |
24 | | Is my card usable anywhere? | card_acceptance |
25 |
26 |
27 | ### Citations
28 |
29 | When using the banking dataset in your work, please cite [Efficient Intent Detection with Dual Sentence Encoders](https://arxiv.org/abs/2003.04807).
30 |
31 | ```bibtex
32 | @inproceedings{Casanueva2020,
33 | author = {I{\~{n}}igo Casanueva and Tadas Temcinas and Daniela Gerz and Matthew Henderson and Ivan Vulic},
34 | title = {Efficient Intent Detection with Dual Sentence Encoders},
35 | year = {2020},
36 | month = {mar},
37 | note = {Data available at https://github.com/PolyAI-LDN/task-specific-datasets},
38 | url = {https://arxiv.org/abs/2003.04807},
39 | booktitle = {Proceedings of the 2nd Workshop on NLP for ConvAI - ACL 2020}
40 | }
41 |
42 | ```
43 |
44 | ## Span Extraction
45 | The directory `span_extraction` contains the data used for the SpanConvert paper.
46 |
47 | A training example looks like:
48 | ```
49 | {
50 | "userInput": {
51 | "text": "I would like a table for one person"
52 | },
53 | "labels": [
54 | {
55 | "slot": "people",
56 | "valueSpan": {
57 | "startIndex": 25,
58 | "endIndex": 35
59 | }
60 | }
61 | ]
62 | }
63 | ```
64 |
65 | In the above example, the span "one person" is the value for the `people` slot.
66 |
67 | The datasets have a structure like this:
68 | ```
69 | ls span_extraction/restaurant8k
70 |
71 | test.json
72 | train_0.json
73 | train_1.json
74 | train_2.json
75 | ...
76 | ```
77 | Where:
78 | * `test.json` contains the examples for evaluation
79 | * `train_0.json` contains all of the training examples
80 | * `train_{i}.json` contains `1/(2^i)`th of the training data.
81 |
82 |
83 | #### Exploring the Span Extraction Datasets
84 | Here's a quick command line demo to explore some of the datasets (requires `jq` and `parallel`)
85 | ```
86 |
87 | # Calculate the number of examples in each json file.
88 | cd span_extraction
89 |
90 | ls -d restaurant8k/*.json | parallel -k 'echo -n "{}," && cat {} | jq length'
91 |
92 | restaurant8k/test.json,3731
93 | restaurant8k/train_0.json,8198
94 | restaurant8k/train_1.json,4099
95 | restaurant8k/train_2.json,2049
96 | restaurant8k/train_3.json,1024
97 | restaurant8k/train_4.json,512
98 | restaurant8k/train_5.json,256
99 | restaurant8k/train_6.json,128
100 | restaurant8k/train_7.json,64
101 | restaurant8k/train_8.json,32
102 |
103 | ls -d dstc8/*/*.json | parallel -k 'echo -n "{}," && cat {} | jq length'
104 | dstc8/Buses_1/test.json,377
105 | dstc8/Buses_1/train_0.json,1133
106 | dstc8/Buses_1/train_1.json,566
107 | dstc8/Buses_1/train_2.json,283
108 | dstc8/Buses_1/train_3.json,141
109 | dstc8/Buses_1/train_4.json,70
110 | dstc8/Events_1/test.json,521
111 | dstc8/Events_1/train_0.json,1498
112 | dstc8/Events_1/train_1.json,749
113 | dstc8/Events_1/train_2.json,374
114 | dstc8/Events_1/train_3.json,187
115 | dstc8/Events_1/train_4.json,93
116 | dstc8/Homes_1/test.json,587
117 | dstc8/Homes_1/train_0.json,2064
118 | dstc8/Homes_1/train_1.json,1032
119 | dstc8/Homes_1/train_2.json,516
120 | dstc8/Homes_1/train_3.json,258
121 | dstc8/Homes_1/train_4.json,129
122 | dstc8/RentalCars_1/test.json,328
123 | dstc8/RentalCars_1/train_0.json,874
124 | dstc8/RentalCars_1/train_1.json,437
125 | dstc8/RentalCars_1/train_2.json,218
126 | dstc8/RentalCars_1/train_3.json,109
127 | dstc8/RentalCars_1/train_4.json,54
128 | ```
129 | ### Citations
130 |
131 | When using the datasets in your work, please cite [the Span-ConveRT paper](https://arxiv.org/abs/2005.08866).
132 |
133 | ```bibtex
134 | @inproceedings{CoopeFarghly2020,
135 | Author = {Sam Coope and Tyler Farghly and Daniela Gerz and Ivan Vulić and Matthew Henderson},
136 | Title = {Span-ConveRT: Few-shot Span Extraction for Dialog with Pretrained Conversational Representations},
137 | Year = {2020},
138 | url = {https://arxiv.org/abs/2005.08866},
139 | publisher = {ACL},
140 | }
141 |
142 | ```
143 |
144 |
145 | ## License
146 | The datasets shared on this repository are licensed under the license found in the LICENSE file.
147 |
--------------------------------------------------------------------------------
/data/bank77/original/banking_data/categories.json:
--------------------------------------------------------------------------------
1 | [
2 | "card_arrival",
3 | "card_linking",
4 | "exchange_rate",
5 | "card_payment_wrong_exchange_rate",
6 | "extra_charge_on_statement",
7 | "pending_cash_withdrawal",
8 | "fiat_currency_support",
9 | "card_delivery_estimate",
10 | "automatic_top_up",
11 | "card_not_working",
12 | "exchange_via_app",
13 | "lost_or_stolen_card",
14 | "age_limit",
15 | "pin_blocked",
16 | "contactless_not_working",
17 | "top_up_by_bank_transfer_charge",
18 | "pending_top_up",
19 | "cancel_transfer",
20 | "top_up_limits",
21 | "wrong_amount_of_cash_received",
22 | "card_payment_fee_charged",
23 | "transfer_not_received_by_recipient",
24 | "supported_cards_and_currencies",
25 | "getting_virtual_card",
26 | "card_acceptance",
27 | "top_up_reverted",
28 | "balance_not_updated_after_cheque_or_cash_deposit",
29 | "card_payment_not_recognised",
30 | "edit_personal_details",
31 | "why_verify_identity",
32 | "unable_to_verify_identity",
33 | "get_physical_card",
34 | "visa_or_mastercard",
35 | "topping_up_by_card",
36 | "disposable_card_limits",
37 | "compromised_card",
38 | "atm_support",
39 | "direct_debit_payment_not_recognised",
40 | "passcode_forgotten",
41 | "declined_cash_withdrawal",
42 | "pending_card_payment",
43 | "lost_or_stolen_phone",
44 | "request_refund",
45 | "declined_transfer",
46 | "Refund_not_showing_up",
47 | "declined_card_payment",
48 | "pending_transfer",
49 | "terminate_account",
50 | "card_swallowed",
51 | "transaction_charged_twice",
52 | "verify_source_of_funds",
53 | "transfer_timing",
54 | "reverted_card_payment?",
55 | "change_pin",
56 | "beneficiary_not_allowed",
57 | "transfer_fee_charged",
58 | "receiving_money",
59 | "failed_transfer",
60 | "transfer_into_account",
61 | "verify_top_up",
62 | "getting_spare_card",
63 | "top_up_by_cash_or_cheque",
64 | "order_physical_card",
65 | "virtual_card_not_working",
66 | "wrong_exchange_rate_for_cash_withdrawal",
67 | "get_disposable_virtual_card",
68 | "top_up_failed",
69 | "balance_not_updated_after_bank_transfer",
70 | "cash_withdrawal_not_recognised",
71 | "exchange_charge",
72 | "top_up_by_card_charge",
73 | "activate_my_card",
74 | "cash_withdrawal_charge",
75 | "card_about_to_expire",
76 | "apple_pay_or_google_pay",
77 | "verify_my_identity",
78 | "country_support"
79 | ]
--------------------------------------------------------------------------------
/data/bank77/original/polyai-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/data/bank77/original/polyai-logo.png
--------------------------------------------------------------------------------
/data/bank77/original/span_extraction/dstc8/Buses_1/train_5.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "userInput": {
4 | "text": "Can you help me out and look for a bus for me?"
5 | },
6 | "context": {},
7 | "id": "27_001080",
8 | "splitKey": 1.0
9 | },
10 | {
11 | "userInput": {
12 | "text": "I want to visit New York."
13 | },
14 | "context": {
15 | "requestedSlots": [
16 | "to_location"
17 | ]
18 | },
19 | "labels": [
20 | {
21 | "slot": "to_location",
22 | "valueSpan": {
23 | "startIndex": 16,
24 | "endIndex": 24
25 | }
26 | }
27 | ],
28 | "id": "27_001082",
29 | "splitKey": 1.0
30 | },
31 | {
32 | "userInput": {
33 | "text": "I am leaving out from Washington on the 11th of March."
34 | },
35 | "context": {
36 | "requestedSlots": [
37 | "leaving_date",
38 | "from_location"
39 | ]
40 | },
41 | "labels": [
42 | {
43 | "slot": "leaving_date",
44 | "valueSpan": {
45 | "startIndex": 40,
46 | "endIndex": 53
47 | }
48 | },
49 | {
50 | "slot": "from_location",
51 | "valueSpan": {
52 | "startIndex": 22,
53 | "endIndex": 32
54 | }
55 | }
56 | ],
57 | "id": "27_001084",
58 | "splitKey": 1.0
59 | },
60 | {
61 | "userInput": {
62 | "text": "What is the bus station that I will leave from and which station will I arrrive at?"
63 | },
64 | "context": {},
65 | "id": "27_001086",
66 | "splitKey": 1.0
67 | },
68 | {
69 | "userInput": {
70 | "text": "Sounds fine to me."
71 | },
72 | "context": {},
73 | "id": "27_001088",
74 | "splitKey": 1.0
75 | },
76 | {
77 | "userInput": {
78 | "text": "Yep, I would like to get tickets."
79 | },
80 | "context": {},
81 | "id": "27_0010810",
82 | "splitKey": 1.0
83 | },
84 | {
85 | "userInput": {
86 | "text": "I need it for two people."
87 | },
88 | "context": {
89 | "requestedSlots": [
90 | "travelers"
91 | ]
92 | },
93 | "id": "27_0010812",
94 | "splitKey": 1.0
95 | },
96 | {
97 | "userInput": {
98 | "text": "That's not right, it's for 4 people."
99 | },
100 | "context": {},
101 | "id": "27_0010814",
102 | "splitKey": 1.0
103 | },
104 | {
105 | "userInput": {
106 | "text": "Yep, that sounds good to me."
107 | },
108 | "context": {},
109 | "id": "27_0010816",
110 | "splitKey": 1.0
111 | },
112 | {
113 | "userInput": {
114 | "text": "Thanks so much for your help."
115 | },
116 | "context": {},
117 | "id": "27_0010818",
118 | "splitKey": 1.0
119 | },
120 | {
121 | "userInput": {
122 | "text": "Nope, thanks again for helping me."
123 | },
124 | "context": {},
125 | "id": "27_0010820",
126 | "splitKey": 1.0
127 | },
128 | {
129 | "userInput": {
130 | "text": "Is there a bus that leaves from Anaheim to SD?"
131 | },
132 | "context": {},
133 | "labels": [
134 | {
135 | "slot": "to_location",
136 | "valueSpan": {
137 | "startIndex": 43,
138 | "endIndex": 45
139 | }
140 | },
141 | {
142 | "slot": "from_location",
143 | "valueSpan": {
144 | "startIndex": 32,
145 | "endIndex": 39
146 | }
147 | }
148 | ],
149 | "id": "27_001090",
150 | "splitKey": 1.0
151 | },
152 | {
153 | "userInput": {
154 | "text": "Thursday next week for four people."
155 | },
156 | "context": {
157 | "requestedSlots": [
158 | "leaving_date"
159 | ]
160 | },
161 | "labels": [
162 | {
163 | "slot": "leaving_date",
164 | "valueSpan": {
165 | "endIndex": 18
166 | }
167 | }
168 | ],
169 | "id": "27_001092",
170 | "splitKey": 1.0
171 | },
172 | {
173 | "userInput": {
174 | "text": "Anything else for three people from LAX?"
175 | },
176 | "context": {},
177 | "labels": [
178 | {
179 | "slot": "from_location",
180 | "valueSpan": {
181 | "startIndex": 36,
182 | "endIndex": 39
183 | }
184 | }
185 | ],
186 | "id": "27_001094",
187 | "splitKey": 1.0
188 | },
189 | {
190 | "userInput": {
191 | "text": "Sounds great. I would like to reserve it."
192 | },
193 | "context": {},
194 | "id": "27_001096",
195 | "splitKey": 1.0
196 | },
197 | {
198 | "userInput": {
199 | "text": "Sounds great."
200 | },
201 | "context": {},
202 | "id": "27_001098",
203 | "splitKey": 1.0
204 | },
205 | {
206 | "userInput": {
207 | "text": "Thanks, that's all for now."
208 | },
209 | "context": {},
210 | "id": "27_0010910",
211 | "splitKey": 1.0
212 | },
213 | {
214 | "userInput": {
215 | "text": "I'm thinking of going out of town. Could you please help me find a bus?"
216 | },
217 | "context": {},
218 | "id": "27_001100",
219 | "splitKey": 1.0
220 | },
221 | {
222 | "userInput": {
223 | "text": "I'd like to leave on 10th of March."
224 | },
225 | "context": {
226 | "requestedSlots": [
227 | "leaving_date"
228 | ]
229 | },
230 | "labels": [
231 | {
232 | "slot": "leaving_date",
233 | "valueSpan": {
234 | "startIndex": 21,
235 | "endIndex": 34
236 | }
237 | }
238 | ],
239 | "id": "27_001102",
240 | "splitKey": 1.0
241 | },
242 | {
243 | "userInput": {
244 | "text": "I want to travel from Washington to Philly."
245 | },
246 | "context": {
247 | "requestedSlots": [
248 | "to_location",
249 | "from_location"
250 | ]
251 | },
252 | "labels": [
253 | {
254 | "slot": "to_location",
255 | "valueSpan": {
256 | "startIndex": 36,
257 | "endIndex": 42
258 | }
259 | },
260 | {
261 | "slot": "from_location",
262 | "valueSpan": {
263 | "startIndex": 22,
264 | "endIndex": 32
265 | }
266 | }
267 | ],
268 | "id": "27_001104",
269 | "splitKey": 1.0
270 | },
271 | {
272 | "userInput": {
273 | "text": "What's my destination station?"
274 | },
275 | "context": {},
276 | "id": "27_001106",
277 | "splitKey": 1.0
278 | },
279 | {
280 | "userInput": {
281 | "text": "Please find out whether there are other available buses, for a group of three."
282 | },
283 | "context": {},
284 | "id": "27_001108",
285 | "splitKey": 1.0
286 | },
287 | {
288 | "userInput": {
289 | "text": "Please let me know my departure and destination stations."
290 | },
291 | "context": {},
292 | "id": "27_0011010",
293 | "splitKey": 1.0
294 | },
295 | {
296 | "userInput": {
297 | "text": "Okay, that sounds perfect. Please make a reservation on the bus."
298 | },
299 | "context": {},
300 | "id": "27_0011012",
301 | "splitKey": 1.0
302 | },
303 | {
304 | "userInput": {
305 | "text": "Yes, sounds good."
306 | },
307 | "context": {},
308 | "id": "27_0011014",
309 | "splitKey": 1.0
310 | },
311 | {
312 | "userInput": {
313 | "text": "I'm really grateful for your help. That will be all."
314 | },
315 | "context": {},
316 | "id": "27_0011016",
317 | "splitKey": 1.0
318 | },
319 | {
320 | "userInput": {
321 | "text": "Can you find me a bus?"
322 | },
323 | "context": {},
324 | "id": "27_001110",
325 | "splitKey": 1.0
326 | },
327 | {
328 | "userInput": {
329 | "text": "I want to go from San Francisco to Los Angeles on the 9th"
330 | },
331 | "context": {
332 | "requestedSlots": [
333 | "from_location",
334 | "leaving_date",
335 | "to_location"
336 | ]
337 | },
338 | "labels": [
339 | {
340 | "slot": "to_location",
341 | "valueSpan": {
342 | "startIndex": 35,
343 | "endIndex": 46
344 | }
345 | },
346 | {
347 | "slot": "from_location",
348 | "valueSpan": {
349 | "startIndex": 18,
350 | "endIndex": 31
351 | }
352 | },
353 | {
354 | "slot": "leaving_date",
355 | "valueSpan": {
356 | "startIndex": 50,
357 | "endIndex": 57
358 | }
359 | }
360 | ],
361 | "id": "27_001112",
362 | "splitKey": 1.0
363 | },
364 | {
365 | "userInput": {
366 | "text": "Which station am I going from?"
367 | },
368 | "context": {},
369 | "id": "27_001114",
370 | "splitKey": 1.0
371 | },
372 | {
373 | "userInput": {
374 | "text": "That's good thanks."
375 | },
376 | "context": {},
377 | "id": "27_001116",
378 | "splitKey": 1.0
379 | },
380 | {
381 | "userInput": {
382 | "text": "Yes please reserve them."
383 | },
384 | "context": {},
385 | "id": "27_001118",
386 | "splitKey": 1.0
387 | },
388 | {
389 | "userInput": {
390 | "text": "I need 2 please."
391 | },
392 | "context": {
393 | "requestedSlots": [
394 | "travelers"
395 | ]
396 | },
397 | "id": "27_0011110",
398 | "splitKey": 1.0
399 | },
400 | {
401 | "userInput": {
402 | "text": "No I need 3 seats."
403 | },
404 | "context": {},
405 | "id": "27_0011112",
406 | "splitKey": 1.0
407 | },
408 | {
409 | "userInput": {
410 | "text": "Yes that's good, which bus station does it go to?"
411 | },
412 | "context": {},
413 | "id": "27_0011114",
414 | "splitKey": 1.0
415 | },
416 | {
417 | "userInput": {
418 | "text": "Thanks a lot"
419 | },
420 | "context": {},
421 | "id": "27_0011116",
422 | "splitKey": 1.0
423 | }
424 | ]
--------------------------------------------------------------------------------
/data/bank77/original/span_extraction/dstc8/RentalCars_1/train_5.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "userInput": {
4 | "text": "Book me a rental car."
5 | },
6 | "context": {},
7 | "id": "23_000170",
8 | "splitKey": 1.0
9 | },
10 | {
11 | "userInput": {
12 | "text": "I'll pick it up around 10:30 in the morning, keeping it until next Thursday"
13 | },
14 | "context": {
15 | "requestedSlots": [
16 | "pickup_time",
17 | "dropoff_date"
18 | ]
19 | },
20 | "labels": [
21 | {
22 | "slot": "dropoff_date",
23 | "valueSpan": {
24 | "startIndex": 62,
25 | "endIndex": 75
26 | }
27 | },
28 | {
29 | "slot": "pickup_time",
30 | "valueSpan": {
31 | "startIndex": 23,
32 | "endIndex": 43
33 | }
34 | }
35 | ],
36 | "id": "23_000172",
37 | "splitKey": 1.0
38 | },
39 | {
40 | "userInput": {
41 | "text": "I'll pick it up in New York on March 4th"
42 | },
43 | "context": {
44 | "requestedSlots": [
45 | "pickup_city",
46 | "pickup_date"
47 | ]
48 | },
49 | "labels": [
50 | {
51 | "slot": "pickup_city",
52 | "valueSpan": {
53 | "startIndex": 19,
54 | "endIndex": 27
55 | }
56 | },
57 | {
58 | "slot": "pickup_date",
59 | "valueSpan": {
60 | "startIndex": 31,
61 | "endIndex": 40
62 | }
63 | }
64 | ],
65 | "id": "23_000174",
66 | "splitKey": 1.0
67 | },
68 | {
69 | "userInput": {
70 | "text": "Sounds good"
71 | },
72 | "context": {},
73 | "id": "23_000176",
74 | "splitKey": 1.0
75 | },
76 | {
77 | "userInput": {
78 | "text": "No not now"
79 | },
80 | "context": {},
81 | "id": "23_000178",
82 | "splitKey": 1.0
83 | },
84 | {
85 | "userInput": {
86 | "text": "No thanks so much"
87 | },
88 | "context": {},
89 | "id": "23_0001710",
90 | "splitKey": 1.0
91 | },
92 | {
93 | "userInput": {
94 | "text": "I want to find a rental car for March 6th, pick up in New York."
95 | },
96 | "context": {},
97 | "labels": [
98 | {
99 | "slot": "pickup_city",
100 | "valueSpan": {
101 | "startIndex": 54,
102 | "endIndex": 62
103 | }
104 | },
105 | {
106 | "slot": "pickup_date",
107 | "valueSpan": {
108 | "startIndex": 32,
109 | "endIndex": 41
110 | }
111 | }
112 | ],
113 | "id": "23_000180",
114 | "splitKey": 1.0
115 | },
116 | {
117 | "userInput": {
118 | "text": "I need a pickup time around 18:30 and I will need the car until the 7th."
119 | },
120 | "context": {
121 | "requestedSlots": [
122 | "dropoff_date"
123 | ]
124 | },
125 | "labels": [
126 | {
127 | "slot": "pickup_time",
128 | "valueSpan": {
129 | "startIndex": 28,
130 | "endIndex": 33
131 | }
132 | },
133 | {
134 | "slot": "dropoff_date",
135 | "valueSpan": {
136 | "startIndex": 64,
137 | "endIndex": 71
138 | }
139 | }
140 | ],
141 | "id": "23_000182",
142 | "splitKey": 1.0
143 | },
144 | {
145 | "userInput": {
146 | "text": "Can you find anything else?"
147 | },
148 | "context": {},
149 | "id": "23_000184",
150 | "splitKey": 1.0
151 | },
152 | {
153 | "userInput": {
154 | "text": "That would be perfect."
155 | },
156 | "context": {},
157 | "id": "23_000186",
158 | "splitKey": 1.0
159 | },
160 | {
161 | "userInput": {
162 | "text": "No, I don't want to reserve it yet. That's all I need."
163 | },
164 | "context": {},
165 | "id": "23_000188",
166 | "splitKey": 1.0
167 | },
168 | {
169 | "userInput": {
170 | "text": "My usual car is being repaired right now. I'm in need of a rental car for the interim."
171 | },
172 | "context": {},
173 | "id": "23_000190",
174 | "splitKey": 1.0
175 | },
176 | {
177 | "userInput": {
178 | "text": "I'll be needing it from March 9th to March 14th."
179 | },
180 | "context": {
181 | "requestedSlots": [
182 | "dropoff_date",
183 | "pickup_date"
184 | ]
185 | },
186 | "labels": [
187 | {
188 | "slot": "pickup_date",
189 | "valueSpan": {
190 | "startIndex": 24,
191 | "endIndex": 33
192 | }
193 | },
194 | {
195 | "slot": "dropoff_date",
196 | "valueSpan": {
197 | "startIndex": 37,
198 | "endIndex": 47
199 | }
200 | }
201 | ],
202 | "id": "23_000192",
203 | "splitKey": 1.0
204 | },
205 | {
206 | "userInput": {
207 | "text": "I'd like to pick it up from Phoenix."
208 | },
209 | "context": {
210 | "requestedSlots": [
211 | "pickup_city"
212 | ]
213 | },
214 | "labels": [
215 | {
216 | "slot": "pickup_city",
217 | "valueSpan": {
218 | "startIndex": 28,
219 | "endIndex": 35
220 | }
221 | }
222 | ],
223 | "id": "23_000194",
224 | "splitKey": 1.0
225 | },
226 | {
227 | "userInput": {
228 | "text": "I'll have to pick it up around 4:30 pm."
229 | },
230 | "context": {
231 | "requestedSlots": [
232 | "pickup_time"
233 | ]
234 | },
235 | "labels": [
236 | {
237 | "slot": "pickup_time",
238 | "valueSpan": {
239 | "startIndex": 31,
240 | "endIndex": 38
241 | }
242 | }
243 | ],
244 | "id": "23_000196",
245 | "splitKey": 1.0
246 | },
247 | {
248 | "userInput": {
249 | "text": "Sounds great to me."
250 | },
251 | "context": {},
252 | "id": "23_000198",
253 | "splitKey": 1.0
254 | },
255 | {
256 | "userInput": {
257 | "text": "Yes, I want to rent this Accord."
258 | },
259 | "context": {},
260 | "id": "23_0001910",
261 | "splitKey": 1.0
262 | },
263 | {
264 | "userInput": {
265 | "text": "That's all correct. What will it cost me?"
266 | },
267 | "context": {},
268 | "id": "23_0001912",
269 | "splitKey": 1.0
270 | },
271 | {
272 | "userInput": {
273 | "text": "Thanks for the help."
274 | },
275 | "context": {},
276 | "id": "23_0001914",
277 | "splitKey": 1.0
278 | },
279 | {
280 | "userInput": {
281 | "text": "No, that's all. Thanks for helping."
282 | },
283 | "context": {},
284 | "id": "23_0001916",
285 | "splitKey": 1.0
286 | },
287 | {
288 | "userInput": {
289 | "text": "I need a car for next Thursday"
290 | },
291 | "context": {},
292 | "labels": [
293 | {
294 | "slot": "dropoff_date",
295 | "valueSpan": {
296 | "startIndex": 17,
297 | "endIndex": 30
298 | }
299 | }
300 | ],
301 | "id": "23_000200",
302 | "splitKey": 1.0
303 | },
304 | {
305 | "userInput": {
306 | "text": "6 in the evening please."
307 | },
308 | "context": {
309 | "requestedSlots": [
310 | "pickup_time"
311 | ]
312 | },
313 | "labels": [
314 | {
315 | "slot": "pickup_time",
316 | "valueSpan": {
317 | "endIndex": 16
318 | }
319 | }
320 | ],
321 | "id": "23_000202",
322 | "splitKey": 1.0
323 | },
324 | {
325 | "userInput": {
326 | "text": "the 5th of March"
327 | },
328 | "context": {
329 | "requestedSlots": [
330 | "pickup_date"
331 | ]
332 | },
333 | "labels": [
334 | {
335 | "slot": "pickup_date",
336 | "valueSpan": {
337 | "startIndex": 4,
338 | "endIndex": 16
339 | }
340 | }
341 | ],
342 | "id": "23_000204",
343 | "splitKey": 1.0
344 | },
345 | {
346 | "userInput": {
347 | "text": "Yes, from Sacramento, Ca on the 4th of March"
348 | },
349 | "context": {
350 | "requestedSlots": [
351 | "pickup_city"
352 | ]
353 | },
354 | "labels": [
355 | {
356 | "slot": "pickup_city",
357 | "valueSpan": {
358 | "startIndex": 10,
359 | "endIndex": 24
360 | }
361 | },
362 | {
363 | "slot": "pickup_date",
364 | "valueSpan": {
365 | "startIndex": 32,
366 | "endIndex": 44
367 | }
368 | }
369 | ],
370 | "id": "23_000206",
371 | "splitKey": 1.0
372 | },
373 | {
374 | "userInput": {
375 | "text": "How much?"
376 | },
377 | "context": {},
378 | "id": "23_000208",
379 | "splitKey": 1.0
380 | },
381 | {
382 | "userInput": {
383 | "text": "What else is there, I'd like arrange for it to be picked up in New York."
384 | },
385 | "context": {},
386 | "labels": [
387 | {
388 | "slot": "pickup_city",
389 | "valueSpan": {
390 | "startIndex": 63,
391 | "endIndex": 71
392 | }
393 | }
394 | ],
395 | "id": "23_0002010",
396 | "splitKey": 1.0
397 | },
398 | {
399 | "userInput": {
400 | "text": "How much?"
401 | },
402 | "context": {},
403 | "id": "23_0002012",
404 | "splitKey": 1.0
405 | }
406 | ]
--------------------------------------------------------------------------------
/data/bank77/original/span_extraction/dstc8/stats.csv:
--------------------------------------------------------------------------------
1 | Buses_1/test.json,377
2 | Buses_1/train_0.json,1133
3 | Buses_1/train_1.json,566
4 | Buses_1/train_2.json,283
5 | Buses_1/train_3.json,141
6 | Buses_1/train_4.json,70
7 |
8 | Events_1/test.json,521
9 | Events_1/train_0.json,1498
10 | Events_1/train_1.json,749
11 | Events_1/train_2.json,374
12 | Events_1/train_3.json,187
13 | Events_1/train_4.json,93
14 |
15 | Homes_1/test.json,587
16 | Homes_1/train_0.json,2064
17 | Homes_1/train_1.json,1032
18 | Homes_1/train_2.json,516
19 | Homes_1/train_3.json,258
20 | Homes_1/train_4.json,129
21 |
22 | RentalCars_1/test.json,328
23 | RentalCars_1/train_0.json,874
24 | RentalCars_1/train_1.json,437
25 | RentalCars_1/train_2.json,218
26 | RentalCars_1/train_3.json,109
27 | RentalCars_1/train_4.json,54
28 |
--------------------------------------------------------------------------------
/data/bank77/original/span_extraction/restaurant8k/train_8.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "userInput": {
4 | "text": "There will be 5 adults and 1 child."
5 | },
6 | "context": {
7 | "requestedSlots": [
8 | "people"
9 | ]
10 | },
11 | "labels": [
12 | {
13 | "slot": "people",
14 | "valueSpan": {
15 | "startIndex": 14,
16 | "endIndex": 34
17 | }
18 | }
19 | ]
20 | },
21 | {
22 | "userInput": {
23 | "text": "We will require and outside table to seat 9 people on August 23rd"
24 | },
25 | "labels": [
26 | {
27 | "slot": "people",
28 | "valueSpan": {
29 | "startIndex": 42,
30 | "endIndex": 50
31 | }
32 | },
33 | {
34 | "slot": "date",
35 | "valueSpan": {
36 | "startIndex": 54,
37 | "endIndex": 65
38 | }
39 | }
40 | ]
41 | },
42 | {
43 | "userInput": {
44 | "text": "Do you have room for 11 of us?"
45 | },
46 | "labels": [
47 | {
48 | "slot": "people",
49 | "valueSpan": {
50 | "startIndex": 21,
51 | "endIndex": 23
52 | }
53 | }
54 | ]
55 | },
56 | {
57 | "userInput": {
58 | "text": "We are 13 and have a table."
59 | },
60 | "labels": [
61 | {
62 | "slot": "people",
63 | "valueSpan": {
64 | "startIndex": 7,
65 | "endIndex": 9
66 | }
67 | }
68 | ]
69 | },
70 | {
71 | "userInput": {
72 | "text": "6 a.m."
73 | },
74 | "context": {
75 | "requestedSlots": [
76 | "time"
77 | ]
78 | },
79 | "labels": [
80 | {
81 | "slot": "time",
82 | "valueSpan": {
83 | "endIndex": 6
84 | }
85 | }
86 | ]
87 | },
88 | {
89 | "userInput": {
90 | "text": "Can i chnage my booking from 18:00 for 3 people to 17:45 for 4 people?"
91 | },
92 | "labels": [
93 | {
94 | "slot": "time",
95 | "valueSpan": {
96 | "startIndex": 51,
97 | "endIndex": 56
98 | }
99 | },
100 | {
101 | "slot": "people",
102 | "valueSpan": {
103 | "startIndex": 61,
104 | "endIndex": 69
105 | }
106 | }
107 | ]
108 | },
109 | {
110 | "userInput": {
111 | "text": "I do not want a restaurant located in Crystal Palace, but in West Central."
112 | }
113 | },
114 | {
115 | "userInput": {
116 | "text": "Can I book for in 14 days?"
117 | },
118 | "labels": [
119 | {
120 | "slot": "date",
121 | "valueSpan": {
122 | "startIndex": 15,
123 | "endIndex": 25
124 | }
125 | }
126 | ]
127 | },
128 | {
129 | "userInput": {
130 | "text": "Please book this hotel for 10 people and on Sun, 05 Aug 2018"
131 | },
132 | "labels": [
133 | {
134 | "slot": "people",
135 | "valueSpan": {
136 | "startIndex": 27,
137 | "endIndex": 36
138 | }
139 | },
140 | {
141 | "slot": "date",
142 | "valueSpan": {
143 | "startIndex": 44,
144 | "endIndex": 60
145 | }
146 | }
147 | ]
148 | },
149 | {
150 | "userInput": {
151 | "text": "30 minutes before 12pm."
152 | },
153 | "context": {
154 | "requestedSlots": [
155 | "time"
156 | ]
157 | },
158 | "labels": [
159 | {
160 | "slot": "time",
161 | "valueSpan": {
162 | "endIndex": 23
163 | }
164 | }
165 | ]
166 | },
167 | {
168 | "userInput": {
169 | "text": "At 11:45AM please"
170 | },
171 | "labels": [
172 | {
173 | "slot": "time",
174 | "valueSpan": {
175 | "startIndex": 3,
176 | "endIndex": 10
177 | }
178 | }
179 | ]
180 | },
181 | {
182 | "userInput": {
183 | "text": "I am looking for Polish cuisine."
184 | }
185 | },
186 | {
187 | "userInput": {
188 | "text": "Yes. Please place your reservation."
189 | }
190 | },
191 | {
192 | "userInput": {
193 | "text": "I booked a table for this evening at 08:30PM for Vashti Sampaia. We're going to reschedule because a couple of the party have pulled out."
194 | },
195 | "context": {},
196 | "labels": [
197 | {
198 | "slot": "date",
199 | "valueSpan": {
200 | "startIndex": 21,
201 | "endIndex": 33
202 | }
203 | },
204 | {
205 | "slot": "time",
206 | "valueSpan": {
207 | "startIndex": 37,
208 | "endIndex": 44
209 | }
210 | },
211 | {
212 | "slot": "first_name",
213 | "valueSpan": {
214 | "startIndex": 49,
215 | "endIndex": 55
216 | }
217 | },
218 | {
219 | "slot": "last_name",
220 | "valueSpan": {
221 | "startIndex": 56,
222 | "endIndex": 63
223 | }
224 | }
225 | ]
226 | },
227 | {
228 | "userInput": {
229 | "text": "8:30 is also fine"
230 | },
231 | "context": {
232 | "requestedSlots": [
233 | "time"
234 | ]
235 | },
236 | "labels": [
237 | {
238 | "slot": "time",
239 | "valueSpan": {
240 | "endIndex": 4
241 | }
242 | }
243 | ]
244 | },
245 | {
246 | "userInput": {
247 | "text": "the 31st"
248 | },
249 | "context": {
250 | "requestedSlots": [
251 | "date"
252 | ]
253 | },
254 | "labels": [
255 | {
256 | "slot": "date",
257 | "valueSpan": {
258 | "startIndex": 4,
259 | "endIndex": 8
260 | }
261 | }
262 | ]
263 | },
264 | {
265 | "userInput": {
266 | "text": "I am looking for an outside table."
267 | }
268 | },
269 | {
270 | "userInput": {
271 | "text": "What is the address for the Hardy's in St James?"
272 | }
273 | },
274 | {
275 | "userInput": {
276 | "text": "I would like to sit outside in 4 days at 10:45AM."
277 | },
278 | "labels": [
279 | {
280 | "slot": "date",
281 | "valueSpan": {
282 | "startIndex": 28,
283 | "endIndex": 37
284 | }
285 | },
286 | {
287 | "slot": "time",
288 | "valueSpan": {
289 | "startIndex": 41,
290 | "endIndex": 49
291 | }
292 | }
293 | ]
294 | },
295 | {
296 | "userInput": {
297 | "text": "I would like to book for 4 people."
298 | },
299 | "context": {
300 | "requestedSlots": [
301 | "people"
302 | ]
303 | },
304 | "labels": [
305 | {
306 | "slot": "people",
307 | "valueSpan": {
308 | "startIndex": 25,
309 | "endIndex": 33
310 | }
311 | }
312 | ]
313 | },
314 | {
315 | "userInput": {
316 | "text": "11 pm"
317 | },
318 | "context": {
319 | "requestedSlots": [
320 | "time"
321 | ]
322 | },
323 | "labels": [
324 | {
325 | "slot": "time",
326 | "valueSpan": {
327 | "endIndex": 5
328 | }
329 | }
330 | ]
331 | },
332 | {
333 | "userInput": {
334 | "text": "my friend is lactose intolerant"
335 | }
336 | },
337 | {
338 | "userInput": {
339 | "text": "Do you have any outside tables?"
340 | }
341 | },
342 | {
343 | "userInput": {
344 | "text": "19th of November"
345 | },
346 | "context": {
347 | "requestedSlots": [
348 | "date"
349 | ]
350 | },
351 | "labels": [
352 | {
353 | "slot": "date",
354 | "valueSpan": {
355 | "endIndex": 16
356 | }
357 | }
358 | ]
359 | },
360 | {
361 | "userInput": {
362 | "text": "Can I have a table inside?"
363 | }
364 | },
365 | {
366 | "userInput": {
367 | "text": "I would like to find out information about the restaurant called Dulwich Wood House."
368 | }
369 | },
370 | {
371 | "userInput": {
372 | "text": "I would like to change my booking to include 9 people."
373 | },
374 | "labels": [
375 | {
376 | "slot": "people",
377 | "valueSpan": {
378 | "startIndex": 45,
379 | "endIndex": 53
380 | }
381 | }
382 | ]
383 | },
384 | {
385 | "userInput": {
386 | "text": "I would like information on the Hush Brasserie restaurant please."
387 | }
388 | },
389 | {
390 | "userInput": {
391 | "text": "Our party changed number; it's 10 now."
392 | },
393 | "labels": [
394 | {
395 | "slot": "people",
396 | "valueSpan": {
397 | "startIndex": 31,
398 | "endIndex": 33
399 | }
400 | }
401 | ]
402 | },
403 | {
404 | "userInput": {
405 | "text": "2018/04/24."
406 | },
407 | "context": {
408 | "requestedSlots": [
409 | "date"
410 | ]
411 | },
412 | "labels": [
413 | {
414 | "slot": "date",
415 | "valueSpan": {
416 | "endIndex": 10
417 | }
418 | }
419 | ]
420 | },
421 | {
422 | "userInput": {
423 | "text": "On 12:30PM"
424 | },
425 | "context": {
426 | "requestedSlots": [
427 | "time"
428 | ]
429 | },
430 | "labels": [
431 | {
432 | "slot": "time",
433 | "valueSpan": {
434 | "startIndex": 3,
435 | "endIndex": 10
436 | }
437 | }
438 | ]
439 | },
440 | {
441 | "userInput": {
442 | "text": "I want a table at for 2 at noon."
443 | },
444 | "labels": [
445 | {
446 | "slot": "time",
447 | "valueSpan": {
448 | "startIndex": 27,
449 | "endIndex": 31
450 | }
451 | },
452 | {
453 | "slot": "people",
454 | "valueSpan": {
455 | "startIndex": 22,
456 | "endIndex": 23
457 | }
458 | }
459 | ]
460 | }
461 | ]
--------------------------------------------------------------------------------
/data/bank77/showDataset.py:
--------------------------------------------------------------------------------
1 | # this script extracts a sub word embedding from the entire one
2 | from gensim.models.keyedvectors import KeyedVectors
3 | import numpy as np
4 | import scipy.io as sio
5 | import nltk
6 | import pdb
7 | import random
8 | import csv
9 | import string
10 | import contractions
11 | import json
12 | import random
13 |
14 | def getDomainIntent(domainLabFile):
15 | domain2lab = {}
16 | lab2domain = {}
17 | currentDomain = None
18 | with open(domainLabFile,'r') as f:
19 | for line in f:
20 | if ':' in line and currentDomain == None:
21 | currentDomain = cleanUpSentence(line)
22 | domain2lab[currentDomain] = []
23 | elif line == "\n":
24 | currentDomain = None
25 | else:
26 | intent = cleanUpSentence(line)
27 | domain2lab[currentDomain].append(intent)
28 |
29 | for key in domain2lab:
30 | domain = key
31 | labList = domain2lab[key]
32 | for lab in labList:
33 | lab2domain[lab] = domain
34 |
35 | return domain2lab, lab2domain
36 |
37 | def cleanUpSentence(sentence):
38 | # sentence: a string, like " Hello, do you like apple? I hate it!! "
39 |
40 | # strip
41 | sentence = sentence.strip()
42 |
43 | # lower case
44 | sentence = sentence.lower()
45 |
46 | # fix contractions
47 | sentence = contractions.fix(sentence)
48 |
49 | # remove '_' and '-'
50 | sentence = sentence.replace('-',' ')
51 | sentence = sentence.replace('_',' ')
52 |
53 | # remove all punctuations
54 | sentence = ''.join(ch for ch in sentence if ch not in string.punctuation)
55 |
56 | return sentence
57 |
58 |
59 | def check_data_format(file_path):
60 | for line in open(file_path,'rb'):
61 | arr =str(line.strip(),'utf-8')
62 | arr = arr.split('\t')
63 | label = [w for w in arr[0].split(' ')]
64 | question = [w for w in arr[1].split(' ')]
65 |
66 | if len(label) == 0 or len(question) == 0:
67 | print("[ERROR] Find empty data: ", label, question)
68 | return False
69 |
70 | return True
71 |
72 |
73 | def save_data(data, file_path):
74 | # save data to disk
75 | with open(file_path, 'w') as f:
76 | json.dump(data, f)
77 | return
78 |
79 | def save_domain_intent(data, file_path):
80 | domain2intent = {}
81 | for line in data:
82 | domain = line[0]
83 | intent = line[1]
84 |
85 | if not domain in domain2intent:
86 | domain2intent[domain] = set()
87 |
88 | domain2intent[domain].add(intent)
89 |
90 | # save data to disk
91 | print("Saving domain intent out ... format: domain \t intent")
92 | with open(file_path,"w") as f:
93 | for domain in domain2intent:
94 | intentSet = domain2intent[domain]
95 | for intent in intentSet:
96 | f.write("%s\t%s\n" % (domain, intent))
97 | return
98 |
99 | def display_data(data):
100 | # dataset count
101 | print("[INFO] We have %d dataset."%(len(data)))
102 |
103 | datasetName = 'BANKING77'
104 | data = data[datasetName]
105 |
106 | # domain count
107 | domainName = set()
108 | for domain in data:
109 | domainName.add(domain)
110 | print("[INFO] There are %d domains."%(len(domainName)))
111 | print(domainName)
112 |
113 | # intent count
114 | intentName = set()
115 | for domain in data:
116 | for d in data[domain]:
117 | lab = d[1][0]
118 | intentName.add(lab)
119 | intentName = list(intentName)
120 | intentName.sort()
121 | print("[INFO] There are %d intent."%(len(intentName)))
122 | print(intentName)
123 |
124 | # data count
125 | count = 0
126 | for domain in data:
127 | for d in data[domain]:
128 | count = count+1
129 | print("[INFO] Data count: %d"%(count))
130 |
131 | # intent for each domain
132 | domain2intentDict = {}
133 | for domain in data:
134 | if not domain in domain2intentDict:
135 | domain2intentDict[domain] = set()
136 |
137 | for d in data[domain]:
138 | lab = d[1][0]
139 | domain2intentDict[domain].add(lab)
140 | print("[INFO] Intent for each domain.")
141 | print(domain2intentDict)
142 |
143 | # data for each intent
144 | intent2count = {}
145 | for domain in data:
146 | for d in data[domain]:
147 | lab = d[1][0]
148 | if not lab in intent2count:
149 | intent2count[lab] = 0
150 | intent2count[lab] = intent2count[lab]+1
151 | print("[INFO] Intent count")
152 | print(intent2count)
153 |
154 | # examples of data
155 | exampleNum = 3
156 | while not exampleNum == 0:
157 | for domain in data:
158 | for d in data[domain]:
159 | lab = d[1]
160 | utt = d[0]
161 | if random.random() < 0.001:
162 | print("[INFO] Example:--%s, %s, %s, %s"%(datasetName, domain, lab, utt))
163 | exampleNum = exampleNum-1
164 | break
165 | if (exampleNum==0):
166 | break
167 |
168 | return None
169 |
170 |
171 | ##
172 | # @brief clean up data, including intent and utterance
173 | #
174 | # @param data a list of data
175 | #
176 | # @return
177 | def cleanData(data):
178 | newData = []
179 | for d in data:
180 | utt = d[0]
181 | lab = d[1]
182 |
183 | uttClr = cleanUpSentence(utt)
184 | labClr = cleanUpSentence(lab)
185 | newData.append([labClr, uttClr])
186 |
187 | return newData
188 |
189 | def constructData(data, intent2domain):
190 | dataset2domain = {}
191 | datasetName = 'CLINC150'
192 | dataset2domain[datasetName] = {}
193 | for d in data:
194 | lab = d[0]
195 | utt = d[1]
196 | domain = intent2domain[lab]
197 | if not domain in dataset2domain[datasetName]:
198 | dataset2domain[datasetName][domain] = []
199 | dataField = [utt, [lab]]
200 | dataset2domain[datasetName][domain].append(dataField)
201 |
202 | return dataset2domain
203 |
204 |
205 | def read_data(file_path):
206 | with open(file_path) as json_file:
207 | data = json.load(json_file)
208 | return data
209 |
210 | # read in data
211 | #dataPath = "/data1/haode/projects/EMDIntentFewShot/SPIN_refactor/data/refactor_OOS/dataset.json"
212 | dataPath = "./dataset.json"
213 | print("Loading data ...", dataPath)
214 | # read lines, collect data count for different classes
215 | data = read_data(dataPath)
216 |
217 | display_data(data)
218 | print("Display.. done")
219 |
--------------------------------------------------------------------------------
/data/hint3/original/readme.rd:
--------------------------------------------------------------------------------
1 | Original HINT3 have v1 and v2, two versions. V2 corrects some false label in v2.
2 |
3 | Here, we use v2.
4 |
--------------------------------------------------------------------------------
/data/hint3/original/v2/train/sofmattress_train.csv:
--------------------------------------------------------------------------------
1 | sentence,label
2 | You guys provide EMI option?,EMI
3 | Do you offer Zero Percent EMI payment options?,EMI
4 | 0% EMI.,EMI
5 | EMI,EMI
6 | I want in installment,EMI
7 | I want it on 0% interest,EMI
8 | How to get in EMI,EMI
9 | what about emi options,EMI
10 | I need emi payment. ,EMI
11 | How to EMI,EMI
12 | I want to buy on EMI,EMI
13 | I want to buy this in installments,EMI
14 | 0% Emi,EMI
15 | Down payments,EMI
16 | Installments,EMI
17 | How about Paisa finance,EMI
18 | Paisa finance service available,EMI
19 | What is minimum down payment,EMI
20 | Paisa Finance is available,EMI
21 | Do you accept Paisa EMI card,EMI
22 | Can we buy through Paisa finance,EMI
23 | EMI Option,EMI
24 | Paisa Finance,EMI
25 | No cost EMI is available?,EMI
26 | Is EMI available,EMI
27 | COD option is availble?,COD
28 | Do you offer COD to my pincode?,COD
29 | Can I do COD?,COD
30 | COD,COD
31 | Is it possible to COD,COD
32 | Cash on delivery is acceptable?,COD
33 | Can pay later on delivery ,COD
34 | DO you have COD option,COD
35 | Can it deliver by COD,COD
36 | Is COD option available,COD
37 | Can I get COD option?,COD
38 | Cash on delivery is available,COD
39 | Features of Ortho mattress,ORTHO_FEATURES
40 | What are the key features of the SOF Ortho mattress,ORTHO_FEATURES
41 | SOF ortho,ORTHO_FEATURES
42 | Ortho mattress,ORTHO_FEATURES
43 | Ortho features,ORTHO_FEATURES
44 | ortho,ORTHO_FEATURES
45 | Tell me about SOF Ortho mattress,ORTHO_FEATURES
46 | Have a back problem,ORTHO_FEATURES
47 | I have back pain issue,ORTHO_FEATURES
48 | Back Pain,ORTHO_FEATURES
49 | Is it orthopaedic,ORTHO_FEATURES
50 | Neck pain and back pain,ORTHO_FEATURES
51 | back ache issue,ORTHO_FEATURES
52 | I m looking mattress for slip disc problem,ORTHO_FEATURES
53 | Do we have anything for backache,ORTHO_FEATURES
54 | I am cervical and Lombard section problem,ORTHO_FEATURES
55 | Is there orthopedic mattress available,ORTHO_FEATURES
56 | What are the key features of the SOF Ergo mattress,ERGO_FEATURES
57 | Features of Ergo mattress,ERGO_FEATURES
58 | SOF ergo,ERGO_FEATURES
59 | Ergo mattress,ERGO_FEATURES
60 | Ergo features,ERGO_FEATURES
61 | Ergo,ERGO_FEATURES
62 | Tell me about SOF Ergo mattress,ERGO_FEATURES
63 | What about ergo,ERGO_FEATURES
64 | What is responsive foam,ERGO_FEATURES
65 | SOF ergo features,ERGO_FEATURES
66 | Does this have ergonomic support?,ERGO_FEATURES
67 | What is the difference between the Ergo & Ortho variants,COMPARISON
68 | Difference between Ergo & Ortho Mattress,COMPARISON
69 | Difference between the products,COMPARISON
70 | Compare the 2 mattresses,COMPARISON
71 | Product comparison,COMPARISON
72 | Comparison,COMPARISON
73 | Which mattress to buy?,COMPARISON
74 | Is the mattress good for my back,COMPARISON
75 | Mattress comparison,COMPARISON
76 | Compare ergo & ortho variants,COMPARISON
77 | I wanna know the difference,COMPARISON
78 | What is the warranty period?,WARRANTY
79 | Warranty,WARRANTY
80 | Does mattress cover is included in warranty,WARRANTY
81 | How long is the warranty you offer on your mattresses and what does it cover,WARRANTY
82 | Do you offer warranty ,WARRANTY
83 | Tell me about the product warranty,WARRANTY
84 | Share warranty information,WARRANTY
85 | Need to know the warranty details,WARRANTY
86 | Want to know about warranty,WARRANTY
87 | would interested in warranty details,WARRANTY
88 | How does the 100 night trial work,100_NIGHT_TRIAL_OFFER
89 | What is the 100-night offer,100_NIGHT_TRIAL_OFFER
90 | Trial details,100_NIGHT_TRIAL_OFFER
91 | How to enroll for trial,100_NIGHT_TRIAL_OFFER
92 | 100 night trial,100_NIGHT_TRIAL_OFFER
93 | Is the 100 night return trial applicable for custom size as well,100_NIGHT_TRIAL_OFFER
94 | 100 night,100_NIGHT_TRIAL_OFFER
95 | Trial offer on customisation,100_NIGHT_TRIAL_OFFER
96 | do you provide exchange,100_NIGHT_TRIAL_OFFER
97 | Can I try a mattress first,100_NIGHT_TRIAL_OFFER
98 | I want to check offers,100_NIGHT_TRIAL_OFFER
99 | What is 100 Night trial offer,100_NIGHT_TRIAL_OFFER
100 | Need 100 days trial,100_NIGHT_TRIAL_OFFER
101 | 100 days trial,100_NIGHT_TRIAL_OFFER
102 | Can I get free trial,100_NIGHT_TRIAL_OFFER
103 | 100 Nights trial version,100_NIGHT_TRIAL_OFFER
104 | Can you give me 100 night trial ,100_NIGHT_TRIAL_OFFER
105 | 100 free Nights,100_NIGHT_TRIAL_OFFER
106 | I want to change the size of the mattress.,SIZE_CUSTOMIZATION
107 | Need some help in changing size of the mattress,SIZE_CUSTOMIZATION
108 | How can I order a custom sized mattress,SIZE_CUSTOMIZATION
109 | Custom size,SIZE_CUSTOMIZATION
110 | Customise size ,SIZE_CUSTOMIZATION
111 | Customisation is possible?,SIZE_CUSTOMIZATION
112 | Will I get an option to Customise the size,SIZE_CUSTOMIZATION
113 | Can mattress size be customised?,SIZE_CUSTOMIZATION
114 | Mattress size change,SIZE_CUSTOMIZATION
115 | Can you help with the size?,WHAT_SIZE_TO_ORDER
116 | How do I know what size to order?,WHAT_SIZE_TO_ORDER
117 | How do I know the size of my bed?,WHAT_SIZE_TO_ORDER
118 | Inches,WHAT_SIZE_TO_ORDER
119 | Length,WHAT_SIZE_TO_ORDER
120 | Feet,WHAT_SIZE_TO_ORDER
121 | 6*3,WHAT_SIZE_TO_ORDER
122 | Help me with the size chart,WHAT_SIZE_TO_ORDER
123 | What are the available sizes?,WHAT_SIZE_TO_ORDER
124 | Can I please have the size chart?,WHAT_SIZE_TO_ORDER
125 | Share the size structure,WHAT_SIZE_TO_ORDER
126 | What are the sizes available?,WHAT_SIZE_TO_ORDER
127 | Want to know the custom size chart,WHAT_SIZE_TO_ORDER
128 | Mattress size,WHAT_SIZE_TO_ORDER
129 | What size to order?,WHAT_SIZE_TO_ORDER
130 | Show me all available sizes,WHAT_SIZE_TO_ORDER
131 | What are the sizes,WHAT_SIZE_TO_ORDER
132 | What are the available mattress sizes,WHAT_SIZE_TO_ORDER
133 | King Size,WHAT_SIZE_TO_ORDER
134 | Inches,WHAT_SIZE_TO_ORDER
135 | Get in Touch,LEAD_GEN
136 | Want to talk to an live agent,LEAD_GEN
137 | Please call me,LEAD_GEN
138 | Do you have Live agent,LEAD_GEN
139 | Want to get in touch,LEAD_GEN
140 | Schedule a callback ,LEAD_GEN
141 | Arrange a call back ,LEAD_GEN
142 | How to get in touch?,LEAD_GEN
143 | Interested in buying,LEAD_GEN
144 | I want to buy this,LEAD_GEN
145 | How to buy?,LEAD_GEN
146 | How to order,LEAD_GEN
147 | I want to order,LEAD_GEN
148 | Connect to an agent,LEAD_GEN
149 | Get In Touch,LEAD_GEN
150 | I want a call back,LEAD_GEN
151 | I need a call back,LEAD_GEN
152 | Please call me back,LEAD_GEN
153 | Call me now,LEAD_GEN
154 | I want to buy,LEAD_GEN
155 | Need a call from your representative,LEAD_GEN
156 | Do you deliver to my pincode,CHECK_PINCODE
157 | Check pincode,CHECK_PINCODE
158 | Is delivery possible on this pincode,CHECK_PINCODE
159 | Will you be able to deliver here,CHECK_PINCODE
160 | Can you make delivery on this pin code?,CHECK_PINCODE
161 | Can you deliver on my pincode,CHECK_PINCODE
162 | Do you deliver to,CHECK_PINCODE
163 | Can you please deliver on my pincode,CHECK_PINCODE
164 | Need a delivery on this pincode,CHECK_PINCODE
165 | Can I get delivery on this pincode,CHECK_PINCODE
166 | Do you have any showrooms in Delhi state,DISTRIBUTORS
167 | Do you have any distributors in Mumbai city,DISTRIBUTORS
168 | Do you have any retailers in Pune city,DISTRIBUTORS
169 | Where can I see the product before I buy,DISTRIBUTORS
170 | What is the price for size (x ft x y ft)? What is the price for size (x inches x y inches)?,DISTRIBUTORS
171 | Distributors ,DISTRIBUTORS
172 | Distributors/Retailers/Showrooms,DISTRIBUTORS
173 | Where is your showroom,DISTRIBUTORS
174 | Do you have a showroom,DISTRIBUTORS
175 | Can I visit SOF mattress showroom,DISTRIBUTORS
176 | Where is your showroom,DISTRIBUTORS
177 | Nearby Show room,DISTRIBUTORS
178 | Do you have Showroom,DISTRIBUTORS
179 | Need store,DISTRIBUTORS
180 | Need Nearby Store,DISTRIBUTORS
181 | Nearest shop,DISTRIBUTORS
182 | Demo store,DISTRIBUTORS
183 | Is there any offline stores ,DISTRIBUTORS
184 | Head office,DISTRIBUTORS
185 | Shops nearby,DISTRIBUTORS
186 | Where was this shop,DISTRIBUTORS
187 | Offline stores,DISTRIBUTORS
188 | Do you have store,DISTRIBUTORS
189 | Where is the shop ,DISTRIBUTORS
190 | Is it available in shops,DISTRIBUTORS
191 | You have any branch,DISTRIBUTORS
192 | Store in,DISTRIBUTORS
193 | We want dealer ship,DISTRIBUTORS
194 | Any shop that I can visit,DISTRIBUTORS
195 | Dealership,DISTRIBUTORS
196 | Shop near by,DISTRIBUTORS
197 | Need dealership,DISTRIBUTORS
198 | Outlet,DISTRIBUTORS
199 | Store near me,DISTRIBUTORS
200 | Price of mattress,MATTRESS_COST
201 | Mattress cost,MATTRESS_COST
202 | Cost of mattress,MATTRESS_COST
203 | How much does a SOF mattress cost,MATTRESS_COST
204 | Cost,MATTRESS_COST
205 | Custom size cost,MATTRESS_COST
206 | What does the mattress cost,MATTRESS_COST
207 | I need price,MATTRESS_COST
208 | Price Range,MATTRESS_COST
209 | Mattress price,MATTRESS_COST
210 | Price of Mattress,MATTRESS_COST
211 | Want to know the price,MATTRESS_COST
212 | How Much Cost,MATTRESS_COST
213 | Price,MATTRESS_COST
214 | Rate,MATTRESS_COST
215 | MRP,MATTRESS_COST
216 | Low price,MATTRESS_COST
217 | What will be the price,MATTRESS_COST
218 | Cost of Bed,MATTRESS_COST
219 | Cost of Mattress,MATTRESS_COST
220 | Mattress cost,MATTRESS_COST
221 | What is the cost,MATTRESS_COST
222 | What are the product variants,PRODUCT_VARIANTS
223 | Product Variants,PRODUCT_VARIANTS
224 | Help me with different products,PRODUCT_VARIANTS
225 | What are the mattress variants,PRODUCT_VARIANTS
226 | I want to check products,PRODUCT_VARIANTS
227 | What are the SOF mattress products,PRODUCT_VARIANTS
228 | Show me products,PRODUCT_VARIANTS
229 | Which product is best,PRODUCT_VARIANTS
230 | Which mattress is best,PRODUCT_VARIANTS
231 | I want to buy a mattress,PRODUCT_VARIANTS
232 | Products,PRODUCT_VARIANTS
233 | View products,PRODUCT_VARIANTS
234 | Type of foam used,PRODUCT_VARIANTS
235 | Mattress variants,PRODUCT_VARIANTS
236 | Mattress Features,PRODUCT_VARIANTS
237 | I am looking the mattress,PRODUCT_VARIANTS
238 | Type of mattress,PRODUCT_VARIANTS
239 | Show more mattress,PRODUCT_VARIANTS
240 | What are the mattress features,PRODUCT_VARIANTS
241 | What are your products,PRODUCT_VARIANTS
242 | Tell me about SOF mattress features,PRODUCT_VARIANTS
243 | How is SOF different from other mattress brands,ABOUT_SOF_MATTRESS
244 | Why SOF mattress,ABOUT_SOF_MATTRESS
245 | About SOF Mattress,ABOUT_SOF_MATTRESS
246 | What is SOF mattress,ABOUT_SOF_MATTRESS
247 | Tell me about SOF mattresses,ABOUT_SOF_MATTRESS
248 | Who are SOF mattress,ABOUT_SOF_MATTRESS
249 | What is SOF,ABOUT_SOF_MATTRESS
250 | How is SOF mattress different from,ABOUT_SOF_MATTRESS
251 | Tell me about SOF Mattress,ABOUT_SOF_MATTRESS
252 | Tell me about company,ABOUT_SOF_MATTRESS
253 | Who is SOF mattress,ABOUT_SOF_MATTRESS
254 | It's been a month,DELAY_IN_DELIVERY
255 | Why so long?,DELAY_IN_DELIVERY
256 | I did not receive my order yet,DELAY_IN_DELIVERY
257 | Why so many days,DELAY_IN_DELIVERY
258 | It's delayed,DELAY_IN_DELIVERY
259 | Almost 1 month over,DELAY_IN_DELIVERY
260 | It's been 30 days my product haven't received,DELAY_IN_DELIVERY
261 | It's been so many days,DELAY_IN_DELIVERY
262 | It's already one month,DELAY_IN_DELIVERY
263 | Delivery is delayed,DELAY_IN_DELIVERY
264 | It's too late to get delivered,DELAY_IN_DELIVERY
265 | Order Status,ORDER_STATUS
266 | What is my order status?,ORDER_STATUS
267 | Order related,ORDER_STATUS
268 | Status of my order,ORDER_STATUS
269 | What about my order,ORDER_STATUS
270 | My order Number,ORDER_STATUS
271 | Where is my order,ORDER_STATUS
272 | Order #,ORDER_STATUS
273 | Track order,ORDER_STATUS
274 | I want updates of my order,ORDER_STATUS
275 | I need my order status,ORDER_STATUS
276 | Status of my order,ORDER_STATUS
277 | State of this order,ORDER_STATUS
278 | Order status,ORDER_STATUS
279 | Present status,ORDER_STATUS
280 | What is the status?,ORDER_STATUS
281 | Current state of my order,ORDER_STATUS
282 | What is the order status?,ORDER_STATUS
283 | Where is my product,ORDER_STATUS
284 | When will the order be delivered to me?,ORDER_STATUS
285 | When can we expect,ORDER_STATUS
286 | Need my money back,RETURN_EXCHANGE
287 | I want refund,RETURN_EXCHANGE
288 | Refund,RETURN_EXCHANGE
289 | Not happy with the product please help me to return,RETURN_EXCHANGE
290 | Help me with exchange process,RETURN_EXCHANGE
291 | How do I return it,RETURN_EXCHANGE
292 | I want to return my mattress,RETURN_EXCHANGE
293 | Exchange,RETURN_EXCHANGE
294 | Looking to exchange,RETURN_EXCHANGE
295 | How can I replace the mattress.,RETURN_EXCHANGE
296 | How do I return It,RETURN_EXCHANGE
297 | Return,RETURN_EXCHANGE
298 | Return my product,RETURN_EXCHANGE
299 | Replacement policy,RETURN_EXCHANGE
300 | I want to cancel my order,CANCEL_ORDER
301 | How can I cancel my order,CANCEL_ORDER
302 | Cancel order,CANCEL_ORDER
303 | Cancellation status,CANCEL_ORDER
304 | I want cancel my order,CANCEL_ORDER
305 | Cancel the order,CANCEL_ORDER
306 | Cancel my order,CANCEL_ORDER
307 | Need to cancel my order,CANCEL_ORDER
308 | Process of cancelling order,CANCEL_ORDER
309 | Can I cancel my order here,CANCEL_ORDER
310 | Can I get pillows?,PILLOWS
311 | Do you sell pillows?,PILLOWS
312 | Pillows,PILLOWS
313 | I want to buy pillows,PILLOWS
314 | Are Pillows available,PILLOWS
315 | Need pair of Pillows,PILLOWS
316 | Can I buy pillows from here,PILLOWS
317 | Do you have cushions,PILLOWS
318 | Can I also have pillows,PILLOWS
319 | Is pillows available,PILLOWS
320 | Offers,OFFERS
321 | What are the available offers,OFFERS
322 | Give me some discount,OFFERS
323 | Any discounts,OFFERS
324 | Discount,OFFERS
325 | May I please know about the offers,OFFERS
326 | Available offers,OFFERS
327 | Is offer available,OFFERS
328 | Want to know the discount ,OFFERS
329 | Tell me about the latest offers,OFFERS
330 |
--------------------------------------------------------------------------------
/data/hint3/showDataset.py:
--------------------------------------------------------------------------------
1 | # this script extracts a sub word embedding from the entire one
2 | from gensim.models.keyedvectors import KeyedVectors
3 | import numpy as np
4 | import scipy.io as sio
5 | import nltk
6 | import pdb
7 | import random
8 | import csv
9 | import string
10 | import contractions
11 | import json
12 | import random
13 |
14 | def getDomainIntent(domainLabFile):
15 | domain2lab = {}
16 | lab2domain = {}
17 | currentDomain = None
18 | with open(domainLabFile,'r') as f:
19 | for line in f:
20 | if ':' in line and currentDomain == None:
21 | currentDomain = cleanUpSentence(line)
22 | domain2lab[currentDomain] = []
23 | elif line == "\n":
24 | currentDomain = None
25 | else:
26 | intent = cleanUpSentence(line)
27 | domain2lab[currentDomain].append(intent)
28 |
29 | for key in domain2lab:
30 | domain = key
31 | labList = domain2lab[key]
32 | for lab in labList:
33 | lab2domain[lab] = domain
34 |
35 | return domain2lab, lab2domain
36 |
37 | def cleanUpSentence(sentence):
38 | # sentence: a string, like " Hello, do you like apple? I hate it!! "
39 |
40 | # strip
41 | sentence = sentence.strip()
42 |
43 | # lower case
44 | sentence = sentence.lower()
45 |
46 | # fix contractions
47 | sentence = contractions.fix(sentence)
48 |
49 | # remove '_' and '-'
50 | sentence = sentence.replace('-',' ')
51 | sentence = sentence.replace('_',' ')
52 |
53 | # remove all punctuations
54 | sentence = ''.join(ch for ch in sentence if ch not in string.punctuation)
55 |
56 | return sentence
57 |
58 |
59 | def check_data_format(file_path):
60 | for line in open(file_path,'rb'):
61 | arr =str(line.strip(),'utf-8')
62 | arr = arr.split('\t')
63 | label = [w for w in arr[0].split(' ')]
64 | question = [w for w in arr[1].split(' ')]
65 |
66 | if len(label) == 0 or len(question) == 0:
67 | print("[ERROR] Find empty data: ", label, question)
68 | return False
69 |
70 | return True
71 |
72 |
73 | def save_data(data, file_path):
74 | # save data to disk
75 | with open(file_path, 'w') as f:
76 | json.dump(data, f)
77 | return
78 |
79 | def save_domain_intent(data, file_path):
80 | domain2intent = {}
81 | for line in data:
82 | domain = line[0]
83 | intent = line[1]
84 |
85 | if not domain in domain2intent:
86 | domain2intent[domain] = set()
87 |
88 | domain2intent[domain].add(intent)
89 |
90 | # save data to disk
91 | print("Saving domain intent out ... format: domain \t intent")
92 | with open(file_path,"w") as f:
93 | for domain in domain2intent:
94 | intentSet = domain2intent[domain]
95 | for intent in intentSet:
96 | f.write("%s\t%s\n" % (domain, intent))
97 | return
98 |
99 | def display_data(data):
100 | # dataset count
101 | print("[INFO] We have %d dataset."%(len(data)))
102 |
103 | datasetName = 'HINT3'
104 | data = data[datasetName]
105 |
106 | # domain count
107 | domainName = set()
108 | for domain in data:
109 | domainName.add(domain)
110 | print("[INFO] There are %d domains."%(len(domainName)))
111 | print(domainName)
112 |
113 | # intent count
114 | intentName = set()
115 | for domain in data:
116 | for d in data[domain]:
117 | lab = d[1][0]
118 | intentName.add(lab)
119 | intentName = list(intentName)
120 | intentName.sort()
121 | print("[INFO] There are %d intent."%(len(intentName)))
122 | print(intentName)
123 |
124 | # data count
125 | count = 0
126 | for domain in data:
127 | for d in data[domain]:
128 | count = count+1
129 | print("[INFO] Data count: %d"%(count))
130 |
131 | # intent for each domain
132 | domain2intentDict = {}
133 | for domain in data:
134 | if not domain in domain2intentDict:
135 | domain2intentDict[domain] = set()
136 |
137 | for d in data[domain]:
138 | lab = d[1][0]
139 | domain2intentDict[domain].add(lab)
140 | print("[INFO] Intent for each domain.")
141 | print(domain2intentDict)
142 |
143 | # data for each intent
144 | intent2count = {}
145 | for domain in data:
146 | for d in data[domain]:
147 | lab = d[1][0]
148 | if not lab in intent2count:
149 | intent2count[lab] = 0
150 | intent2count[lab] = intent2count[lab]+1
151 | print("[INFO] Intent count")
152 | print(intent2count)
153 |
154 | # examples of data
155 | exampleNum = 3
156 | while not exampleNum == 0:
157 | for domain in data:
158 | for d in data[domain]:
159 | lab = d[1]
160 | utt = d[0]
161 | if random.random() < 0.001:
162 | print("[INFO] Example:--%s, %s, %s, %s"%(datasetName, domain, lab, utt))
163 | exampleNum = exampleNum-1
164 | break
165 | if (exampleNum==0):
166 | break
167 |
168 | return None
169 |
170 |
171 | ##
172 | # @brief clean up data, including intent and utterance
173 | #
174 | # @param data a list of data
175 | #
176 | # @return
177 | def cleanData(data):
178 | newData = []
179 | for d in data:
180 | utt = d[0]
181 | lab = d[1]
182 |
183 | uttClr = cleanUpSentence(utt)
184 | labClr = cleanUpSentence(lab)
185 | newData.append([labClr, uttClr])
186 |
187 | return newData
188 |
189 | def constructData(data, intent2domain):
190 | dataset2domain = {}
191 | datasetName = 'HINT3'
192 | dataset2domain[datasetName] = {}
193 | for d in data:
194 | lab = d[0]
195 | utt = d[1]
196 | domain = intent2domain[lab]
197 | if not domain in dataset2domain[datasetName]:
198 | dataset2domain[datasetName][domain] = []
199 | dataField = [utt, [lab]]
200 | dataset2domain[datasetName][domain].append(dataField)
201 |
202 | return dataset2domain
203 |
204 |
205 | def read_data(file_path):
206 | with open(file_path) as json_file:
207 | data = json.load(json_file)
208 | return data
209 |
210 | # read in data
211 | #dataPath = "/data1/haode/projects/EMDIntentFewShot/SPIN_refactor/data/refactor_OOS/dataset.json"
212 | dataPath = "./dataset.json"
213 | print("Loading data ...", dataPath)
214 | # read lines, collect data count for different classes
215 | data = read_data(dataPath)
216 |
217 | display_data(data)
218 | print("Display.. done")
219 |
--------------------------------------------------------------------------------
/data/hwu64/showDataset.py:
--------------------------------------------------------------------------------
1 | # this script extracts a sub word embedding from the entire one
2 | from gensim.models.keyedvectors import KeyedVectors
3 | import numpy as np
4 | import scipy.io as sio
5 | import nltk
6 | import pdb
7 | import random
8 | import csv
9 | import string
10 | import contractions
11 | import json
12 | import random
13 |
14 | def getDomainIntent(domainLabFile):
15 | domain2lab = {}
16 | lab2domain = {}
17 | currentDomain = None
18 | with open(domainLabFile,'r') as f:
19 | for line in f:
20 | if ':' in line and currentDomain == None:
21 | currentDomain = cleanUpSentence(line)
22 | domain2lab[currentDomain] = []
23 | elif line == "\n":
24 | currentDomain = None
25 | else:
26 | intent = cleanUpSentence(line)
27 | domain2lab[currentDomain].append(intent)
28 |
29 | for key in domain2lab:
30 | domain = key
31 | labList = domain2lab[key]
32 | for lab in labList:
33 | lab2domain[lab] = domain
34 |
35 | return domain2lab, lab2domain
36 |
37 | def cleanUpSentence(sentence):
38 | # sentence: a string, like " Hello, do you like apple? I hate it!! "
39 |
40 | # strip
41 | sentence = sentence.strip()
42 |
43 | # lower case
44 | sentence = sentence.lower()
45 |
46 | # fix contractions
47 | sentence = contractions.fix(sentence)
48 |
49 | # remove '_' and '-'
50 | sentence = sentence.replace('-',' ')
51 | sentence = sentence.replace('_',' ')
52 |
53 | # remove all punctuations
54 | sentence = ''.join(ch for ch in sentence if ch not in string.punctuation)
55 |
56 | return sentence
57 |
58 |
59 | def check_data_format(file_path):
60 | for line in open(file_path,'rb'):
61 | arr =str(line.strip(),'utf-8')
62 | arr = arr.split('\t')
63 | label = [w for w in arr[0].split(' ')]
64 | question = [w for w in arr[1].split(' ')]
65 |
66 | if len(label) == 0 or len(question) == 0:
67 | print("[ERROR] Find empty data: ", label, question)
68 | return False
69 |
70 | return True
71 |
72 |
73 | def save_data(data, file_path):
74 | # save data to disk
75 | with open(file_path, 'w') as f:
76 | json.dump(data, f)
77 | return
78 |
79 | def save_domain_intent(data, file_path):
80 | domain2intent = {}
81 | for line in data:
82 | domain = line[0]
83 | intent = line[1]
84 |
85 | if not domain in domain2intent:
86 | domain2intent[domain] = set()
87 |
88 | domain2intent[domain].add(intent)
89 |
90 | # save data to disk
91 | print("Saving domain intent out ... format: domain \t intent")
92 | with open(file_path,"w") as f:
93 | for domain in domain2intent:
94 | intentSet = domain2intent[domain]
95 | for intent in intentSet:
96 | f.write("%s\t%s\n" % (domain, intent))
97 | return
98 |
99 | def display_data(data):
100 | # dataset count
101 | print("[INFO] We have %d dataset."%(len(data)))
102 |
103 | datasetName = 'HWU64'
104 | data = data[datasetName]
105 |
106 | # domain count
107 | domainName = set()
108 | for domain in data:
109 | domainName.add(domain)
110 | print("[INFO] There are %d domains."%(len(domainName)))
111 | print(domainName)
112 |
113 | # intent count
114 | intentName = set()
115 | for domain in data:
116 | for d in data[domain]:
117 | lab = d[1][0]
118 | intentName.add(lab)
119 | intentName = list(intentName)
120 | intentName.sort()
121 | print("[INFO] There are %d intent."%(len(intentName)))
122 | print(intentName)
123 |
124 | # data count
125 | count = 0
126 | for domain in data:
127 | for d in data[domain]:
128 | count = count+1
129 | print("[INFO] Data count: %d"%(count))
130 |
131 | # intent for each domain
132 | domain2intentDict = {}
133 | for domain in data:
134 | if not domain in domain2intentDict:
135 | domain2intentDict[domain] = set()
136 |
137 | for d in data[domain]:
138 | lab = d[1][0]
139 | domain2intentDict[domain].add(lab)
140 | print("[INFO] Intent for each domain.")
141 | print(domain2intentDict)
142 |
143 | # data for each intent
144 | intent2count = {}
145 | for domain in data:
146 | for d in data[domain]:
147 | lab = d[1][0]
148 | if not lab in intent2count:
149 | intent2count[lab] = 0
150 | intent2count[lab] = intent2count[lab]+1
151 | print("[INFO] Intent count")
152 | print(intent2count)
153 |
154 | # examples of data
155 | exampleNum = 3
156 | while not exampleNum == 0:
157 | for domain in data:
158 | for d in data[domain]:
159 | lab = d[1]
160 | utt = d[0]
161 | if random.random() < 0.001:
162 | print("[INFO] Example:--%s, %s, %s, %s"%(datasetName, domain, lab, utt))
163 | exampleNum = exampleNum-1
164 | break
165 | if (exampleNum==0):
166 | break
167 |
168 | return None
169 |
170 |
171 | ##
172 | # @brief clean up data, including intent and utterance
173 | #
174 | # @param data a list of data
175 | #
176 | # @return
177 | def cleanData(data):
178 | newData = []
179 | for d in data:
180 | utt = d[0]
181 | lab = d[1]
182 |
183 | uttClr = cleanUpSentence(utt)
184 | labClr = cleanUpSentence(lab)
185 | newData.append([labClr, uttClr])
186 |
187 | return newData
188 |
189 | def constructData(data, intent2domain):
190 | dataset2domain = {}
191 | datasetName = 'CLINC150'
192 | dataset2domain[datasetName] = {}
193 | for d in data:
194 | lab = d[0]
195 | utt = d[1]
196 | domain = intent2domain[lab]
197 | if not domain in dataset2domain[datasetName]:
198 | dataset2domain[datasetName][domain] = []
199 | dataField = [utt, [lab]]
200 | dataset2domain[datasetName][domain].append(dataField)
201 |
202 | return dataset2domain
203 |
204 |
205 | def read_data(file_path):
206 | with open(file_path) as json_file:
207 | data = json.load(json_file)
208 | return data
209 |
210 | # read in data
211 | #dataPath = "/data1/haode/projects/EMDIntentFewShot/SPIN_refactor/data/refactor_OOS/dataset.json"
212 | dataPath = "./dataset.json"
213 | print("Loading data ...", dataPath)
214 | # read lines, collect data count for different classes
215 | data = read_data(dataPath)
216 |
217 | display_data(data)
218 | print("Display.. done")
219 |
--------------------------------------------------------------------------------
/data/mcid/original/README:
--------------------------------------------------------------------------------
1 | Data Description
2 | The following directory contains intent classification data(train, test and eval splits) for Covid-19 queries in 3 languages:
3 | 1. English
4 | 2. Spanish
5 | 3. French
6 | 4. German
7 | Additionally, there is also a Spanglish directory which contains a code-switched test set for Spanglish queries.
8 |
9 | Data Format
10 |
11 | The data is provided in the form of Tab Separated files containing two columns:
12 | 1. Column 1 contains the query utterance
13 | 2. Column 2 contains the intent/label for the query
14 |
--------------------------------------------------------------------------------
/data/mcid/original/en/eval.tsv:
--------------------------------------------------------------------------------
1 | is there a cure for the coronavirus intent:what_are_treatment_options
2 | flu treatment options viable to fight corona intent:what_are_treatment_options
3 | i want treatment options for covid-19 intent:what_are_treatment_options
4 | is there a cure for the virus intent:what_are_treatment_options
5 | what are some updates on the cure for covid-19 intent:what_are_treatment_options
6 | does acetaminophen cure coronavirus intent:what_are_treatment_options
7 | where can i get the covid vaccine intent:what_are_treatment_options
8 | who offers a vaccine for corona virus intent:what_are_treatment_options
9 | any homeopathic remedies for covid 19 intent:what_are_treatment_options
10 | please tell me if it's a myth that soup gives you corona intent:myths
11 | covid debunked myths intent:myths
12 | myths about desinfect your shoes every day intent:myths
13 | do 5g antennas spread the coronavirus intent:myths
14 | is it true that i can catch covid-19 from animals intent:myths
15 | covid-19 origin myths intent:myths
16 | myth about half the population dying from covid 19 intent:myths
17 | is it a myth that i can get coronavirus from my cat intent:myths
18 | can i travel to go hiking intent:travel
19 | can i visit china now intent:travel
20 | advice for traveling in dc intent:travel
21 | is air travel safe yet intent:travel
22 | will be the hotels reopen in 40 days intent:travel
23 | i wanna see the latest news about covid in my city intent:news_and_press
24 | tell me the latest corona news for europe intent:news_and_press
25 | recent updates in bahrain intent:news_and_press
26 | anything new about new york intent:news_and_press
27 | latest coronavirus news intent:news_and_press
28 | any news on coronavirus from today intent:news_and_press
29 | are there any updates in texas intent:news_and_press
30 | covid updates from the latest press release intent:news_and_press
31 | share this intent:share
32 | share this bot intent:share
33 | please share this intent:share
34 | this would be great to share intent:share
35 | i want to share this assistant with my contacts intent:share
36 | share this article intent:share
37 | share this info intent:share
38 | share this report intent:share
39 | send link to sister intent:share
40 | how are you today bot intent:hi
41 | hello covid bot intent:hi
42 | are you still sleepy bot intent:hi
43 | hey there intent:hi
44 | whaddup doctor intent:hi
45 | hey there doc intent:hi
46 | hey doc intent:hi
47 | what's up doc intent:hi
48 | how's it going intent:hi
49 | great info, thanks intent:okay_thanks
50 | got it, thanks bot intent:okay_thanks
51 | thanks for the info intent:okay_thanks
52 | i appreciate your help, dear corona intent:okay_thanks
53 | i appreciate it doc intent:okay_thanks
54 | i appreciate your help with travel safety tips intent:okay_thanks
55 | thanks for letting me know about high risk areas intent:okay_thanks
56 | thanks so much, corona doc intent:okay_thanks
57 | thank you for everything, doctor intent:okay_thanks
58 | thanks for all the info intent:okay_thanks
59 | thanks for the information doctor intent:okay_thanks
60 | thanks for the details intent:okay_thanks
61 | thank you for your help corona bot intent:okay_thanks
62 | how can i donate to new york city intent:donate
63 | how can i donate to charity intent:donate
64 | what's the best way to make a donation in my community intent:donate
65 | i want to fund nyc intent:donate
66 | donate to china intent:donate
67 | donate for african relief intent:donate
68 | donate to a local charity intent:donate
69 | what's a good cause to donate to intent:donate
70 | funding for covid 19 response chicago intent:donate
71 | how many covid19 cases are there now intent:latest_numbers
72 | how many confirmed cases are there in new york intent:latest_numbers
73 | can you show me the ranking by confirmed cases within the u.s intent:latest_numbers
74 | what are the statistics for elderly people infected by the virus intent:latest_numbers
75 | how many us states have a shelter in place order intent:latest_numbers
76 | total recoveries from covid intent:latest_numbers
77 | how many tests has the us done intent:latest_numbers
78 | does staying home protect my family from getting infected with the virus intent:protect_yourself
79 | do gloves help intent:protect_yourself
80 | should i sanitize my hands after grocery shopping intent:protect_yourself
81 | what kind of hand sanitizer should i use intent:protect_yourself
82 | how many times a day should i wash my hands intent:protect_yourself
83 | what should essential workers do for prevention of the virus intent:protect_yourself
84 | how long is the virus alive on surfaces intent:protect_yourself
85 | how to kill the spread of corona virus on surfaces intent:protect_yourself
86 | how can i protect myself from coronavirus intent:protect_yourself
87 | do face masks help stop coronavirus intent:protect_yourself
88 | are there any medications i should take if i have the virus intent:protect_yourself
89 | do pets need protection if going outdoors intent:protect_yourself
90 | should i wear gloves intent:protect_yourself
91 | what gloves are best for protection intent:protect_yourself
92 | enlighten me about the virus intent:what_is_corona
93 | brief me about the corona intent:what_is_corona
94 | describe corona intent:what_is_corona
95 | give me more info on the pandemic intent:what_is_corona
96 | tell me what i should know about covid intent:what_is_corona
97 | what does covid-19 stand for intent:what_is_corona
98 | is diarrhea a symptom intent:what_are_symptoms
99 | what are the symptoms of covid-19 intent:what_are_symptoms
100 | tell me all the possible symptoms of covid-19 intent:what_are_symptoms
101 | how can i get better intent:what_are_symptoms
102 | how can i tell the difference between having allergies or coronavirus intent:what_are_symptoms
103 | what are possible danger signs for it intent:what_are_symptoms
104 | what are some of the symptoms associated with corona virus intent:what_are_symptoms
105 | what are the late stage symptoms intent:what_are_symptoms
106 | how many different symptoms are there intent:what_are_symptoms
107 | can corona virus cause a headache intent:what_are_symptoms
108 | i have a sore throat, is this a symptom of the covid-19 intent:what_are_symptoms
109 | what main symptoms are clear signs of the disease intent:what_are_symptoms
110 | what symptom is the easiest to notice when someone has the disease intent:what_are_symptoms
111 | is covid-19 causing my shortness of breath intent:what_are_symptoms
112 | ive been sneezing a lot is that caused buy the corona intent:what_are_symptoms
113 | does corona virus spread through farts intent:how_does_corona_spread
114 | can i get the corona from my kids ipad even after 3 days intent:how_does_corona_spread
115 | can i get infected with covid if someone sneezes on me intent:how_does_corona_spread
116 | how likely does covid spread on amazon packages intent:how_does_corona_spread
117 | is coronavirus spread on doorknobs intent:how_does_corona_spread
118 | can i get infected with covid through indirect contact intent:how_does_corona_spread
119 | how can touching shared surfuces put you at risk for corona virus intent:how_does_corona_spread
120 | can the virus spread by skin to skin interaction intent:how_does_corona_spread
121 | can the virus travel by touch intent:how_does_corona_spread
122 | should i have contact with my pet if i have the coronavirus disease intent:can_i_get_from_feces_animal_pets
123 | can pets get the coronavirus, and can we catch it from them intent:can_i_get_from_feces_animal_pets
124 | how do i avoid getting covid from animals intent:can_i_get_from_feces_animal_pets
125 | is there a vaccination against the covid-19 coronavirus that my pet can receive intent:can_i_get_from_feces_animal_pets
126 | can i test my animal through a private veterinary laboratory intent:can_i_get_from_feces_animal_pets
127 | can i get corona virus from eating sushi intent:can_i_get_from_feces_animal_pets
128 | how can i tell if my cat has corona intent:can_i_get_from_feces_animal_pets
129 | can snakes get the virus intent:can_i_get_from_feces_animal_pets
130 | can horses get the virus intent:can_i_get_from_feces_animal_pets
131 | do mosquitos carry covid intent:can_i_get_from_feces_animal_pets
132 | how long does covid19 stay on public surfaces intent:can_i_get_from_packages_surfaces
133 | can i get coronavirus from touching groceries intent:can_i_get_from_packages_surfaces
134 | have people got covid from deliveries from china intent:can_i_get_from_packages_surfaces
135 | should i stop getting deliveries intent:can_i_get_from_packages_surfaces
136 | will bleach kill coronavirus on my stove intent:can_i_get_from_packages_surfaces
137 | is covid good at living on steel intent:can_i_get_from_packages_surfaces
138 | does covid-19 stay on furniture intent:can_i_get_from_packages_surfaces
139 | does covid-19 stay on ceilings intent:can_i_get_from_packages_surfaces
140 | which cities in the united states are at most risk for coronavirus right now intent:what_if_i_visited_high_risk_area
141 | is the hospital considered a high risk area intent:what_if_i_visited_high_risk_area
142 | which country is currently the highest risk intent:what_if_i_visited_high_risk_area
143 | should i be concerned since i just visited china intent:what_if_i_visited_high_risk_area
144 | what city in california has the highest risk intent:what_if_i_visited_high_risk_area
145 | should i travel within the u.s during the pandemic intent:what_if_i_visited_high_risk_area
146 | which countries are safe for nonessential international travel intent:what_if_i_visited_high_risk_area
147 | will i get in trouble visiting a high risk area intent:what_if_i_visited_high_risk_area
148 | which essential business has the highest risk of infection intent:what_if_i_visited_high_risk_area
149 |
--------------------------------------------------------------------------------
/data/mcid/original/es/eval.tsv:
--------------------------------------------------------------------------------
1 | me gustaría saber el número de muertes de corona virus en florida intent:latest_numbers
2 | dime cuántos casos confirmados de covid hay en méxico intent:latest_numbers
3 | estadísticas de pacientes recuperados a fecha de hoy en el mundo intent:latest_numbers
4 | me facilitás la cantidad de recuperados de corona virus en mi ciudad intent:latest_numbers
5 | cuántos casos confirmados de covid19 hay en el mundo intent:latest_numbers
6 | cuáles son los países con más muertos por covid19 intent:latest_numbers
7 | estadísticas mundiales sobre el coronavirus intent:latest_numbers
8 | dígame cuántos casos hay actualmente en cuba intent:latest_numbers
9 | cuántos contagiados hay en argentina intent:latest_numbers
10 | cómo puedo proteger a los menores frente al corona virus intent:protect_yourself
11 | me dices, por favor, qué debo hacer para no contagiarme de corona virus intent:protect_yourself
12 | si puedes dime la mejor forma de prevenir contagios en los menores intent:protect_yourself
13 | qué puedo usar si no tengo cubrebocas en mi casa intent:protect_yourself
14 | cómo puedo prevenir el contagio de corona virus intent:protect_yourself
15 | cómo evito que el virus entre a mi casa si salgo a comprar comida intent:protect_yourself
16 | cuánta lejía tengo que usar para limpiar los productos del covid19 intent:protect_yourself
17 | medidas de cuidado para que adultos mayores no se contagien de covid19 intent:protect_yourself
18 | cómo puedo hacer para protegerme del coronavirus intent:protect_yourself
19 | a cuántos metros puedo estar de otra persona sin que me contagie intent:protect_yourself
20 | qué prevenciones deben tomar los adultos mayores frente al covid-19 intent:protect_yourself
21 | cómo hago para no atrapar el corona virus intent:protect_yourself
22 | no me quiero enfermar. ¿qué puedo hacer intent:protect_yourself
23 | cómo evito contagiarme del covid 19 intent:protect_yourself
24 | es necesario desinfectar los calzados al volver de lugares públicos intent:protect_yourself
25 | quiero aprender más del coronavirus intent:what_is_corona
26 | dime todo del coronavirus intent:what_is_corona
27 | qué tan mortal es el covid-19 intent:what_is_corona
28 | muéstrame información del coronavirus, por favor intent:what_is_corona
29 | me dices la diferencia entre coronavirus y sars intent:what_is_corona
30 | qué enfermedades están relacionadas con el covid intent:what_is_corona
31 | hay otro virus que se llame corona virus intent:what_is_corona
32 | cuál es la diferencia entre el corona virus y el de la influenza intent:what_is_corona
33 | me explicás qué es exactamente el covid 19 intent:what_is_corona
34 | en qué consiste el virus del covid19 intent:what_is_corona
35 | dame datos básicos del coronavirus intent:what_is_corona
36 | quiero que me des información científica sobre el coronavirus intent:what_is_corona
37 | dame una definición de covid 19 intent:what_is_corona
38 | define corona virus intent:what_is_corona
39 | qué tal si me das un resumen sobre el corona virus intent:what_is_corona
40 | estoy interesada en saber cuáles son los síntomas del corona virus, por favor intent:what_are_symptoms
41 | tengo mucha necesidad por conocer cuáles son los principales sintomas del corona virus intent:what_are_symptoms
42 | crees que el dolor de garganta es un síntoma de corona virus intent:what_are_symptoms
43 | como puedo saber si tengo corona virus intent:what_are_symptoms
44 | cuáles son los principales síntomas del corona virus intent:what_are_symptoms
45 | cómo afecta la covid al cuerpo intent:what_are_symptoms
46 | qué ataca el covid intent:what_are_symptoms
47 | diferencias entre la influenza y el covid intent:what_are_symptoms
48 | cuáles son los síntomas de alguien que tiene coronavirus intent:what_are_symptoms
49 | es posible que sólo tenga síntomas estomacales y sea corona virus intent:what_are_symptoms
50 | cuál es el síntoma más evidente que indica corona virus intent:what_are_symptoms
51 | dime si tener diarrea es un síntoma suficiente para hacerse un examen intent:what_are_symptoms
52 | necesito una lista de síntomas frecuentes de corona virus intent:what_are_symptoms
53 | afectaciones de alguien con coronavirus intent:what_are_symptoms
54 | cómo saber si estoy infectada de coronavirus intent:what_are_symptoms
55 | lista de síntomas del coronavirus intent:what_are_symptoms
56 | cómo se contagia el corona virus intent:how_does_corona_spread
57 | búscame información de cómo se contagia el corona virus, por favor intent:how_does_corona_spread
58 | cuáles son los factores de riesgo del covid intent:how_does_corona_spread
59 | puedo contagiarme por caminar en la calle intent:how_does_corona_spread
60 | por tocar una superficie me puedo contagiar de covid19 intent:how_does_corona_spread
61 | el coronavirus, ¿es una enfermedad de transmición sexual intent:how_does_corona_spread
62 | el virus le puede dar a uno por contacto de piel intent:how_does_corona_spread
63 | quisiera informacion sobre si los animales pueden contagiar el corona virus intent:can_i_get_from_feces_animal_pets
64 | realmente necesito saber si mediante los animales podemos contagiarnos de corona virus intent:can_i_get_from_feces_animal_pets
65 | el coronavirus se transmite entre especies intent:can_i_get_from_feces_animal_pets
66 | le puede dar coronavirus a mi mascota intent:can_i_get_from_feces_animal_pets
67 | a mis gatos les puede dar coronavirus intent:can_i_get_from_feces_animal_pets
68 | si me como un murciélago tendré coronavirus intent:can_i_get_from_feces_animal_pets
69 | las palomas propagan el corona virus intent:can_i_get_from_feces_animal_pets
70 | por qué hay tigres con corona virus intent:can_i_get_from_feces_animal_pets
71 | necesito proteger a mi perro del virus intent:can_i_get_from_feces_animal_pets
72 | es cierto que el covid19 viene de los murciélagos intent:can_i_get_from_feces_animal_pets
73 | los cerdos se pueden enfermar del virus intent:can_i_get_from_feces_animal_pets
74 | puede un perro darle corona virus a un gato intent:can_i_get_from_feces_animal_pets
75 | montar un caballo puede pasarme el virus intent:can_i_get_from_feces_animal_pets
76 | tienes información de cuánto tiempo dura el corona virus en el cartón intent:can_i_get_from_packages_surfaces
77 | quiero que me digas si el corona virus sobrevive mucho tiempo en el papel intent:can_i_get_from_packages_surfaces
78 | el coronavirus aguanta el calor intent:can_i_get_from_packages_surfaces
79 | el coronavirus puede quedarse en los muebles intent:can_i_get_from_packages_surfaces
80 | cuánto tiempo vive el virus en la madera intent:can_i_get_from_packages_surfaces
81 | cuánto tiempo vive el virus en la mesada de mi cocina intent:can_i_get_from_packages_surfaces
82 | dime si las superficies del interior de un vehículo pueden propagar el virus intent:can_i_get_from_packages_surfaces
83 | qué puede pasarme si visito una zona de riesgo por corona virus intent:what_if_i_visited_high_risk_area
84 | reglas para visitar zonas de riesgo de covid intent:what_if_i_visited_high_risk_area
85 | pude contraer coronavirus si estuve en una zona de riesgo intent:what_if_i_visited_high_risk_area
86 | cuáles son las normativas que debo seguir si volví de una zona de riesgo intent:what_if_i_visited_high_risk_area
87 | que debe tener en cuenta alguien que estuvo hace poco en españa intent:what_if_i_visited_high_risk_area
88 | españa es zona de alto riesgo de covid 19 intent:what_if_i_visited_high_risk_area
89 | qué pasa si he estado de viaje en roma intent:what_if_i_visited_high_risk_area
90 | cuál es el protocolo si estuve en tránsito en zona de alto riesgo intent:what_if_i_visited_high_risk_area
91 | hay algún tratamiento eficaz contra el corona virus intent:what_are_treatment_options
92 | cómo dar tratamiento a un infectado de covid intent:what_are_treatment_options
93 | cómo tratan los hospitales a los enfermos del covid intent:what_are_treatment_options
94 | si me da coronavirus, ¿cómo me voy a curar intent:what_are_treatment_options
95 | cómo están curando a los pacientes intent:what_are_treatment_options
96 | puedo tomarme un té para elevar mis defensas contra el coronavirus intent:what_are_treatment_options
97 | la vacuna de la influenza sirve para el coronavirus intent:what_are_treatment_options
98 | sirve la hidroxicloroquina para matar el coronavirus intent:what_are_treatment_options
99 | es efectiva la vacuna contra la gripe para lidiar con el corona virus intent:what_are_treatment_options
100 | qué información sobre covid resultó ser mito intent:myths
101 | es cierto que el coronavirus es un invento del gobierno para matar a los viejitos intent:myths
102 | quiero saber cuáles son los mitos más importantes sobre el virus y su desmentida intent:myths
103 | decime el top tres de mitos sobre el covid-19 intent:myths
104 | los anteojos de sol, ¿ayudan a evitar el contagio del virus intent:myths
105 | necesito saber qué información es ficticia sobre el coronavirus intent:myths
106 | que noticias sobre el coronavirus son mitos intent:myths
107 | te es posible decirme si los servicios de tren están cancelados en italia intent:travel
108 | medidas de seguridad para viajeros intent:travel
109 | cuándo la gente va a poder viajar otra vez intent:travel
110 | en septiembre ya habrá viajes a europa intent:travel
111 | nombrar las recomendaciones para ir al aeropuerto intent:travel
112 | están funcionando los taxis estos días intent:travel
113 | qué tengo que hacer para viajar en avión de forma segura en una pandemia intent:travel
114 | están más baratos los vuelos a causa del covid19 intent:travel
115 | el subterráneo sigue funcionando intent:travel
116 | tienes consejos para viajar durante la epidemia intent:travel
117 | es buena idea pedir un lyft ahora que existe el virus intent:travel
118 | cómo va el coronavirus? ¿qué dicen las noticias intent:news_and_press
119 | cuáles son las noticias destacadas de hoy en el mundo sobre covid intent:news_and_press
120 | dame los titulares de las principales noticias sobre coronavirus intent:news_and_press
121 | cuáles fueron las noticias sobre el coronavirus hoy en argentina intent:news_and_press
122 | muéstreme las últimas noticias sobre el virus en colombia intent:news_and_press
123 | quiero tener actualizaciones diarias sobre el virus en canadá intent:news_and_press
124 | dame info sobre el virus en australia intent:news_and_press
125 | comparte en mi muro intent:share
126 | dale like a este comentario y compártelo intent:share
127 | podrías compartir esto que me acabás de decir intent:share
128 | compartamos esto en mis redes intent:share
129 | necesito que todos sepan lo que acabas de contarme, compartámoslo intent:share
130 | lo que me estás contando, compártelo intent:share
131 | usa mis redes sociales para compartir este servicio intent:share
132 | compartir la noticia con mis amigos intent:share
133 | puedes compartirle a mis familiares esta información intent:share
134 | cómo se hace para que todos en mis redes tengan este servicio intent:share
135 | compartir el vínculo intent:share
136 | compártale el asistente de corona virus a mi familia intent:share
137 | buenas noches, compadre intent:hi
138 | cómo está mi informante intent:hi
139 | qué más, coronin intent:hi
140 | cómo va eso, corona intent:hi
141 | hola, asistente corona virus intent:hi
142 | ¡qué tal intent:hi
143 | cómo dice que le va, doctor digital intent:hi
144 | ¡hola, corona intent:hi
145 | millones de gracias por la información intent:okay_thanks
146 | excelente, gracias intent:okay_thanks
147 | me queda claro, muchas gracias intent:okay_thanks
148 | gracias por toda la información intent:okay_thanks
149 | oye, doctor, que sepas que agradezco la información intent:okay_thanks
150 | sabes que siento agradecimiento hacia ti intent:okay_thanks
151 | estoy muy agradecida por los datos que me diste intent:okay_thanks
152 | gracias por decirme, muy útil intent:okay_thanks
153 | un profundo agradecimiento intent:okay_thanks
154 | dame información de cómo puedo hacer una donación intent:donate
155 | quisiera realizar una donación intent:donate
156 | por favor, aplica una donación de 1000 pesos a esta propuesta intent:donate
157 | quiero donar elementos de limpieza al pueblo de mi mamá intent:donate
158 | querría que hicieras una donación a la provincia de buenos aires, por favor intent:donate
159 | me ayudarías a donar parte de mis ahorros intent:donate
160 | ayúdame a enviar dinero a san juan intent:donate
161 | me gustaría regalarles dinero a los más afectados por el covid intent:donate
162 |
--------------------------------------------------------------------------------
/data/mcid/original/fr/eval.tsv:
--------------------------------------------------------------------------------
1 | combien d'enfants infectés par coronavirus en europe intent:latest_numbers
2 | combien de gens sont hospitalisés en russie intent:latest_numbers
3 | derniers chiffres coronavirus intent:latest_numbers
4 | pourcentage de personnes guéries ontario canada intent:latest_numbers
5 | statistiques covid-19 les plus récentes par pays intent:latest_numbers
6 | nouveaux cas de covid-19 cette semaine en chine intent:latest_numbers
7 | statistiques du covid-19 par état aux états-unis intent:latest_numbers
8 | évolution nombre d'hospitalisations france intent:latest_numbers
9 | dernières statistiques d'hospitalisations au canada intent:latest_numbers
10 | quel est le nombre de personnes infectées à travers le monde intent:latest_numbers
11 | masque pour enfant intent:protect_yourself
12 | combien de temps dois-je garder un masque intent:protect_yourself
13 | puis-je réutiliser un masque intent:protect_yourself
14 | les masques en tissus sont-ils efficaces intent:protect_yourself
15 | a combien de dégrés laver les habits pour tuer le virus intent:protect_yourself
16 | secourir une personne en période de distanciation sociale intent:protect_yourself
17 | traitements préventifs intent:protect_yourself
18 | méthodes de protection covid-19 intent:protect_yourself
19 | merci encore une fois intent:okay_thanks
20 | je te remercie énormement doc covid intent:okay_thanks
21 | merci beaucoup docteur bot intent:okay_thanks
22 | tu es très efficace, merci intent:okay_thanks
23 | merci pour les renseignements intent:okay_thanks
24 | merci pour l'explication intent:okay_thanks
25 | parfait merci bien intent:okay_thanks
26 | je te remercie pour toutes ces infos intent:okay_thanks
27 | mille mercis pour ces infos intent:okay_thanks
28 | merci énormément mon ami intent:okay_thanks
29 | cimer le doc intent:okay_thanks
30 | je te remercie intent:okay_thanks
31 | merci infiniment intent:okay_thanks
32 | dis-donc, t'as réponse à tout toi intent:okay_thanks
33 | merci de ton aide intent:okay_thanks
34 | merci je te revaudrai ça intent:okay_thanks
35 | don pour les victimes du coronavirus au kenya intent:donate
36 | faire un don pour lutter contre le covid-19 intent:donate
37 | comment faire un don pour les personnes touchées intent:donate
38 | urgence sanitaire, comment aider financièrement intent:donate
39 | faire une contribution pour la corée intent:donate
40 | je veux verser une somme pour soutenir la recherche sur le coronavirus intent:donate
41 | aide financière coronavirus intent:donate
42 | donne à la fondation louis pasteur intent:donate
43 | pays origine coronavirus intent:what_is_corona
44 | début du covid intent:what_is_corona
45 | c’est une maladie respiratoire intent:what_is_corona
46 | que veut dire covid-19 intent:what_is_corona
47 | c'est quoi le nouveau coronavirus intent:what_is_corona
48 | différence coronavirus grippe intent:what_is_corona
49 | déclaration pandémie oms covid-19 intent:what_is_corona
50 | peut-on guérir du coronavirus intent:what_is_corona
51 | test covid-19 france intent:what_is_corona
52 | quel traitement contre le coronavirus intent:what_is_corona
53 | puis-je avoir des renseignements sur le covid19 intent:what_is_corona
54 | le coronavirus est-il plus mortel que la grippe saisonnière intent:what_is_corona
55 | c'est quoi les premiers symptômes lors de l'infection intent:what_are_symptoms
56 | peut-on tomber malade du coronavirus une deuxième fois intent:what_are_symptoms
57 | quand apparaissent les symptômes intent:what_are_symptoms
58 | est-ce que j'ai le coronavirus intent:what_are_symptoms
59 | pas de symptômes corona intent:what_are_symptoms
60 | quelles différences entre les symptômes de la grippe et du coronavirus intent:what_are_symptoms
61 | symptômes les plus fréquents intent:what_are_symptoms
62 | puis-je m'autodiagnostiquer du covid19 intent:what_are_symptoms
63 | peut-on attraper le virus en jouant dans le sable intent:how_does_corona_spread
64 | comment les enfants partagent le virus intent:how_does_corona_spread
65 | dis-moi comment se transmet le covid-19 intent:how_does_corona_spread
66 | le coronavirus peut-il se propager via la nourriture intent:how_does_corona_spread
67 | qu'est-ce que la transmission communautaire intent:how_does_corona_spread
68 | informations sur la propagation du coronavirus intent:how_does_corona_spread
69 | propagation du virus dans le milieu médical intent:how_does_corona_spread
70 | que veut dire la charge virale du virus intent:how_does_corona_spread
71 | quels sont les moyens de propagation confirmés pour le covid-19 intent:how_does_corona_spread
72 | le covid-19 se transmet-il par la transpiration intent:how_does_corona_spread
73 | est-ce qu’il faut tester mon chien intent:can_i_get_from_feces_animal_pets
74 | chien et chat covid 19 intent:can_i_get_from_feces_animal_pets
75 | dois-je me laver les mains après avoir joué avec mon chien intent:can_i_get_from_feces_animal_pets
76 | contamination au covid par voie alimentaire intent:can_i_get_from_feces_animal_pets
77 | quels animaux peuvent contaminer un homme du coronavirus intent:can_i_get_from_feces_animal_pets
78 | est-ce que les lapins transmettent le coronavirus intent:can_i_get_from_feces_animal_pets
79 | quel est le risque d'attraper le covid-19 par un animal de compagnie intent:can_i_get_from_feces_animal_pets
80 | durée de vie coronavirus sur le béton intent:can_i_get_from_packages_surfaces
81 | durée de vie du coronavirus sur plastique intent:can_i_get_from_packages_surfaces
82 | coronavirus, durée vie, verre intent:can_i_get_from_packages_surfaces
83 | combien de temps le coronavirus reste sur les objets intent:can_i_get_from_packages_surfaces
84 | survie du corona sur les vêtements intent:can_i_get_from_packages_surfaces
85 | temps de survie corona boîtes en carton intent:can_i_get_from_packages_surfaces
86 | le virus survit-il sur les plantes intent:can_i_get_from_packages_surfaces
87 | le covid-19 survit-il au frigo intent:can_i_get_from_packages_surfaces
88 | est-il conseillé de désinfecter mes courses avant de les ranger intent:can_i_get_from_packages_surfaces
89 | combien de temps le coronavirus survit-il sur une matière non organique intent:can_i_get_from_packages_surfaces
90 | le coronavirus survit-il sur des surfaces très chaudes intent:can_i_get_from_packages_surfaces
91 | nombre de pays à haut risque intent:what_if_i_visited_high_risk_area
92 | coronavirus endroits à éviter intent:what_if_i_visited_high_risk_area
93 | la chine est-elle encore une zone à haut risque intent:what_if_i_visited_high_risk_area
94 | puis-je retourner sans risques de france à wuhan après le déconfinement là-bas intent:what_if_i_visited_high_risk_area
95 | quels sont les procédures de confinement pour les citoyens français rapatriés de zones à risque comme la ville de new-york intent:what_if_i_visited_high_risk_area
96 | où sont les zones covid 19 à haut risques intent:what_if_i_visited_high_risk_area
97 | quelles villes sont les plus risquées en france pour le covid 19 intent:what_if_i_visited_high_risk_area
98 | où est -il interdit d'aller car trop de risque à cause du coronavirus intent:what_if_i_visited_high_risk_area
99 | remède maison coronavirus intent:what_are_treatment_options
100 | aliments pouvant combattre le coronavirus intent:what_are_treatment_options
101 | faire du sport peut-il aider à guérir le coronavirus intent:what_are_treatment_options
102 | l'ibuprofène est-il recommandé pour traiter la fièvre si on est atteint par le covid-19 intent:what_are_treatment_options
103 | les remèdes naturels guérissent-ils le virus intent:what_are_treatment_options
104 | où en sont les recherches de vaccin contre le coronavirus intent:what_are_treatment_options
105 | est-ce la chloroquine utilisée pour traiter les le coronavirus intent:what_are_treatment_options
106 | exercices yoga toux douloureuse intent:what_are_treatment_options
107 | combien de temps de repos après infection au coronavirus intent:what_are_treatment_options
108 | quelles sont mes options pour traiter une infection virale intent:what_are_treatment_options
109 | immunostimulants à essayer pour se proteger du covid-19 intent:what_are_treatment_options
110 | covid-19 et aspirine intent:what_are_treatment_options
111 | quels médicaments peut-on prendre quand on a le coronavirus intent:what_are_treatment_options
112 | est-ce que je peux prendre un anti-migraineux alors que je suis contaminée intent:what_are_treatment_options
113 | quand peut-on espérer avoir un vaccin intent:what_are_treatment_options
114 | quels médicaments efficaces contre corona intent:what_are_treatment_options
115 | est-ce que les huiles essentielles sont efficaces contre le coronavirus intent:what_are_treatment_options
116 | antibiotiques les plus efficaces contre corona intent:what_are_treatment_options
117 | est-ce un mythe que le virus est faible dans les pays chaud intent:myths
118 | le coronavirus est-il réellement créé dans le labo p4 intent:myths
119 | manger de l'ail empêche la contamination intent:myths
120 | le pangolin est-il un mythe dans l'incubation du virus intent:myths
121 | quels sont les mythes sur la propagation du virus en inde intent:myths
122 | il faut boire beaucoup … d’alcool – il tue le coronavirus intent:myths
123 | le covid 19 provient des chauves-souris intent:myths
124 | des états totalitaires ont importé et ont fait circuler de manière délibérée le coronavirus dans leur pays pour se débarrasser de l'opposition intent:myths
125 | mythe coronavirus a été créé dans un laboratoire intent:myths
126 | corona pangolin ou chauve-souris intent:myths
127 | puis-je prendre le métro à toronto intent:travel
128 | comment faire rembourser mon trajet vers une région à risque ? paris intent:travel
129 | conseils de sécurité pour les trajets professionnels intent:travel
130 | quel est le moyen de transport à éviter durant l'épidémie de covid-19 intent:travel
131 | puis-je quitter ma ville pendant l'épidémie de covid 19 intent:travel
132 | risque contamination voyage métro intent:travel
133 | est-ce que je peux aller en chine demain intent:travel
134 | voyager en europe cet été intent:travel
135 | de combien dois-je repousser mon voyage intent:travel
136 | remboursement vol annulé corona intent:travel
137 | peux-tu me communiquer les dernières informations disponibles au sujet du covid19 intent:news_and_press
138 | dernières mises à jour corona en chine intent:news_and_press
139 | peux-tu me donner l'actualité sur le covid19 intent:news_and_press
140 | informations de dernière minute sur le coronavirus intent:news_and_press
141 | derniers communiqués de presse concernant le coronavirus intent:news_and_press
142 | donne-moi les infos les plus récentes sur l'épidémie de corona en asie intent:news_and_press
143 | peux-tu me trouver les dernières informations sur le coronavirus aux etats-unis intent:news_and_press
144 | fais-moi voir les infos du covid d’hier intent:news_and_press
145 | news d'aujourd'hui sur le corona virus intent:news_and_press
146 | est-ce qu’il y a du nouveau sur le virus intent:news_and_press
147 | que s’est-il passé ces dernières heures intent:news_and_press
148 | partage ce lien sur snap intent:share
149 | partager ce service sur fb intent:share
150 | partager ces infos sur insta intent:share
151 | peux-tu publier ta réponse sur mon facebook intent:share
152 | peux-tu transmettre cela à mon père s'il te plaît intent:share
153 | comment faire pour partager tes réponses avec mes amis sur twitter intent:share
154 | lance-toi sur les réseaux sociaux intent:share
155 | comment te partager avec mes amis intent:share
156 | partage ton lien sur instagram intent:share
157 | partage les services de ce bot avec mes contacts intent:share
158 | il faut mettre cet article corona sur mon mur facebook intent:share
159 | coucou docteur intent:hi
160 | bonjour médecin covid intent:hi
161 | coucou corona intent:hi
162 | salut, comment ça va bot intent:hi
163 | tu vas bien docteur covid intent:hi
164 | comment vas-tu doc intent:hi
165 | hey monsieur info corona intent:hi
166 | bonjour infocorona intent:hi
167 | hey virus info intent:hi
168 | y'a t'il un docteur ici intent:hi
169 | salut spécialiste du corona intent:hi
170 | salut mon pote le bot intent:hi
171 | ave, corona, ceux qui vont mourir te saluent intent:hi
172 | hey médecin coronavirus, comment vas tu intent:hi
173 | ça va doc intent:hi
174 |
--------------------------------------------------------------------------------
/data/mcid/showDataset.py:
--------------------------------------------------------------------------------
1 | # this script extracts a sub word embedding from the entire one
2 | from gensim.models.keyedvectors import KeyedVectors
3 | import numpy as np
4 | import scipy.io as sio
5 | import nltk
6 | import pdb
7 | import random
8 | import csv
9 | import string
10 | import contractions
11 | import json
12 | import random
13 |
14 | def getDomainIntent(domainLabFile):
15 | domain2lab = {}
16 | lab2domain = {}
17 | currentDomain = None
18 | with open(domainLabFile,'r') as f:
19 | for line in f:
20 | if ':' in line and currentDomain == None:
21 | currentDomain = cleanUpSentence(line)
22 | domain2lab[currentDomain] = []
23 | elif line == "\n":
24 | currentDomain = None
25 | else:
26 | intent = cleanUpSentence(line)
27 | domain2lab[currentDomain].append(intent)
28 |
29 | for key in domain2lab:
30 | domain = key
31 | labList = domain2lab[key]
32 | for lab in labList:
33 | lab2domain[lab] = domain
34 |
35 | return domain2lab, lab2domain
36 |
37 | def cleanUpSentence(sentence):
38 | # sentence: a string, like " Hello, do you like apple? I hate it!! "
39 |
40 | # strip
41 | sentence = sentence.strip()
42 |
43 | # lower case
44 | sentence = sentence.lower()
45 |
46 | # fix contractions
47 | sentence = contractions.fix(sentence)
48 |
49 | # remove '_' and '-'
50 | sentence = sentence.replace('-',' ')
51 | sentence = sentence.replace('_',' ')
52 |
53 | # remove all punctuations
54 | sentence = ''.join(ch for ch in sentence if ch not in string.punctuation)
55 |
56 | return sentence
57 |
58 |
59 | def check_data_format(file_path):
60 | for line in open(file_path,'rb'):
61 | arr =str(line.strip(),'utf-8')
62 | arr = arr.split('\t')
63 | label = [w for w in arr[0].split(' ')]
64 | question = [w for w in arr[1].split(' ')]
65 |
66 | if len(label) == 0 or len(question) == 0:
67 | print("[ERROR] Find empty data: ", label, question)
68 | return False
69 |
70 | return True
71 |
72 |
73 | def save_data(data, file_path):
74 | # save data to disk
75 | with open(file_path, 'w') as f:
76 | json.dump(data, f)
77 | return
78 |
79 | def save_domain_intent(data, file_path):
80 | domain2intent = {}
81 | for line in data:
82 | domain = line[0]
83 | intent = line[1]
84 |
85 | if not domain in domain2intent:
86 | domain2intent[domain] = set()
87 |
88 | domain2intent[domain].add(intent)
89 |
90 | # save data to disk
91 | print("Saving domain intent out ... format: domain \t intent")
92 | with open(file_path,"w") as f:
93 | for domain in domain2intent:
94 | intentSet = domain2intent[domain]
95 | for intent in intentSet:
96 | f.write("%s\t%s\n" % (domain, intent))
97 | return
98 |
99 | def display_data(data):
100 | # dataset count
101 | print("[INFO] We have %d dataset."%(len(data)))
102 |
103 | datasetName = 'MCID'
104 | data = data[datasetName]
105 |
106 | # domain count
107 | domainName = set()
108 | for domain in data:
109 | domainName.add(domain)
110 | print("[INFO] There are %d domains."%(len(domainName)))
111 | print(domainName)
112 |
113 | # intent count
114 | intentName = set()
115 | for domain in data:
116 | for d in data[domain]:
117 | lab = d[1][0]
118 | intentName.add(lab)
119 | intentName = list(intentName)
120 | intentName.sort()
121 | print("[INFO] There are %d intent."%(len(intentName)))
122 | print(intentName)
123 |
124 | # data count
125 | count = 0
126 | for domain in data:
127 | for d in data[domain]:
128 | count = count+1
129 | print("[INFO] Data count: %d"%(count))
130 |
131 | # intent for each domain
132 | domain2intentDict = {}
133 | for domain in data:
134 | if not domain in domain2intentDict:
135 | domain2intentDict[domain] = set()
136 |
137 | for d in data[domain]:
138 | lab = d[1][0]
139 | domain2intentDict[domain].add(lab)
140 | print("[INFO] Intent for each domain.")
141 | print(domain2intentDict)
142 |
143 | # data for each intent
144 | intent2count = {}
145 | for domain in data:
146 | for d in data[domain]:
147 | lab = d[1][0]
148 | if not lab in intent2count:
149 | intent2count[lab] = 0
150 | intent2count[lab] = intent2count[lab]+1
151 | print("[INFO] Intent count")
152 | print(intent2count)
153 |
154 | # examples of data
155 | exampleNum = 3
156 | while not exampleNum == 0:
157 | for domain in data:
158 | for d in data[domain]:
159 | lab = d[1]
160 | utt = d[0]
161 | if random.random() < 0.001:
162 | print("[INFO] Example:--%s, %s, %s, %s"%(datasetName, domain, lab, utt))
163 | exampleNum = exampleNum-1
164 | break
165 | if (exampleNum==0):
166 | break
167 |
168 | return None
169 |
170 |
171 | ##
172 | # @brief clean up data, including intent and utterance
173 | #
174 | # @param data a list of data
175 | #
176 | # @return
177 | def cleanData(data):
178 | newData = []
179 | for d in data:
180 | utt = d[0]
181 | lab = d[1]
182 |
183 | uttClr = cleanUpSentence(utt)
184 | labClr = cleanUpSentence(lab)
185 | newData.append([labClr, uttClr])
186 |
187 | return newData
188 |
189 | def constructData(data, intent2domain):
190 | dataset2domain = {}
191 | datasetName = 'CLINC150'
192 | dataset2domain[datasetName] = {}
193 | for d in data:
194 | lab = d[0]
195 | utt = d[1]
196 | domain = intent2domain[lab]
197 | if not domain in dataset2domain[datasetName]:
198 | dataset2domain[datasetName][domain] = []
199 | dataField = [utt, [lab]]
200 | dataset2domain[datasetName][domain].append(dataField)
201 |
202 | return dataset2domain
203 |
204 |
205 | def read_data(file_path):
206 | with open(file_path) as json_file:
207 | data = json.load(json_file)
208 | return data
209 |
210 | # read in data
211 | #dataPath = "/data1/haode/projects/EMDIntentFewShot/SPIN_refactor/data/refactor_OOS/dataset.json"
212 | dataPath = "./dataset.json"
213 | print("Loading data ...", dataPath)
214 | # read lines, collect data count for different classes
215 | data = read_data(dataPath)
216 |
217 | display_data(data)
218 | print("Display.. done")
219 |
--------------------------------------------------------------------------------
/data/oos/original/domain_intent.txt:
--------------------------------------------------------------------------------
1 | banking:
2 | transfer
3 | transactions
4 | balance
5 | freeze account
6 | pay bill
7 | bill balance
8 | bill due
9 | interest rate
10 | routing
11 | min payment
12 | order checks
13 | pin change
14 | report fraud
15 | spending history
16 | account blocked
17 |
18 | credit cards:
19 | credit limit
20 | credit limit change
21 | international fees
22 | expiration date
23 | credit score
24 | replacement card duration
25 | card declined
26 | improve credit score
27 | report lost card
28 | damaged card
29 | redeem rewards
30 | apr
31 | application status
32 | new card
33 | rewards balance
34 |
35 | kitchen & dining:
36 | recipe
37 | restaurant reservation
38 | restaurant reviews
39 | nutrition info
40 | ingredients list
41 | cancel reservation
42 | how busy
43 | ingredient substitution
44 | confirm reservation
45 | meal suggestion
46 | accept reservations
47 | restaurant suggestion
48 | cook time
49 | calories
50 | food last
51 |
52 | home:
53 | shopping list
54 | shopping list update
55 | smart home
56 | reminder
57 | reminder update
58 | order
59 | order status
60 | what song
61 | todo list
62 | todo list update
63 | next song
64 | play music
65 | update playlist
66 | calendar
67 | calendar update
68 |
69 | auto & commute:
70 | traffic
71 | tire change
72 | tire pressure
73 | last maintenance
74 | schedule maintenance
75 | uber
76 | jump start
77 | oil change how
78 | oil change when
79 | mpg
80 | current location
81 | distance
82 | gas
83 | gas type
84 | directions
85 |
86 | travel:
87 | travel alert
88 | flight status
89 | travel notification
90 | travel suggestion
91 | exchange rate
92 | plug type
93 | lost luggage
94 | international visa
95 | translate
96 | timezone
97 | vaccines
98 | book flight
99 | book hotel
100 | car rental
101 | carry on
102 |
103 | utility:
104 | roll dice
105 | definition
106 | flip coin
107 | measurement conversion
108 | timer
109 | spelling
110 | date
111 | text
112 | weather
113 | make call
114 | time
115 | share location
116 | calculator
117 | alarm
118 | find phone
119 |
120 | work:
121 | direct deposit
122 | pto balance
123 | insurance
124 | insurance change
125 | next holiday
126 | w2
127 | payday
128 | taxes
129 | meeting schedule
130 | pto request
131 | income
132 | rollover 401k
133 | schedule meeting
134 | pto request status
135 | pto used
136 |
137 | small talk:
138 | goodbye
139 | where are you from
140 | greeting
141 | tell joke
142 | how old are you
143 | what is your name
144 | who made you
145 | thank you
146 | what can i ask you
147 | do you have pets
148 | what are your hobbies
149 | are you a bot
150 | meaning of life
151 | who do you work for
152 | fun fact
153 |
154 | meta:
155 | cancel
156 | change accent
157 | change ai name
158 | change language
159 | change speed
160 | change user name
161 | change volume
162 | maybe
163 | no
164 | repeat
165 | reset settings
166 | sync device
167 | user name
168 | whisper mode
169 | yes
170 |
--------------------------------------------------------------------------------
/data/oos/original/oos-eval-master/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.json~
3 | *.csv~
4 | .DS_Store
5 |
--------------------------------------------------------------------------------
/data/oos/original/oos-eval-master/README.md:
--------------------------------------------------------------------------------
1 | [](https://clinc.com)
2 |
3 | # An Evaluation Dataset for Intent Classification and Out-of-Scope Prediction
4 | Repository that accompanies [An Evaluation Dataset for Intent Classification and Out-of-Scope Prediction](https://www.aclweb.org/anthology/D19-1131/).
5 |
6 |
7 | ## FAQs
8 | ### 1. What are the relevant files?
9 | See `data/data_full.json` for the "full" dataset. This is the dataset used in Table 1 (the "Full" columns). This file contains 150 "in-scope" intent classes, each with 100 train, 20 validation, and 30 test samples. There are 100 train and validation out-of-scope samples, and 1000 out-of-scope test samples.
10 |
11 | ### 2. What is the name of the dataset?
12 | The dataset was not given a name in the original paper, but [others](https://arxiv.org/pdf/2003.04807.pdf) have called it `CLINC150`.
13 |
14 | ### 3. What is this dataset for?
15 | This dataset is for evaluating the performance of intent classification systems in the presence of "out-of-scope" queries. By "out-of-scope", we mean queries that do not fall into any of the system-supported intent classes. Most datasets include only data that is "in-scope". Our dataset includes both in-scope and out-of-scope data. You might also know the term "out-of-scope" by other terms, including "out-of-domain" or "out-of-distribution".
16 |
17 | ### 4. What language is the dataset in?
18 | All queries are in English.
19 |
20 | ### 5. How does your dataset/evaluation handle multi-intent queries?
21 | All samples/queries in our dataset are single-intent samples. We consider the problem of multi-intent classification to be future work.
22 |
23 | ### 6. How did you gather the dataset?
24 | We used crowdsourcing to generate the dataset. We asked crowd workers to either paraphrase "seed" phrases, or respond to scenarios (e.g. "pretend you need to book a flight, what would you say?"). We used crowdsourcing to generate data for both in-scope and out-of-scope data.
25 |
26 | ## Citation
27 |
28 | If you find our dataset useful, please be sure to cite:
29 |
30 | ```
31 | @inproceedings{larson-etal-2019-evaluation,
32 | title = "An Evaluation Dataset for Intent Classification and Out-of-Scope Prediction",
33 | author = "Larson, Stefan and
34 | Mahendran, Anish and
35 | Peper, Joseph J. and
36 | Clarke, Christopher and
37 | Lee, Andrew and
38 | Hill, Parker and
39 | Kummerfeld, Jonathan K. and
40 | Leach, Kevin and
41 | Laurenzano, Michael A. and
42 | Tang, Lingjia and
43 | Mars, Jason",
44 | booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)",
45 | year = "2019",
46 | url = "https://www.aclweb.org/anthology/D19-1131"
47 | }
48 | ```
49 |
--------------------------------------------------------------------------------
/data/oos/original/oos-eval-master/clinc_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/data/oos/original/oos-eval-master/clinc_logo.png
--------------------------------------------------------------------------------
/data/oos/original/oos-eval-master/hyperparameters.csv:
--------------------------------------------------------------------------------
1 | ,,,,,,,,,,,
2 | SVM-OH Hyperparameters,,,,,,,,,,,
3 | ,Full,classifier,vectorizer,kernel,C,,,,,,
4 | ,,svm,onehot,linear,1,,,,,,
5 | ,,,,,,,,,,,
6 | ,Small,classifier,vectorizer,kernel,C,,,,,,
7 | ,,svm,onehot,linear,1,,,,,,
8 | ,,,,,,,,,,,
9 | ,Imbalanced,classifier,vectorizer,kernel,C,,,,,,
10 | ,,svm,onehot,linear,1,,,,,,
11 | ,,,,,,,,,,,
12 | ,OOS+,classifier,vectorizer,kernel,C,,,,,,
13 | ,,svm,onehot,linear,1,,,,,,
14 | ,,,,,,,,,,,
15 | ,UnderSample Binary,classifier,vectorizer,kernel,C,,,,,,
16 | ,,svm,onehot,linear,1,,,,,,
17 | ,,,,,,,,,,,
18 | ,Wiki Aug Binary,classifier,vectorizer,kernel,C,,,,,,
19 | ,,svm,onehot,linear,1,,,,,,
20 | ,,,,,,,,,,,
21 | ,,,,,,,,,,,
22 | FastText Hyperparameters,,,,,,,,,,,
23 | ,Full,classifier,learning_rate,dim,word_ngrams,lr_update_rate,ws,loss,,,
24 | ,,fasttext,1,100,3,100,3,ns,,,
25 | ,,,,,,,,,,,
26 | ,Small,classifier,learning_rate,dim,word_ngrams,lr_update_rate,ws,loss,,,
27 | ,,fasttext,0.1,200,1,200,7,softmax,,,
28 | ,,,,,,,,,,,
29 | ,Imbalanced,classifier,learning_rate,dim,word_ngrams,lr_update_rate,ws,loss,,,
30 | ,,fasttext,1,400,2,200,7,ns,,,
31 | ,,,,,,,,,,,
32 | ,OOS+,classifier,learning_rate,dim,word_ngrams,lr_update_rate,ws,loss,,,
33 | ,,fasttext,1,50,2,200,7,ns,,,
34 | ,,,,,,,,,,,
35 | ,UnderSample Binary,classifier,learning_rate,dim,word_ngrams,lr_update_rate,ws,loss,,,
36 | ,,fasttext,0.1,200,4,100,3,ns,,,
37 | ,,,,,,,,,,,
38 | ,Wiki Aug Binary,classifier,learning_rate,dim,word_ngrams,lr_update_rate,ws,loss,,,
39 | ,,fasttext,1,400,4,100,3,ns,,,
40 | ,,,,,,,,,,,
41 | ,,,,,,,,,,,
42 | CNN Hyperparameters,,,,,,,,,,,
43 | ,Full,classifier,num_filters,conv_activation,f_activation,strides,padding,dense_dim,dropout,s_activation,batch_size
44 | ,,cnn,200,softmax,relu,1,valid,300,0.6,softmax,16
45 | ,,,,,,,,,,,
46 | ,Small,classifier,num_filters,conv_activation,f_activation,strides,padding,dense_dim,dropout,s_activation,batch_size
47 | ,,cnn,150,softmax,relu,1,valid,200,0.6,softmax,16
48 | ,,,,,,,,,,,
49 | ,Imbalanced,classifier,num_filters,conv_activation,f_activation,strides,padding,dense_dim,dropout,s_activation,batch_size
50 | ,,cnn,150,softmax,relu,1,valid,300,0.6,softmax,16
51 | ,,,,,,,,,,,
52 | ,OOS+,classifier,num_filters,conv_activation,f_activation,strides,padding,dense_dim,dropout,s_activation,batch_size
53 | ,,cnn,200,softmax,relu,1,valid,200,0.6,softmax,16
54 | ,,,,m,,,,,,,
55 | ,UnderSample Binary,classifier,num_filters,conv_activation,f_activation,strides,padding,dense_dim,dropout,s_activation,batch_size
56 | ,,cnn,150,softmax,relu,1,valid,100,0.2,softmax,16
57 | ,,,,,,,,,,,
58 | ,Wiki Aug Binary,classifier,num_filters,conv_activation,f_activation,strides,padding,dense_dim,dropout,s_activation,batch_size
59 | ,,cnn,200,softmax,relu,1,valid,300,0.2,softmax,16
60 | ,,,,,,,,,,,
61 | ,,,,,,,,,,,
62 | MLP Hyperparameters,,,,,,,,,,,
63 | ,Full,classifier,f_hidden_activation,s_hidden_activation,hidden_dim,vectorizer,batch_size,dropout,,,
64 | ,,mlp,tanh,softmax,400,use,1,0,,,
65 | ,,,,,,,,,,,
66 | ,Small,classifier,f_hidden_activation,s_hidden_activation,hidden_dim,vectorizer,batch_size,dropout,,,
67 | ,,mlp,tanh,softmax,200,use,1,0.1,,,
68 | ,,,,,,,,,,,
69 | ,Imbalanced,classifier,f_hidden_activation,s_hidden_activation,hidden_dim,vectorizer,batch_size,dropout,,,
70 | ,,mlp,tanh,softmax,200,use,64,0,,,
71 | ,,,,,,,,,,,
72 | ,OOS+,classifier,f_hidden_activation,s_hidden_activation,hidden_dim,vectorizer,batch_size,dropout,,,
73 | ,,mlp,tanh,softmax,200,use,16,0.1,,,
74 | ,,,,,,,,,,,
75 | ,UnderSample Binary,classifier,f_hidden_activation,s_hidden_activation,hidden_dim,vectorizer,batch_size,dropout,,,
76 | ,,mlp,tanh,softmax,100,use,64,0,,,
77 | ,,,,,,,,,,,
78 | ,Wiki Aug Binary,classifier,f_hidden_activation,s_hidden_activation,hidden_dim,vectorizer,batch_size,dropout,,,
79 | ,,mlp,tanh,softmax,300,use,16,0,,,
80 | ,,,,,,,,,,,
81 | ,,,,,,,,,,,
82 | BERT Hyperparameters,,,,,,,,,,,
83 | ,Full,classifier,learning_rate,warmup_proportion,train_batch_size,num_train_epochs,gradient_accumulation_steps,bert_model,,,
84 | ,,bert,4.00E-05,0.1,32,5,1,bert-large-uncased,,,
85 | ,,,,,,,,,,,
86 | ,Small,classifier,learning_rate,warmup_proportion,train_batch_size,num_train_epochs,gradient_accumulation_steps,bert_model,,,
87 | ,,bert,4.00E-05,0.1,32,5,1,bert-large-uncased,,,
88 | ,,,,,,,,,,,
89 | ,Imbalanced,classifier,learning_rate,warmup_proportion,train_batch_size,num_train_epochs,gradient_accumulation_steps,bert_model,,,
90 | ,,bert,4.00E-05,0.1,32,5,1,bert-large-uncased,,,
91 | ,,,,,,,,,,,
92 | ,OOS+,classifier,learning_rate,warmup_proportion,train_batch_size,num_train_epochs,gradient_accumulation_steps,bert_model,,,
93 | ,,bert,4.00E-05,0.1,32,5,1,bert-large-uncased,,,
94 | ,,,,,,,,,,,
95 | ,UnderSample Binary,classifier,learning_rate,warmup_proportion,train_batch_size,num_train_epochs,gradient_accumulation_steps,bert_model,,,
96 | ,,bert,2.00E-05,0.1,32,5,1,bert-large-uncased,,,
97 | ,,,,,,,,,,,
98 | ,Wiki Aug Binary,classifier,learning_rate,warmup_proportion,train_batch_size,num_train_epochs,gradient_accumulation_steps,bert_model,,,
99 | ,,bert,4.00E-05,0.1,32,5,1,bert-large-uncased,,,
--------------------------------------------------------------------------------
/data/oos/original/oos-eval-master/paper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/data/oos/original/oos-eval-master/paper.pdf
--------------------------------------------------------------------------------
/data/oos/original/oos-eval-master/poster.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/data/oos/original/oos-eval-master/poster.pdf
--------------------------------------------------------------------------------
/data/oos/original/oos-eval-master/supplementary.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/data/oos/original/oos-eval-master/supplementary.pdf
--------------------------------------------------------------------------------
/data/oos/showDataset.py:
--------------------------------------------------------------------------------
1 | # this script extracts a sub word embedding from the entire one
2 | from gensim.models.keyedvectors import KeyedVectors
3 | import numpy as np
4 | import scipy.io as sio
5 | import nltk
6 | import pdb
7 | import random
8 | import csv
9 | import string
10 | import contractions
11 | import json
12 | import random
13 |
14 | def getDomainIntent(domainLabFile):
15 | domain2lab = {}
16 | lab2domain = {}
17 | currentDomain = None
18 | with open(domainLabFile,'r') as f:
19 | for line in f:
20 | if ':' in line and currentDomain == None:
21 | currentDomain = cleanUpSentence(line)
22 | domain2lab[currentDomain] = []
23 | elif line == "\n":
24 | currentDomain = None
25 | else:
26 | intent = cleanUpSentence(line)
27 | domain2lab[currentDomain].append(intent)
28 |
29 | for key in domain2lab:
30 | domain = key
31 | labList = domain2lab[key]
32 | for lab in labList:
33 | lab2domain[lab] = domain
34 |
35 | return domain2lab, lab2domain
36 |
37 | def cleanUpSentence(sentence):
38 | # sentence: a string, like " Hello, do you like apple? I hate it!! "
39 |
40 | # strip
41 | sentence = sentence.strip()
42 |
43 | # lower case
44 | sentence = sentence.lower()
45 |
46 | # fix contractions
47 | sentence = contractions.fix(sentence)
48 |
49 | # remove '_' and '-'
50 | sentence = sentence.replace('-',' ')
51 | sentence = sentence.replace('_',' ')
52 |
53 | # remove all punctuations
54 | sentence = ''.join(ch for ch in sentence if ch not in string.punctuation)
55 |
56 | return sentence
57 |
58 |
59 | def check_data_format(file_path):
60 | for line in open(file_path,'rb'):
61 | arr =str(line.strip(),'utf-8')
62 | arr = arr.split('\t')
63 | label = [w for w in arr[0].split(' ')]
64 | question = [w for w in arr[1].split(' ')]
65 |
66 | if len(label) == 0 or len(question) == 0:
67 | print("[ERROR] Find empty data: ", label, question)
68 | return False
69 |
70 | return True
71 |
72 |
73 | def save_data(data, file_path):
74 | # save data to disk
75 | with open(file_path, 'w') as f:
76 | json.dump(data, f)
77 | return
78 |
79 | def save_domain_intent(data, file_path):
80 | domain2intent = {}
81 | for line in data:
82 | domain = line[0]
83 | intent = line[1]
84 |
85 | if not domain in domain2intent:
86 | domain2intent[domain] = set()
87 |
88 | domain2intent[domain].add(intent)
89 |
90 | # save data to disk
91 | print("Saving domain intent out ... format: domain \t intent")
92 | with open(file_path,"w") as f:
93 | for domain in domain2intent:
94 | intentSet = domain2intent[domain]
95 | for intent in intentSet:
96 | f.write("%s\t%s\n" % (domain, intent))
97 | return
98 |
99 | def display_data(data):
100 | # dataset count
101 | print("[INFO] We have %d dataset."%(len(data)))
102 |
103 | datasetName = 'CLINC150'
104 | data = data[datasetName]
105 |
106 | # domain count
107 | domainName = set()
108 | for domain in data:
109 | domainName.add(domain)
110 | print("[INFO] There are %d domains."%(len(domainName)))
111 | print(domainName)
112 |
113 | # intent count
114 | intentName = set()
115 | for domain in data:
116 | for d in data[domain]:
117 | lab = d[1][0]
118 | intentName.add(lab)
119 | intentName = list(intentName)
120 | intentName.sort()
121 | print("[INFO] There are %d intent."%(len(intentName)))
122 | print(intentName)
123 |
124 | # data count
125 | count = 0
126 | for domain in data:
127 | for d in data[domain]:
128 | count = count+1
129 | print("[INFO] Data count: %d"%(count))
130 |
131 | # intent for each domain
132 | domain2intentDict = {}
133 | for domain in data:
134 | if not domain in domain2intentDict:
135 | domain2intentDict[domain] = set()
136 |
137 | for d in data[domain]:
138 | lab = d[1][0]
139 | domain2intentDict[domain].add(lab)
140 | print("[INFO] Intent for each domain.")
141 | print(domain2intentDict)
142 |
143 | # data for each intent
144 | intent2count = {}
145 | for domain in data:
146 | for d in data[domain]:
147 | lab = d[1][0]
148 | if not lab in intent2count:
149 | intent2count[lab] = 0
150 | intent2count[lab] = intent2count[lab]+1
151 | print("[INFO] Intent count")
152 | print(intent2count)
153 |
154 | # examples of data
155 | exampleNum = 3
156 | while not exampleNum == 0:
157 | for domain in data:
158 | for d in data[domain]:
159 | lab = d[1]
160 | utt = d[0]
161 | if random.random() < 0.001:
162 | print("[INFO] Example:--%s, %s, %s, %s"%(datasetName, domain, lab, utt))
163 | exampleNum = exampleNum-1
164 | break
165 | if (exampleNum==0):
166 | break
167 |
168 | return None
169 |
170 |
171 | ##
172 | # @brief clean up data, including intent and utterance
173 | #
174 | # @param data a list of data
175 | #
176 | # @return
177 | def cleanData(data):
178 | newData = []
179 | for d in data:
180 | utt = d[0]
181 | lab = d[1]
182 |
183 | uttClr = cleanUpSentence(utt)
184 | labClr = cleanUpSentence(lab)
185 | newData.append([labClr, uttClr])
186 |
187 | return newData
188 |
189 | def constructData(data, intent2domain):
190 | dataset2domain = {}
191 | datasetName = 'CLINC150'
192 | dataset2domain[datasetName] = {}
193 | for d in data:
194 | lab = d[0]
195 | utt = d[1]
196 | domain = intent2domain[lab]
197 | if not domain in dataset2domain[datasetName]:
198 | dataset2domain[datasetName][domain] = []
199 | dataField = [utt, [lab]]
200 | dataset2domain[datasetName][domain].append(dataField)
201 |
202 | return dataset2domain
203 |
204 |
205 | def read_data(file_path):
206 | with open(file_path) as json_file:
207 | data = json.load(json_file)
208 | return data
209 |
210 | # read in data
211 | #dataPath = "/data1/haode/projects/EMDIntentFewShot/SPIN_refactor/data/refactor_OOS/dataset.json"
212 | dataPath = "./dataset.json"
213 | print("Loading data ...", dataPath)
214 | # read lines, collect data count for different classes
215 | data = read_data(dataPath)
216 |
217 | display_data(data)
218 | print("Display.. done")
219 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # This file assembles three popular metric learnign baselines, matching network, prototype network and relation network.
2 | # This file is coded based on train_matchingNet.py.
3 | # coding=utf-8
4 | import torch
5 | import argparse
6 | import time
7 | from transformers import AutoTokenizer
8 |
9 | from utils.models import IntentBERT
10 | from utils.IntentDataset import IntentDataset
11 | from utils.Evaluator import FewShotEvaluator
12 | from utils.commonVar import *
13 | from utils.printHelper import *
14 | from utils.tools import *
15 | from utils.Logger import logger
16 |
17 | def set_seed(seed):
18 | random.seed(seed)
19 | np.random.seed(seed)
20 | torch.manual_seed(seed)
21 |
22 | def parseArgument():
23 | # ==== parse argument ====
24 | parser = argparse.ArgumentParser(description='Evaluate few-shot performance')
25 |
26 | # ==== model ====
27 | parser.add_argument('--seed', default=1, type=int)
28 | parser.add_argument('--mode', default='multi-class',
29 | help='Choose from multi-class')
30 | parser.add_argument('--tokenizer', default='bert-base-uncased',
31 | help="Name of tokenizer")
32 | parser.add_argument('--LMName', default='bert-base-uncased',
33 | help='Name for models and path to saved model')
34 | parser.add_argument('--multi_label', action="store_true")
35 |
36 | # ==== dataset ====
37 | parser.add_argument('--dataDir',
38 | help="Dataset names included in this experiment and separated by comma. "
39 | "For example:'OOS,bank77,hwu64'")
40 | parser.add_argument('--targetDomain',
41 | help='Target domain names and separated by comma')
42 |
43 | # ==== evaluation task ====
44 | parser.add_argument('--way', type=int, default=5)
45 | parser.add_argument('--shot', type=int, default=2)
46 | parser.add_argument('--query', type=int, default=5)
47 | parser.add_argument('--clsFierName', default='Linear',
48 | help="Classifer name for few-shot evaluation"
49 | "Choose from Linear|SVM|NN|Cosine|MultiLabel")
50 |
51 | # ==== training arguments ====
52 | parser.add_argument('--disableCuda', action="store_true")
53 | parser.add_argument('--taskNum', type=int, default=500)
54 |
55 | # ==== other things ====
56 | parser.add_argument('--loggingLevel', default='INFO',
57 | help="python logging level")
58 |
59 | args = parser.parse_args()
60 |
61 | return args
62 |
63 | def main():
64 | # ======= process arguments ======
65 | args = parseArgument()
66 | print(args)
67 |
68 | if args.multi_label:
69 | args.clsFierName = "MultiLabel"
70 |
71 | # ==== setup logger ====
72 | if args.loggingLevel == LOGGING_LEVEL_INFO:
73 | loggingLevel = logging.INFO
74 | elif args.loggingLevel == LOGGING_LEVEL_DEBUG:
75 | loggingLevel = logging.DEBUG
76 | else:
77 | raise NotImplementedError("Not supported logging level %s", args.loggingLevel)
78 | logger.setLevel(loggingLevel)
79 |
80 | # ======= process data ======
81 | # tokenizer
82 | tok = AutoTokenizer.from_pretrained(args.tokenizer)
83 | # load raw dataset
84 | logger.info(f"Loading data from {args.dataDir}")
85 | dataset = IntentDataset(multi_label=args.multi_label)
86 | dataset.loadDataset(splitName(args.dataDir))
87 | dataset.tokenize(tok)
88 | logger.info("----- Testing Data -----")
89 | testData = dataset.splitDomain(splitName(args.targetDomain), multi_label=args.multi_label)
90 |
91 | # ======= prepare model ======
92 | # initialize model
93 | modelConfig = {}
94 | modelConfig['device'] = torch.device('cuda:0' if not args.disableCuda else 'cpu')
95 | modelConfig['clsNumber'] = args.shot
96 | modelConfig['LMName'] = args.LMName
97 | model = IntentBERT(modelConfig)
98 | logger.info("----- IntentBERT initialized -----")
99 |
100 | # setup evaluator
101 | valParam = {"evalTaskNum": args.taskNum, "clsFierName": args.clsFierName, 'multi_label':args.multi_label}
102 | valTaskParam = {"way":args.way, "shot":args.shot, "query":args.query}
103 | tester = FewShotEvaluator(valParam, valTaskParam, testData)
104 |
105 | # set up model
106 | logger.info("Evaluating model ...")
107 | # evaluate before finetuning begins
108 | tester.evaluate(model, tok, args.mode, logLevel='INFO')
109 | # print config
110 | logger.info(args)
111 | logger.info(time.asctime())
112 |
113 | if __name__ == "__main__":
114 | main()
115 | exit(0)
116 |
--------------------------------------------------------------------------------
/images/combined.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/images/combined.png
--------------------------------------------------------------------------------
/images/main.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/images/main.png
--------------------------------------------------------------------------------
/mlm.py:
--------------------------------------------------------------------------------
1 | # This file assembles three popular metric learnign baselines, matching network, prototype network and relation network.
2 | # This file is coded based on train_matchingNet.py.
3 | # coding=utf-8
4 | import os
5 | import torch
6 | import torch.optim as optim
7 | import argparse
8 | import time
9 | from transformers import AutoTokenizer
10 |
11 | from utils.models import IntentBERT
12 | from utils.IntentDataset import IntentDataset
13 | from utils.Trainer import MLMOnlyTrainer
14 | from utils.Evaluator import FewShotEvaluator
15 | from utils.commonVar import *
16 | from utils.printHelper import *
17 | from utils.tools import *
18 | from utils.Logger import logger
19 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
20 |
21 | def set_seed(seed):
22 | random.seed(seed)
23 | np.random.seed(seed)
24 | torch.manual_seed(seed)
25 |
26 | def parseArgument():
27 | # ==== parse argument ====
28 | parser = argparse.ArgumentParser(description='Train with MLM loss')
29 |
30 | # ==== model ====
31 | parser.add_argument('--seed', default=1, type=int)
32 | parser.add_argument('--mode', default='multi-class',
33 | help='Choose from multi-class')
34 | parser.add_argument('--tokenizer', default='bert-base-uncased',
35 | help="Name of tokenizer")
36 | parser.add_argument('--LMName', default='bert-base-uncased',
37 | help='Name for models and path to saved model')
38 |
39 | # ==== dataset ====
40 | parser.add_argument('--dataDir',
41 | help="Dataset names included in this experiment and separated by comma. "
42 | "For example:'OOS,bank77,hwu64'")
43 | parser.add_argument('--targetDomain',
44 | help='Target domain names and separated by comma')
45 |
46 | # ==== evaluation task ====
47 | parser.add_argument('--way', type=int, default=5)
48 | parser.add_argument('--shot', type=int, default=2)
49 | parser.add_argument('--query', type=int, default=5)
50 | parser.add_argument('--clsFierName', default='Linear',
51 | help="Classifer name for few-shot evaluation"
52 | "Choose from Linear|SVM|NN|Cosine")
53 |
54 | # ==== optimizer ====
55 | parser.add_argument('--optimizer', default='Adam',
56 | help='Choose from SGD|Adam')
57 | parser.add_argument('--learningRate', type=float, default=2e-5)
58 | parser.add_argument('--weightDecay', type=float, default=0)
59 |
60 | # ==== training arguments ====
61 | parser.add_argument('--disableCuda', action="store_true")
62 | parser.add_argument('--epochs', type=int, default=50)
63 | parser.add_argument('--batch_size', type=int, default=32)
64 | parser.add_argument('--taskNum', type=int, default=500)
65 |
66 | # ==== other things ====
67 | parser.add_argument('--loggingLevel', default='INFO',
68 | help="python logging level")
69 | parser.add_argument('--saveModel', action='store_true',
70 | help="Whether to save pretrained model")
71 | parser.add_argument('--saveName', default='none',
72 | help="Specify a unique name to save your model"
73 | "If none, then there will be a specific name controlled by how the model is trained")
74 | parser.add_argument('--tensorboard', action='store_true',
75 | help="Enable tensorboard to log training and validation accuracy")
76 |
77 | args = parser.parse_args()
78 |
79 | return args
80 |
81 | def main():
82 | # ======= process arguments ======
83 | args = parseArgument()
84 | print(args)
85 |
86 | if not args.saveModel:
87 | logger.info("The model will not be saved after training!")
88 |
89 | # ==== setup logger ====
90 | if args.loggingLevel == LOGGING_LEVEL_INFO:
91 | loggingLevel = logging.INFO
92 | elif args.loggingLevel == LOGGING_LEVEL_DEBUG:
93 | loggingLevel = logging.DEBUG
94 | else:
95 | raise NotImplementedError("Not supported logging level %s", args.loggingLevel)
96 | logger.setLevel(loggingLevel)
97 |
98 | # ======= process data ======
99 | # tokenizer
100 | tok = AutoTokenizer.from_pretrained(args.tokenizer)
101 | # load raw dataset
102 | logger.info(f"Loading data from {args.dataDir}")
103 | dataset = IntentDataset()
104 | dataset.loadDataset(splitName(args.dataDir))
105 | dataset.tokenize(tok)
106 | # spit data into training, validation and testing
107 | logger.info("----- Testing Data -----")
108 | testData = dataset.splitDomain(splitName(args.targetDomain))
109 |
110 | # ======= prepare model ======
111 | # initialize model
112 | modelConfig = {}
113 | modelConfig['device'] = torch.device('cuda:0' if not args.disableCuda else 'cpu')
114 | modelConfig['clsNumber'] = testData.getLabNum()
115 | modelConfig['LMName'] = args.LMName
116 | model = IntentBERT(modelConfig)
117 | logger.info("----- IntentBERT initialized -----")
118 |
119 | # setup validator
120 | valParam = {"evalTaskNum": args.taskNum, "clsFierName": args.clsFierName, "multi_label": False}
121 | valTaskParam = {"way":args.way, "shot":args.shot, "query":args.query}
122 | tester = FewShotEvaluator(valParam, valTaskParam, testData)
123 |
124 | # setup trainer
125 | optimizer = None
126 | if args.optimizer == OPTER_ADAM:
127 | optimizer = optim.Adam(model.parameters(), lr=args.learningRate, weight_decay=args.weightDecay)
128 | elif args.optimizer == OPTER_SGD:
129 | optimizer = optim.SGD(model.parameters(), lr=args.learningRate, weight_decay=args.weightDecay)
130 | else:
131 | raise NotImplementedError("Not supported optimizer %s"%(args.optimizer))
132 |
133 | trainingParam = {"epoch" : args.epochs, \
134 | "batch" : args.batch_size, \
135 | "tensorboard": args.tensorboard}
136 | trainer = MLMOnlyTrainer(trainingParam, optimizer, testData, testData, tester)
137 |
138 | # train
139 | trainer.train(model, tok)
140 |
141 | # evaluate once more to show results
142 | tester.evaluate(model, tok, args.mode, logLevel='INFO')
143 |
144 | # save model into disk
145 | if args.saveModel:
146 | if args.saveName == 'none':
147 | prefix = "MLMOnly"
148 | save_path = os.path.join(SAVE_PATH, f'{prefix}_{args.targetDomain}')
149 | else:
150 | save_path = os.path.join(SAVE_PATH, args.saveName)
151 | logger.info("Saving model.pth into folder: %s", save_path)
152 | if not os.path.exists(save_path):
153 | os.mkdir(save_path)
154 | model.save(save_path)
155 |
156 | # print config
157 | logger.info(args)
158 | logger.info(time.asctime())
159 |
160 | if __name__ == "__main__":
161 | main()
162 | exit(0)
163 |
--------------------------------------------------------------------------------
/scripts/eval.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo usage:
3 | echo scriptName.sh : run in normal mode
4 | echo scriptName.sh debug : run in debug mode
5 |
6 | # hardware
7 | cudaID=$2
8 |
9 | # debug mode
10 | if [[ $# != 0 ]] && [[ $1 == "debug" ]]
11 | then
12 | debug=true
13 | else
14 | debug=false
15 | fi
16 |
17 | seed=1
18 |
19 | # dataset
20 | dataDir='bank77'
21 | targetDomain="BANKING"
22 | dataDir=mcid
23 | targetDomain="MEDICAL"
24 | dataDir=hint3
25 | targetDomain='curekart,powerplay11,sofmattress'
26 |
27 | # setting
28 | shot=2
29 |
30 | # model initialization
31 | LMName=intent-bert-base-uncased
32 | # LMName=joint-intent-bert-base-uncased-bank77
33 | # LMName=joint-intent-bert-base-uncased-mcid
34 | # LMName=joint-intent-bert-base-uncased-hint3
35 |
36 | # modify arguments if it's debug mode
37 | RED='\033[0;31m'
38 | GRN='\033[0;32m'
39 | NC='\033[0m' # No Color
40 | if $debug
41 | then
42 | echo -e "Run in ${RED} debug ${NC} mode."
43 | epochs=1
44 | else
45 | echo -e "Run in ${GRN} normal ${NC} mode."
46 | fi
47 |
48 | echo "Start Experiment ..."
49 | logFolder=./log/
50 | mkdir -p ${logFolder}
51 | logFile=${logFolder}/eval_${dataDir}_${way}way_${shot}shot_LMName${LMName}.log
52 | if $debug
53 | then
54 | logFlie=${logFolder}/logDebug.log
55 | fi
56 |
57 | export CUDA_VISIBLE_DEVICES=${cudaID}
58 | python eval.py \
59 | --seed ${seed} \
60 | --targetDomain ${targetDomain} \
61 | --dataDir ${dataDir} \
62 | --shot ${shot} \
63 | --LMName ${LMName} \
64 | | tee "${logFile}"
65 | echo "Experiment finished."
66 |
--------------------------------------------------------------------------------
/scripts/mlm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo usage:
3 | echo scriptName.sh : run in normal mode
4 | echo scriptName.sh debug : run in debug mode
5 |
6 | # hardware
7 | cudaID=$2
8 |
9 | # debug mode
10 | if [[ $# != 0 ]] && [[ $1 == "debug" ]]
11 | then
12 | debug=true
13 | else
14 | debug=false
15 | fi
16 |
17 | seed=1
18 |
19 | # dataset
20 | dataDir='bank77'
21 | # dataDir='mcid'
22 | # dataDir='hint3'
23 | targetDomain="BANKING"
24 | # targetDomain='MEDICAL'
25 | # targetDomain='curekart,powerplay11,sofmattress'
26 |
27 | # setting
28 | shot=2
29 |
30 | # training
31 | tensorboard=
32 | saveModel=--saveModel
33 | saveName=none
34 |
35 | # model setting
36 | # common
37 | LMName=bert-base-uncased
38 |
39 | # modify arguments if it's debug mode
40 | RED='\033[0;31m'
41 | GRN='\033[0;32m'
42 | NC='\033[0m' # No Color
43 | if $debug
44 | then
45 | echo -e "Run in ${RED} debug ${NC} mode."
46 | # validationTaskNum=10
47 | # testTaskNum=10
48 | epochs=1
49 | # repeatNum=1
50 | else
51 | echo -e "Run in ${GRN} normal ${NC} mode."
52 | fi
53 |
54 | echo "Start Experiment ..."
55 | logFolder=./log/
56 | mkdir -p ${logFolder}
57 | logFile=${logFolder}/mlm_${sourceDomainName}_to_${targetDomainName}_${way}way_${shot}shot.log
58 | if $debug
59 | then
60 | logFlie=${logFolder}/logDebug.log
61 | fi
62 |
63 | export CUDA_VISIBLE_DEVICES=${cudaID}
64 | python mlm.py \
65 | --seed ${seed} \
66 | --targetDomain ${targetDomain} \
67 | ${tensorboard} \
68 | --dataDir ${dataDir} \
69 | --shot ${shot} \
70 | ${saveModel} \
71 | --LMName ${LMName} \
72 | --saveName ${saveName} \
73 | | tee "${logFile}"
74 | echo "Experiment finished."
75 |
--------------------------------------------------------------------------------
/scripts/transfer.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo usage:
3 | echo scriptName.sh : run in normal mode
4 | echo scriptName.sh debug : run in debug mode
5 |
6 | # hardware
7 | cudaID=$2
8 |
9 | # debug mode
10 | if [[ $# != 0 ]] && [[ $1 == "debug" ]]
11 | then
12 | debug=true
13 | else
14 | debug=false
15 | fi
16 |
17 | seed=4321
18 | seed=1
19 |
20 | # dataset
21 | dataDir='oos,hwu64,bank77'
22 | # dataDir="oos,hwu64,mcid"
23 | # dataDir='oos,hwu64,hint3'
24 |
25 | sourceDomain="utility,auto_commute,kitchen_dining,work,home,meta,small_talk,travel"
26 | # valDomain="alarm,audio,iot,calendar,play,datetime,takeaway,news,music,weather,qa,social,recommendation,cooking,email,transport,lists"
27 | valDomain="iot,play,qa"
28 |
29 | # below is only for evaluation
30 | targetDomain="BANKING"
31 | # targetDomain="MEDICAL"
32 | # targetDomain="curekart,powerplay11,sofmattress"
33 |
34 | # setting
35 | shot=2
36 |
37 | # training
38 | tensorboard=
39 | saveModel=--saveModel
40 | saveName=none
41 | validation=--validation
42 | mlm=
43 | learningRate=1e-5
44 | learningRate=5e-6
45 |
46 | # model setting
47 | # common
48 | LMName=bert-base-uncased
49 |
50 | # modify arguments if it's debug mode
51 | RED='\033[0;31m'
52 | GRN='\033[0;32m'
53 | NC='\033[0m' # No Color
54 | if $debug
55 | then
56 | echo -e "Run in ${RED} debug ${NC} mode."
57 | epochs=1
58 | else
59 | echo -e "Run in ${GRN} normal ${NC} mode."
60 | fi
61 |
62 | echo "Start Experiment ..."
63 | logFolder=./log/
64 | mkdir -p ${logFolder}
65 | logFile=${logFolder}/transfer_${sourceDomain}_to_${targetDomain}_${way}way_${shot}shot.log
66 | # logFile=${logFolder}/transfer_mlm_${sourceDomainName}_to_${targetDomainName}_${way}way_${shot}shot.log
67 | if $debug
68 | then
69 | logFlie=${logFolder}/logDebug.log
70 | fi
71 |
72 | export CUDA_VISIBLE_DEVICES=${cudaID}
73 | python transfer.py \
74 | --seed ${seed} \
75 | --valDomain ${valDomain} \
76 | --sourceDomain ${sourceDomain} \
77 | --targetDomain ${targetDomain} \
78 | ${tensorboard} \
79 | --dataDir ${dataDir} \
80 | --shot ${shot} \
81 | ${saveModel} \
82 | ${validation} \
83 | ${mlm} \
84 | --learningRate ${learningRate} \
85 | --LMName ${LMName} \
86 | --saveName ${saveName} \
87 | | tee "${logFile}"
88 | echo "Experiment finished."
89 |
--------------------------------------------------------------------------------
/transfer.py:
--------------------------------------------------------------------------------
1 | # This file assembles three popular metric learnign baselines, matching network, prototype network and relation network.
2 | # This file is coded based on train_matchingNet.py.
3 | # coding=utf-8
4 | import os
5 | import torch
6 | import torch.optim as optim
7 | import argparse
8 | import time
9 | import copy
10 | from transformers import AutoTokenizer
11 | import random
12 |
13 | from utils.models import IntentBERT
14 | from utils.IntentDataset import IntentDataset
15 | from utils.Trainer import TransferTrainer
16 | from utils.Evaluator import FewShotEvaluator
17 | from utils.commonVar import *
18 | from utils.printHelper import *
19 | from utils.tools import *
20 | from utils.Logger import logger
21 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
22 |
23 | def set_seed(seed):
24 | random.seed(seed)
25 | np.random.seed(seed)
26 | torch.manual_seed(seed)
27 |
28 | def parseArgument():
29 | # ==== parse argument ====
30 | parser = argparse.ArgumentParser(description='Train IntentBERT')
31 |
32 | # ==== model ====
33 | parser.add_argument('--seed', default=1, type=int)
34 | parser.add_argument('--mode', default='multi-class',
35 | help='Choose from multi-class')
36 | parser.add_argument('--tokenizer', default='bert-base-uncased',
37 | help="Name of tokenizer")
38 | parser.add_argument('--LMName', default='bert-base-uncased',
39 | help='Name for models and path to saved model')
40 |
41 | # ==== dataset ====
42 | parser.add_argument('--dataDir',
43 | help="Dataset names included in this experiment and separated by comma. "
44 | "For example:'OOS,bank77,hwu64'")
45 | parser.add_argument('--sourceDomain',
46 | help="Source domain names and separated by comma. "
47 | "For example:'travel,banking,home'")
48 | parser.add_argument('--valDomain',
49 | help='Validation domain names and separated by comma')
50 | parser.add_argument('--targetDomain',
51 | help='Target domain names and separated by comma')
52 |
53 | # ==== evaluation task ====
54 | parser.add_argument('--way', type=int, default=5)
55 | parser.add_argument('--shot', type=int, default=2)
56 | parser.add_argument('--query', type=int, default=5)
57 | parser.add_argument('--clsFierName', default='Linear',
58 | help="Classifer name for few-shot evaluation"
59 | "Choose from Linear|SVM|NN|Cosine|MultiLabel")
60 |
61 | # ==== optimizer ====
62 | parser.add_argument('--optimizer', default='Adam',
63 | help='Choose from SGD|Adam')
64 | parser.add_argument('--learningRate', type=float, default=2e-5)
65 | parser.add_argument('--weightDecay', type=float, default=0)
66 |
67 | # ==== training arguments ====
68 | parser.add_argument('--disableCuda', action="store_true")
69 | parser.add_argument('--validation', action="store_true")
70 | parser.add_argument('--epochs', type=int, default=10)
71 | parser.add_argument('--batch_size', type=int, default=32)
72 | parser.add_argument('--taskNum', type=int, default=500)
73 | parser.add_argument('--patience', type=int, default=3,
74 | help="Early stop when performance does not go better")
75 | parser.add_argument('--mlm', action='store_true',
76 | help="If use mlm as auxiliary loss")
77 | parser.add_argument('--lambda_mlm', type=float, default=1.0,
78 | help="The weight for mlm loss")
79 | parser.add_argument('--mlm_data', default='target', type=str,
80 | help="Data for mlm. Choose from target|source")
81 | parser.add_argument('--shuffle_mlm', action="store_true")
82 | parser.add_argument('--shuffle', action="store_true")
83 | parser.add_argument('--regression', action="store_true",
84 | help="If the pretrain task is a regression task")
85 |
86 | # ==== other things ====
87 | parser.add_argument('--loggingLevel', default='INFO',
88 | help="python logging level")
89 | parser.add_argument('--saveModel', action='store_true',
90 | help="Whether to save pretrained model")
91 | parser.add_argument('--saveName', default='none',
92 | help="Specify a unique name to save your model"
93 | "If none, then there will be a specific name controlled by how the model is trained")
94 | parser.add_argument('--tensorboard', action='store_true',
95 | help="Enable tensorboard to log training and validation accuracy")
96 |
97 | args = parser.parse_args()
98 |
99 | return args
100 |
101 | def main():
102 | # ======= process arguments ======
103 | args = parseArgument()
104 | print(args)
105 |
106 | if not args.saveModel:
107 | logger.info("The model will not be saved after training!")
108 |
109 | # ==== setup logger ====
110 | if args.loggingLevel == LOGGING_LEVEL_INFO:
111 | loggingLevel = logging.INFO
112 | elif args.loggingLevel == LOGGING_LEVEL_DEBUG:
113 | loggingLevel = logging.DEBUG
114 | else:
115 | raise NotImplementedError("Not supported logging level %s", args.loggingLevel)
116 | logger.setLevel(loggingLevel)
117 |
118 | # ==== set seed ====
119 | set_seed(args.seed)
120 |
121 | # ======= process data ======
122 | # tokenizer
123 | tok = AutoTokenizer.from_pretrained(args.tokenizer)
124 | # load raw dataset
125 | logger.info(f"Loading data from {args.dataDir}")
126 | dataset = IntentDataset(regression=args.regression)
127 | dataset.loadDataset(splitName(args.dataDir))
128 | dataset.tokenize(tok)
129 | # spit data into training, validation and testing
130 | logger.info("----- Training Data -----")
131 | trainData = dataset.splitDomain(splitName(args.sourceDomain), regression=args.regression)
132 | logger.info("----- Validation Data -----")
133 | valData = dataset.splitDomain(splitName(args.valDomain))
134 | logger.info("----- Testing Data -----")
135 | testData = dataset.splitDomain(splitName(args.targetDomain))
136 | # shuffle word order
137 | if args.shuffle:
138 | trainData.shuffle_words()
139 |
140 | # ======= prepare model ======
141 | # initialize model
142 | modelConfig = {}
143 | modelConfig['device'] = torch.device('cuda:0' if not args.disableCuda else 'cpu')
144 | if args.regression:
145 | modelConfig['clsNumber'] = 1
146 | else:
147 | modelConfig['clsNumber'] = trainData.getLabNum()
148 | modelConfig['LMName'] = args.LMName
149 | model = IntentBERT(modelConfig)
150 | logger.info("----- IntentBERT initialized -----")
151 |
152 | # setup validator
153 | valParam = {"evalTaskNum": args.taskNum, "clsFierName": args.clsFierName, "multi_label": False}
154 | valTaskParam = {"way":args.way, "shot":args.shot, "query":args.query}
155 | validator = FewShotEvaluator(valParam, valTaskParam, valData)
156 | tester = FewShotEvaluator(valParam, valTaskParam, testData)
157 |
158 | # setup trainer
159 | optimizer = None
160 | if args.optimizer == OPTER_ADAM:
161 | optimizer = optim.Adam(model.parameters(), lr=args.learningRate, weight_decay=args.weightDecay)
162 | elif args.optimizer == OPTER_SGD:
163 | optimizer = optim.SGD(model.parameters(), lr=args.learningRate, weight_decay=args.weightDecay)
164 | else:
165 | raise NotImplementedError("Not supported optimizer %s"%(args.optimizer))
166 |
167 | if args.mlm and args.mlm_data == "target":
168 | args.validation = False
169 | trainingParam = {"epoch" : args.epochs, \
170 | "batch" : args.batch_size, \
171 | "validation" : args.validation, \
172 | "patience" : args.patience, \
173 | "tensorboard": args.tensorboard, \
174 | "mlm" : args.mlm, \
175 | "lambda mlm" : args.lambda_mlm, \
176 | "regression" : args.regression}
177 | unlabeledData = None
178 | if args.mlm_data == "source":
179 | unlabeledData = copy.deepcopy(trainData)
180 | elif args.mlm_data == "target":
181 | unlabeledData = copy.deepcopy(testData)
182 | if args.shuffle_mlm:
183 | unlabeledData.shuffle_words()
184 | trainer = TransferTrainer(trainingParam, optimizer, trainData, unlabeledData, validator, tester)
185 |
186 | # train
187 | trainer.train(model, tok, args.mode)
188 |
189 | # load best model
190 | bestModelStateDict = trainer.getBestModelStateDict()
191 | model.load_state_dict(bestModelStateDict)
192 |
193 | # evaluate once more to show results
194 | tester.evaluate(model, tok, args.mode, logLevel='INFO')
195 |
196 | # save model into disk
197 | if args.saveModel:
198 | # decide the save name
199 | if args.saveName == 'none':
200 | prefix = "STMLM" if args.mlm else "ST"
201 | if args.mlm:
202 | if args.mlm_data == 'target':
203 | prefix += f"_{args.targetDomain}"
204 | elif args.mlm_data == 'source':
205 | prefix += "_source"
206 | save_path = os.path.join(SAVE_PATH, f'{prefix}_{args.mode}_{args.sourceDomain}')
207 | if args.shuffle:
208 | save_path += "_shuffle"
209 | if args.shuffle_mlm:
210 | save_path += "_shuffle_mlm"
211 | else:
212 | save_path = os.path.join(SAVE_PATH, args.saveName)
213 | # save
214 | logger.info("Saving model.pth into folder: %s", save_path)
215 | if not os.path.exists(save_path):
216 | os.mkdir(save_path)
217 | model.save(save_path)
218 |
219 | # print config
220 | logger.info(args)
221 | logger.info(time.asctime())
222 |
223 | if __name__ == "__main__":
224 | main()
225 | exit(0)
226 |
--------------------------------------------------------------------------------
/utils/Evaluator.py:
--------------------------------------------------------------------------------
1 | from utils.IntentDataset import IntentDataset
2 | from utils.TaskSampler import MultiLabTaskSampler, UniformTaskSampler
3 | from utils.tools import makeEvalExamples
4 | from utils.printHelper import *
5 | from utils.Logger import logger
6 | from utils.commonVar import *
7 | import logging
8 | import torch
9 | import numpy as np
10 | from tqdm import tqdm
11 |
12 | from sklearn.metrics import accuracy_score, precision_recall_fscore_support
13 |
14 | ##
15 | # @brief base class of evaluator
16 | class EvaluatorBase():
17 | def __init__(self):
18 | self.roundN = 4
19 | pass
20 |
21 | def round(self, floatNum):
22 | return round(floatNum, self.roundN)
23 |
24 | def evaluate(self):
25 | raise NotImplementedError("train() is not implemented.")
26 |
27 | ##
28 | # @brief MetaEvaluator used to do meta evaluation. Tasks are sampled and the model is evaluated task by task.
29 | class FewShotEvaluator(EvaluatorBase):
30 | def __init__(self, evalParam, taskParam, dataset: IntentDataset):
31 | super(FewShotEvaluator, self).__init__()
32 | self.way = taskParam['way']
33 | self.shot = taskParam['shot']
34 | self.query = taskParam['query']
35 |
36 | self.dataset = dataset
37 |
38 | self.multi_label = evalParam['multi_label']
39 | self.clsFierName = evalParam['clsFierName']
40 | self.evalTaskNum = evalParam['evalTaskNum']
41 | logger.info("In evaluator classifier %s is used.", self.clsFierName)
42 |
43 | if self.multi_label:
44 | self.taskSampler = MultiLabTaskSampler(self.dataset, self.shot, self.query)
45 | else:
46 | self.taskSampler = UniformTaskSampler(self.dataset, self.way, self.shot, self.query)
47 |
48 | def evaluate(self, model, tokenizer, mode='multi-class', logLevel='DEBUG'):
49 | model.eval()
50 |
51 | performList = [] # acc, pre, rec, fsc
52 | with torch.no_grad():
53 | for task in range(self.evalTaskNum):
54 | # sample a task
55 | task = self.taskSampler.sampleOneTask()
56 |
57 | # collect data
58 | supportX = task[META_TASK_SHOT_TOKEN]
59 | queryX = task[META_TASK_QUERY_TOKEN]
60 | if mode == 'multi-class':
61 | supportY = task[META_TASK_SHOT_LOC_LABID]
62 | queryY = task[META_TASK_QUERY_LOC_LABID]
63 | else:
64 | logger.error("Invalid model %d"%(mode))
65 |
66 | # padding
67 | supportX, supportY, queryX, queryY =\
68 | makeEvalExamples(supportX, supportY, queryX, queryY, tokenizer, mode=mode)
69 |
70 | # forward
71 | queryPrediction = model.fewShotPredict(supportX.to(model.device),
72 | supportY,
73 | queryX.to(model.device),
74 | self.clsFierName,
75 | mode=mode)
76 |
77 | # calculate acc
78 | acc = accuracy_score(queryY, queryPrediction) # acc
79 | if self.multi_label:
80 | performDetail = precision_recall_fscore_support(queryY, queryPrediction, average='micro', warn_for=tuple())
81 | else:
82 | performDetail = precision_recall_fscore_support(queryY, queryPrediction, average='macro', warn_for=tuple())
83 |
84 | performList.append([acc, performDetail[0], performDetail[1], performDetail[2]])
85 |
86 | # performance mean and std
87 | performMean = np.mean(np.stack(performList, 0), 0)
88 | performStd = np.std(np.stack(performList, 0), 0)
89 |
90 | if logLevel == 'DEBUG':
91 | itemList = ["acc", "pre", "rec", "fsc"]
92 | logger.debug("Evaluate statistics: ")
93 | printMeanStd(performMean, performStd, itemList, debugLevel=logging.DEBUG)
94 | else:
95 | itemList = ["acc", "pre", "rec", "fsc"]
96 | logger.info("Evaluate statistics: ")
97 | printMeanStd(performMean, performStd, itemList, debugLevel=logging.INFO)
98 |
99 | # acc, pre, rec, F1
100 | return performMean[0], performMean[1], performMean[2], performMean[3]
101 |
102 |
103 | ##
104 | # @brief MetaEvaluator used to do meta evaluation. Tasks are sampled and the model is evaluated task by task.
105 | class FineTuneEvaluator(EvaluatorBase):
106 | def __init__(self, evalParam, taskParam, optimizer, dataset: IntentDataset):
107 | super(FineTuneEvaluator, self).__init__()
108 | self.way = taskParam['way']
109 | self.shot = taskParam['shot']
110 | self.query = taskParam['query']
111 |
112 | self.dataset = dataset
113 | self.optimizer = optimizer
114 |
115 | self.finetuneSteps = evalParam['finetuneSteps']
116 | self.evalTaskNum = evalParam['evalTaskNum']
117 |
118 | self.taskSampler = UniformTaskSampler(self.dataset, self.way, self.shot, self.query)
119 |
120 | def evaluate(self, model, tokenizer, mode='multi-class', logLevel='DEBUG'):
121 | performList = [] # acc, pre, rec, fsc
122 | initial_model = model.state_dict().copy()
123 | initial_optim = self.optimizer.state_dict().copy()
124 |
125 | for task in tqdm(range(self.evalTaskNum)):
126 | # sample a task
127 | task = self.taskSampler.sampleOneTask()
128 |
129 | # collect data
130 | supportX = task[META_TASK_SHOT_TOKEN]
131 | queryX = task[META_TASK_QUERY_TOKEN]
132 | if mode == 'multi-class':
133 | supportY = task[META_TASK_SHOT_LOC_LABID]
134 | queryY = task[META_TASK_QUERY_LOC_LABID]
135 | else:
136 | logger.error("Invalid model %d"%(mode))
137 |
138 | # padding
139 | supportX, supportY, queryX, queryY =\
140 | makeEvalExamples(supportX, supportY, queryX, queryY, tokenizer, mode=mode)
141 |
142 | # finetune
143 | model.train()
144 | for _ in range(self.finetuneSteps):
145 | logits = model(supportX.to(model.device))
146 | loss = model.loss_ce(logits, torch.tensor(supportY).to(model.device))
147 | self.optimizer.zero_grad()
148 | loss.backward()
149 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
150 | self.optimizer.step()
151 |
152 | model.eval()
153 | with torch.no_grad():
154 | if mode == 'multi-class':
155 | queryPrediction = model(queryX.to(model.device)).argmax(-1)
156 | else:
157 | logger.error("Invalid model %d"%(mode))
158 |
159 | queryPrediction = queryPrediction.cpu().numpy()
160 |
161 | # calculate acc
162 | acc = accuracy_score(queryY, queryPrediction) # acc
163 | performDetail = precision_recall_fscore_support(queryY, queryPrediction, average='macro', warn_for=tuple())
164 |
165 | performList.append([acc, performDetail[0], performDetail[1], performDetail[2]])
166 |
167 | model.load_state_dict(initial_model)
168 | self.optimizer.load_state_dict(initial_optim)
169 |
170 | # performance mean and std
171 | performMean = np.mean(np.stack(performList, 0), 0)
172 | performStd = np.std(np.stack(performList, 0), 0)
173 |
174 | if logLevel == 'DEBUG':
175 | itemList = ["acc", "pre", "rec", "fsc"]
176 | logger.debug("Evaluate statistics: ")
177 | printMeanStd(performMean, performStd, itemList, debugLevel=logging.DEBUG)
178 | else:
179 | itemList = ["acc", "pre", "rec", "fsc"]
180 | logger.info("Evaluate statistics: ")
181 | printMeanStd(performMean, performStd, itemList, debugLevel=logging.INFO)
182 |
183 | # acc, pre, rec, F1
184 | return performMean[0], performMean[1], performMean[2], performMean[3]
185 |
--------------------------------------------------------------------------------
/utils/IntentDataset.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 | import nltk
3 | from nltk.corpus import stopwords
4 | nltk.download('stopwords')
5 |
6 | from utils.commonVar import *
7 | import json
8 | from utils.Logger import logger
9 | import os
10 | import copy
11 | import random
12 |
13 |
14 | class IntentDataset():
15 | def __init__(self,
16 | domList=None,
17 | labList=None,
18 | uttList=None,
19 | tokList=None,
20 | regression=False,
21 | multi_label=False):
22 | self.regression = regression
23 | self.multi_label = multi_label
24 |
25 | self.domList = [] if domList is None else domList
26 | self.labList = [] if labList is None else labList
27 | self.uttList = [] if uttList is None else uttList
28 | self.tokList = [] if tokList is None else tokList
29 |
30 | if (self.labList is not None) and (not self.regression):
31 | self.createLabID()
32 | if self.regression:
33 | self.labIDList = self.labList
34 | if not self.multi_label:
35 | self.convertLabs()
36 | self.labID2DataInd = None
37 | self.dataInd2LabID = None
38 |
39 | def getDomList(self):
40 | return self.domList
41 |
42 | def getLabList(self):
43 | return self.labList
44 |
45 | def getUttList(self):
46 | return self.uttList
47 |
48 | def getTokList(self):
49 | return self.tokList
50 |
51 | def getAllData(self):
52 | return self.domList, self.labList, self.uttList, self.tokList
53 |
54 | def getLabNum(self):
55 | labSet = set()
56 | for lab in self.labList:
57 | if self.multi_label:
58 | for l in lab:
59 | labSet.add(l)
60 | else:
61 | labSet.add(lab)
62 | return len(labSet)
63 |
64 | def getLabID(self):
65 | return self.labIDList
66 |
67 | def checkData(self, utt: str, label: str):
68 | if not self.regression:
69 | if len(label) == 0 or len(utt) == 0:
70 | logger.warning("Illegal label %s or utterance %s, 0 length", label, utt)
71 | return 1
72 | return 0
73 |
74 | def loadDataset(self, dataDirList):
75 | self.domList, self.labList, self.uttList = [], [], []
76 | dataFilePathList = \
77 | [os.path.join(DATA_PATH, dataDir, FILE_NAME_DATASET) for dataDir in dataDirList]
78 |
79 | dataList = []
80 | for dataFilePath in dataFilePathList:
81 | with open(dataFilePath, 'r') as json_file:
82 | dataList.append(json.load(json_file))
83 |
84 | delDataNum = 0
85 | for data in dataList:
86 | for datasetName in data:
87 | dataset = data[datasetName]
88 | for domainName in dataset:
89 | domain = dataset[domainName]
90 | for dataItem in domain:
91 | utt = dataItem[0]
92 | labList = dataItem[1]
93 |
94 | if self.multi_label:
95 | lab = labList
96 | else:
97 | lab = labList[0]
98 |
99 | if not self.checkData(utt, lab) == 0:
100 | logger.warning("Illegal label %s or utterance %s, too short length", lab, utt)
101 | delDataNum = delDataNum+1
102 | else:
103 | self.domList.append(domainName)
104 | self.labList.append(lab)
105 | self.uttList.append(utt)
106 |
107 | # report deleted data number
108 | if (delDataNum>0):
109 | logger.warning("%d data is deleted from dataset.", delDataNum)
110 |
111 | # sanity check
112 | countSet = set()
113 | countSet.add(len(self.domList))
114 | countSet.add(len(self.labList))
115 | countSet.add(len(self.uttList))
116 | if len(countSet) > 1:
117 | logger.error("Unaligned data list. Length of data list: dataset %d, domain %d, lab %d, utterance %d", len(self.domainList), len(self.labList), len(self.uttList))
118 | exit(1)
119 | if not self.regression:
120 | self.createLabID()
121 | else:
122 | self.labIDList = self.labList
123 | logger.info(f"{countSet} data collected")
124 | return 0
125 |
126 | def removeStopWord(self):
127 | raise NotImplementedError
128 |
129 | # print info
130 | logger.info("Removing stop words ...")
131 | logger.info("Before removing stop words: data count is %d", len(self.uttList))
132 |
133 | # remove stop word
134 | stopwordsEnglish = stopwords.words('english')
135 | uttListNew = []
136 | labListNew = []
137 | delLabListNew = []
138 | delUttListNew = [] # Utt for utterance
139 | maxLen = -1
140 | for lab, utt in zip(self.labList, self.uttList):
141 | uttWordListNew = [w for w in utt.split(' ') if not word in stopwordsEnglish]
142 | uttNew = ' '.join(uttWordListNew)
143 |
144 | uttNewLen = len(uttWordListNew)
145 | if uttNewLen <= 0: # too short utterance, delete it from dataset
146 | delLabListNew.append(lab)
147 | delUttListNew.append(uttNew)
148 | else: # utt with normal length
149 | if uttNewLen > maxLen:
150 | maxLen = uttNewLen
151 | labListNew.append(lab)
152 | uttListNew.append(uttNew)
153 | self.labList = labListNew
154 | self.uttListNew = uttListNew
155 | self.delLabList.append(delLabListNew)
156 | self.delUttList.append(delUttListNew)
157 |
158 | # update data list
159 | logger.info("After removing stop words: data count is %d", len(self.uttList))
160 | logger.info("Removing stop words ... done.")
161 |
162 | return 0
163 |
164 | def splitDomain(self, domainName: list, regression=False, multi_label=False):
165 | domList = self.getDomList()
166 |
167 | # collect index
168 | indList = []
169 | for ind, domain in enumerate(domList):
170 | if domain in domainName:
171 | indList.append(ind)
172 |
173 | # sanity check
174 | dataCount = len(indList)
175 | if dataCount<1:
176 | logger.error("Empty data for domain %s", domainName)
177 | exit(1)
178 |
179 | logger.info("For domain %s, %d data is selected from %d data in the dataset.", domainName, dataCount, len(domList))
180 |
181 | # get all data from dataset
182 | domList, labList, uttList, tokList = self.getAllData()
183 | domDomList = [domList[i] for i in indList]
184 | domLabList = [labList[i] for i in indList]
185 | domUttList = [uttList[i] for i in indList]
186 | if self.tokList:
187 | domTokList = [tokList[i] for i in indList]
188 | else:
189 | domTokList = []
190 | domDataset = IntentDataset(domDomList, domLabList, domUttList, domTokList, regression=regression, multi_label=multi_label)
191 |
192 | return domDataset
193 |
194 | def tokenize(self, tokenizer):
195 | self.tokList = []
196 | for u in self.uttList:
197 | ut = tokenizer(u)
198 | if 'token_type_ids' not in ut:
199 | ut['token_type_ids'] = [0]*len(ut['input_ids'])
200 | self.tokList.append(ut)
201 |
202 | def shuffle_words(self):
203 | newList = []
204 | for u in self.uttList:
205 | replace = copy.deepcopy(u)
206 | replace = replace.split(' ')
207 | random.shuffle(replace)
208 | replace = ' '.join(replace)
209 | newList.append(replace)
210 | self.uttList = newList
211 |
212 | # convert label names to label IDs: 0, 1, 2, 3
213 | def createLabID(self):
214 | # get unique label
215 | labSet = set()
216 | for lab in self.labList:
217 | if self.multi_label:
218 | for l in lab:
219 | labSet.add(l)
220 | else:
221 | labSet.add(lab)
222 |
223 | # get number
224 | self.labNum = len(labSet)
225 | sortedLabList = list(labSet)
226 | sortedLabList.sort()
227 |
228 | # fill up dict: lab -> labID
229 | self.name2LabID = {}
230 | for ind, lab in enumerate(sortedLabList):
231 | if not lab in self.name2LabID:
232 | self.name2LabID[lab] = ind
233 |
234 | # fill up label ID list
235 | self.labIDList =[]
236 | for lab in self.labList:
237 | if self.multi_label:
238 | labID = []
239 | for l in lab:
240 | labID.append(self.name2LabID[l])
241 | self.labIDList.append(labID)
242 | else:
243 | self.labIDList.append(self.name2LabID[lab])
244 |
245 | # sanity check
246 | if not len(self.labIDList) == len(self.uttList):
247 | logger.error("create labID error. Not consistence labe ID list length and utterance list length.")
248 | exit(1)
249 |
250 | def getLabID2dataInd(self):
251 | if not self.labID2DataInd == None:
252 | return self.labID2DataInd
253 | else:
254 | self.labID2DataInd = {}
255 | for dataInd, labID in enumerate(self.labIDList):
256 | if self.multi_label:
257 | for l in labID:
258 | if not l in self.labID2DataInd:
259 | self.labID2DataInd[l] = []
260 | self.labID2DataInd[l].append(dataInd)
261 | else:
262 | if not labID in self.labID2DataInd:
263 | self.labID2DataInd[labID] = []
264 | self.labID2DataInd[labID].append(dataInd)
265 |
266 | # sanity check
267 | if not self.multi_label:
268 | dataCount = 0
269 | for labID in self.labID2DataInd:
270 | dataCount = dataCount + len(self.labID2DataInd[labID])
271 | if not dataCount == len(self.uttList):
272 | logger.error("Inconsistent data count %d and %d when generating dict, labID2DataInd", dataCount, len(self.uttList))
273 | exit(1)
274 |
275 | return self.labID2DataInd
276 |
277 | def getDataInd2labID(self):
278 | if not self.dataInd2LabID == None:
279 | return self.dataInd2LabID
280 | else:
281 | self.dataInd2LabID = {}
282 | for dataInd, labID in enumerate(self.labIDList):
283 | self.dataInd2LabID[dataInd] = labID
284 | return self.dataInd2LabID
285 |
286 | def convertLabs(self):
287 | # when the dataset is not multi-label, convert labels from list to a single instance
288 | if self.labList:
289 | if isinstance(self.labList[0], list):
290 | newList = []
291 | for l in self.labList:
292 | newList.append(l[0])
293 | self.labList = newList
294 | if self.labIDList:
295 | if isinstance(self.labIDList[0], list):
296 | newList = []
297 | for l in self.labIDList:
298 | newList.append(l[0])
299 | self.labIDList = newList
--------------------------------------------------------------------------------
/utils/Logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 |
4 | logging.basicConfig(stream=sys.stdout, format='[%(levelname)s] %(message)s')
5 | logger = logging.getLogger('globalLogger')
6 |
--------------------------------------------------------------------------------
/utils/TaskSampler.py:
--------------------------------------------------------------------------------
1 | from utils.IntentDataset import IntentDataset
2 | from utils.Logger import logger
3 | import random
4 | from utils.commonVar import *
5 | import numpy as np
6 | import copy
7 | from sklearn.preprocessing import MultiLabelBinarizer
8 | # random.seed(0)
9 | # base class for task samper
10 | # sample meta-task from a dataset for training and evaluation
11 | class TaskSampler():
12 | def __init__(self, dataset:IntentDataset):
13 | self.dataset = dataset
14 |
15 | def sampleOneTask():
16 | raise NotImplementedError("sampleOneTask() is not implemented.")
17 |
18 | class UniformTaskSampler(TaskSampler):
19 | def __init__(self, dataset:IntentDataset, way, shot, query):
20 | super(UniformTaskSampler, self).__init__(dataset)
21 | self.way = way
22 | self.shot = shot
23 | self.query = query
24 | self.taskPool = None
25 | self.dataset = dataset
26 |
27 | ##
28 | # @brief sample data index for a task. Class global IDs are also sampled.
29 | #
30 | # @return a dict, print it to see what's there
31 | def sampleClassIDsDataInd(self):
32 | taskInfo = {}
33 | glbLabNum = self.dataset.getLabNum()
34 | labID2DataInd = self.dataset.getLabID2dataInd()
35 |
36 | uniqueGlbLabIDs = list(range(glbLabNum))
37 | # random sample global label IDs
38 | taskGlbLabIDs = random.sample(uniqueGlbLabIDs, self.way)
39 | taskInfo[META_TASK_GLB_LABID] = taskGlbLabIDs
40 |
41 | # sample data for each labID
42 | taskInfo[META_TASK_SHOT_GLB_LABID] = []
43 | taskInfo[META_TASK_QUERY_GLB_LABID] = []
44 | taskInfo[META_TASK_SHOT_DATAIND] = []
45 | taskInfo[META_TASK_QUERY_DATAIND] = []
46 | for labID in taskGlbLabIDs:
47 | # random sample support data and query data
48 | dataInds = random.sample(labID2DataInd[labID], self.shot+self.query)
49 | random.shuffle(dataInds)
50 | shotDataInds = dataInds[:self.shot]
51 | queryDataInds = dataInds[(-self.query):]
52 |
53 | taskInfo[META_TASK_SHOT_GLB_LABID].extend([labID]*(self.shot))
54 | taskInfo[META_TASK_SHOT_DATAIND].extend(shotDataInds)
55 | taskInfo[META_TASK_QUERY_GLB_LABID].extend([labID]*(self.query))
56 | taskInfo[META_TASK_QUERY_DATAIND].extend(queryDataInds)
57 |
58 | return taskInfo
59 |
60 |
61 | ##
62 | # @brief it works with sampleClassIDsDataInd(), taking a taskInfo returned by sampleClassIDsDataInd(), then return data in the task.
63 | #
64 | # @param taskDataInds a dict containing data index
65 | #
66 | # @return a dict containing task data, print it to see what's there
67 | def collectDataForTask(self, taskDataInds):
68 | task = {}
69 |
70 | # compose local labID from glbLabID
71 | glbLabIDList = taskDataInds[META_TASK_GLB_LABID]
72 | glbLabID2LocLabID = {}
73 | for pos, glbLabID in enumerate(glbLabIDList):
74 | glbLabID2LocLabID[glbLabID] = pos
75 | tokList = self.dataset.getTokList()
76 | labList = self.dataset.getLabList()
77 |
78 | # support
79 | task[META_TASK_SHOT_LOC_LABID] = [glbLabID2LocLabID[glbLabID] for glbLabID in taskDataInds[META_TASK_SHOT_GLB_LABID]]
80 | task[META_TASK_SHOT_TOKEN] = [tokList[i] for i in taskDataInds[META_TASK_SHOT_DATAIND]]
81 | task[META_TASK_SHOT_LAB] = [labList[i] for i in taskDataInds[META_TASK_SHOT_DATAIND]]
82 |
83 | # query
84 | task[META_TASK_QUERY_LOC_LABID] = [glbLabID2LocLabID[glbLabID] for glbLabID in taskDataInds[META_TASK_QUERY_GLB_LABID]]
85 | task[META_TASK_QUERY_TOKEN] = [tokList[i] for i in taskDataInds[META_TASK_QUERY_DATAIND]]
86 | task[META_TASK_QUERY_LAB] = [labList[i] for i in taskDataInds[META_TASK_QUERY_DATAIND]]
87 |
88 | return task
89 |
90 | def sampleOneTask(self):
91 | # 1. sample classes and data index
92 | taskDataInds = self.sampleClassIDsDataInd()
93 |
94 | # 2. according to data index, select data, such tokens, lenths, label names, etc.
95 | task = self.collectDataForTask(taskDataInds)
96 |
97 | return task
98 |
99 |
100 | class MultiLabTaskSampler(TaskSampler):
101 | def __init__(self, dataset:IntentDataset, shot, query):
102 | super(MultiLabTaskSampler, self).__init__(dataset)
103 | self.shot = shot
104 | self.query = query
105 | self.taskPool = None
106 | self.dataset = dataset
107 |
108 | ##
109 | # @brief sample data index for a task. Class global IDs are also sampled.
110 | #
111 | # @return a dict, print it to see what's there
112 | def sampleClassIDsDataInd(self):
113 | taskInfo = {}
114 | glbLabNum = self.dataset.getLabNum()
115 | labID2DataInd = self.dataset.getLabID2dataInd()
116 |
117 | uniqueGlbLabIDs = list(range(glbLabNum))
118 | # random sample global label IDs
119 | taskInfo[META_TASK_GLB_LABID] = uniqueGlbLabIDs
120 |
121 | # sample data for each labID
122 | taskInfo[META_TASK_SHOT_GLB_LABID] = []
123 | taskInfo[META_TASK_QUERY_GLB_LABID] = []
124 | taskInfo[META_TASK_SHOT_DATAIND] = []
125 | taskInfo[META_TASK_QUERY_DATAIND] = []
126 | shotDataInds, queryDataInds = [], []
127 | for labID in uniqueGlbLabIDs:
128 | # random sample support data
129 | dataInds = random.sample(labID2DataInd[labID], self.shot)
130 | shotDataInds += dataInds
131 | # random sample query data
132 | remain = list(set(labID2DataInd[labID])-set(dataInds))
133 | dataInds = random.sample(remain, self.query)
134 | queryDataInds += dataInds
135 | shotDataInds = list(set(shotDataInds))
136 | queryDataInds = list(set(queryDataInds))
137 | shotDataInds = self.checkForDuplicate(shotDataInds, required_num=self.shot)
138 | queryDataInds = self.checkForDuplicate(queryDataInds, required_num=self.query)
139 |
140 | taskInfo[META_TASK_SHOT_DATAIND] = shotDataInds
141 | taskInfo[META_TASK_QUERY_DATAIND] = queryDataInds
142 |
143 | dataInd2LabID = self.dataset.getDataInd2labID()
144 | for d in shotDataInds:
145 | taskInfo[META_TASK_SHOT_GLB_LABID].append(dataInd2LabID[d])
146 | for d in queryDataInds:
147 | taskInfo[META_TASK_QUERY_GLB_LABID].append(dataInd2LabID[d])
148 | return taskInfo
149 |
150 | def checkForDuplicate(self, dataInds, required_num):
151 | dataInd2LabID = self.dataset.getDataInd2labID()
152 | label_lists = []
153 | for di in dataInds:
154 | label_lists.extend(dataInd2LabID[di])
155 | label_names, counts = np.unique(label_lists, return_counts=True)
156 | shot_counts = {ln: c for ln, c in zip(label_names, counts)}
157 | loopInds = copy.deepcopy(dataInds)
158 | for di in loopInds:
159 | can_remove = True
160 | for l in dataInd2LabID[di]:
161 | if (l in shot_counts) and (shot_counts[l] - 1 < required_num):
162 | can_remove = False
163 | if can_remove:
164 | dataInds.remove(di)
165 | for l in dataInd2LabID[di]:
166 | shot_counts[l] -= 1
167 | return dataInds
168 |
169 | ##
170 | # @brief it works with sampleClassIDsDataInd(), taking a taskInfo returned by sampleClassIDsDataInd(), then return data in the task.
171 | #
172 | # @param taskDataInds a dict containing data index
173 | #
174 | # @return a dict containing task data, print it to see what's there
175 | def collectDataForTask(self, taskDataInds):
176 | task = {}
177 |
178 | # compose local labID from glbLabID
179 | tokList = self.dataset.getTokList()
180 | labList = self.dataset.getLabList()
181 |
182 | mlb = MultiLabelBinarizer()
183 |
184 | # support
185 | task[META_TASK_SHOT_LOC_LABID] = mlb.fit_transform(taskDataInds[META_TASK_SHOT_GLB_LABID]).tolist()
186 | task[META_TASK_SHOT_TOKEN] = [tokList[i] for i in taskDataInds[META_TASK_SHOT_DATAIND]]
187 | task[META_TASK_SHOT_LAB] = [labList[i] for i in taskDataInds[META_TASK_SHOT_DATAIND]]
188 |
189 | # query
190 | task[META_TASK_QUERY_LOC_LABID] = mlb.fit_transform(taskDataInds[META_TASK_QUERY_GLB_LABID]).tolist()
191 | task[META_TASK_QUERY_TOKEN] = [tokList[i] for i in taskDataInds[META_TASK_QUERY_DATAIND]]
192 | task[META_TASK_QUERY_LAB] = [labList[i] for i in taskDataInds[META_TASK_QUERY_DATAIND]]
193 |
194 | return task
195 |
196 | def sampleOneTask(self):
197 | # 1. sample classes and data index
198 | taskDataInds = self.sampleClassIDsDataInd()
199 |
200 | # 2. according to data index, select data, such tokens, lenths, label names, etc.
201 | task = self.collectDataForTask(taskDataInds)
202 |
203 | return task
--------------------------------------------------------------------------------
/utils/Trainer.py:
--------------------------------------------------------------------------------
1 | from utils.IntentDataset import IntentDataset
2 | from utils.Evaluator import EvaluatorBase
3 | from utils.Logger import logger
4 | from utils.commonVar import *
5 | from utils.tools import mask_tokens, makeTrainExamples
6 | import time
7 | import torch
8 | from torch.utils.data import DataLoader
9 | import numpy as np
10 | import copy
11 | from sklearn.metrics import accuracy_score, r2_score
12 | from torch.utils.tensorboard import SummaryWriter
13 |
14 | ##
15 | # @brief base class of trainer
16 | class TrainerBase():
17 | def __init__(self):
18 | self.finished=False
19 | self.bestModelStateDict = None
20 | self.roundN = 4
21 | pass
22 |
23 | def round(self, floatNum):
24 | return round(floatNum, self.roundN)
25 |
26 | def train(self):
27 | raise NotImplementedError("train() is not implemented.")
28 |
29 | def getBestModelStateDict(self):
30 | return self.bestModelStateDict
31 |
32 | ##
33 | # @brief TransferTrainer used to do transfer-training. The training is performed in a supervised manner. All available data is used fo training. By contrast, meta-training is performed by tasks.
34 | class TransferTrainer(TrainerBase):
35 | def __init__(self,
36 | trainingParam:dict,
37 | optimizer,
38 | dataset:IntentDataset,
39 | unlabeled:IntentDataset,
40 | valEvaluator: EvaluatorBase,
41 | testEvaluator:EvaluatorBase):
42 | super(TransferTrainer, self).__init__()
43 | self.epoch = trainingParam['epoch']
44 | self.batch_size = trainingParam['batch']
45 | self.validation = trainingParam['validation']
46 | self.patience = trainingParam['patience']
47 | self.tensorboard = trainingParam['tensorboard']
48 | self.mlm = trainingParam['mlm']
49 | self.lambda_mlm = trainingParam['lambda mlm']
50 | self.regression = trainingParam['regression']
51 |
52 | self.dataset = dataset
53 | self.unlabeled = unlabeled
54 | self.optimizer = optimizer
55 | self.valEvaluator = valEvaluator
56 | self.testEvaluator = testEvaluator
57 |
58 | if self.tensorboard:
59 | self.writer = SummaryWriter()
60 |
61 | def train(self, model, tokenizer, mode='multi-class'):
62 | self.bestModelStateDict = copy.deepcopy(model.state_dict())
63 | durationOverallTrain = 0.0
64 | durationOverallVal = 0.0
65 | valBestAcc = -1
66 | accumulateStep = 0
67 |
68 | # evaluate before training
69 | valAcc, valPre, valRec, valFsc = self.valEvaluator.evaluate(model, tokenizer, mode)
70 | teAcc, tePre, teRec, teFsc = self.testEvaluator.evaluate(model, tokenizer, mode)
71 | logger.info('---- Before training ----')
72 | logger.info("ValAcc %f, Val pre %f, Val rec %f , Val Fsc %f", valAcc, valPre, valRec, valFsc)
73 | logger.info("TestAcc %f, Test pre %f, Test rec %f, Test Fsc %f", teAcc, tePre, teRec, teFsc)
74 |
75 | if mode == 'multi-class':
76 | labTensorData = makeTrainExamples(self.dataset.getTokList(), tokenizer, self.dataset.getLabID(), mode=mode)
77 | else:
78 | logger.error("Invalid model %d"%(mode))
79 | dataloader = DataLoader(labTensorData, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)
80 |
81 | if self.mlm:
82 | unlabTensorData = makeTrainExamples(self.unlabeled.getTokList(), tokenizer, mode='unlabel')
83 | unlabeledloader = DataLoader(unlabTensorData, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)
84 | unlabelediter = iter(unlabeledloader)
85 |
86 | for epoch in range(self.epoch): # an epoch means all sampled tasks are done
87 | model.train()
88 | batchTrAccSum = 0.0
89 | batchTrLossSPSum = 0.0
90 | batchTrLossMLMSum = 0.0
91 | timeEpochStart = time.time()
92 |
93 | for batch in dataloader:
94 | # task data
95 | Y, ids, types, masks = batch
96 | X = {'input_ids':ids.to(model.device),
97 | 'token_type_ids':types.to(model.device),
98 | 'attention_mask':masks.to(model.device)}
99 |
100 | # forward
101 | logits = model(X)
102 | # loss
103 | if self.regression:
104 | lossSP = model.loss_mse(logits, Y.to(model.device))
105 | else:
106 | lossSP = model.loss_ce(logits, Y.to(model.device))
107 |
108 | if self.mlm:
109 | try:
110 | ids, types, masks = unlabelediter.next()
111 | except StopIteration:
112 | unlabelediter = iter(unlabeledloader)
113 | ids, types, masks = unlabelediter.next()
114 | X_un = {'input_ids':ids.to(model.device),
115 | 'token_type_ids':types.to(model.device),
116 | 'attention_mask':masks.to(model.device)}
117 | mask_ids, mask_lb = mask_tokens(X_un['input_ids'].cpu(), tokenizer)
118 | X_un = {'input_ids':mask_ids.to(model.device),
119 | 'token_type_ids':X_un['token_type_ids'],
120 | 'attention_mask':X_un['attention_mask']}
121 | lossMLM = model.mlmForward(X_un, mask_lb.to(model.device))
122 | lossTOT = lossSP + self.lambda_mlm * lossMLM
123 | else:
124 | lossTOT = lossSP
125 |
126 | # backward
127 | self.optimizer.zero_grad()
128 | lossTOT.backward()
129 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
130 | self.optimizer.step()
131 |
132 | # calculate train acc
133 | YTensor = Y.cpu()
134 | logits = logits.detach().clone()
135 | if torch.cuda.is_available():
136 | logits = logits.cpu()
137 | if self.regression:
138 | predictResult = torch.sigmoid(logits).numpy()
139 | acc = r2_score(YTensor, predictResult)
140 | else:
141 | logits = logits.numpy()
142 | predictResult = np.argmax(logits, 1)
143 | acc = accuracy_score(YTensor, predictResult)
144 |
145 | # accumulate statistics
146 | batchTrAccSum += acc
147 | batchTrLossSPSum += lossSP.item()
148 | if self.mlm:
149 | batchTrLossMLMSum += lossMLM.item()
150 |
151 | # current epoch training done, collect data
152 | durationTrain = self.round(time.time() - timeEpochStart)
153 | durationOverallTrain += durationTrain
154 | batchTrAccAvrg = self.round(batchTrAccSum/len(dataloader))
155 | batchTrLossSPAvrg = batchTrLossSPSum/len(dataloader)
156 | batchTrLossMLMAvrg = batchTrLossMLMSum/len(dataloader)
157 |
158 | valAcc, valPre, valRec, valFsc = self.valEvaluator.evaluate(model, tokenizer, mode)
159 | teAcc, tePre, teRec, teFsc = self.testEvaluator.evaluate(model, tokenizer, mode)
160 |
161 | # display current epoch's info
162 | logger.info("---- epoch: %d/%d, train_time %f ----", epoch, self.epoch, durationTrain)
163 | logger.info("SPLoss %f, MLMLoss %f, TrainAcc %f", batchTrLossSPAvrg, batchTrLossMLMAvrg, batchTrAccAvrg)
164 | logger.info("ValAcc %f, Val pre %f, Val rec %f , Val Fsc %f", valAcc, valPre, valRec, valFsc)
165 | logger.info("TestAcc %f, Test pre %f, Test rec %f, Test Fsc %f", teAcc, tePre, teRec, teFsc)
166 | if self.tensorboard:
167 | self.writer.add_scalar('train loss', batchTrLossSPAvrg+self.lambda_mlm*batchTrLossMLMAvrg, global_step=epoch)
168 | self.writer.add_scalar('val acc', valAcc, global_step=epoch)
169 | self.writer.add_scalar('test acc', teAcc, global_step=epoch)
170 |
171 | # early stop
172 | if not self.validation:
173 | valAcc = -1
174 | if (valAcc >= valBestAcc): # better validation result
175 | print("[INFO] Find a better model. Val acc: %f -> %f"%(valBestAcc, valAcc))
176 | valBestAcc = valAcc
177 | accumulateStep = 0
178 |
179 | # cache current model, used for evaluation later
180 | self.bestModelStateDict = copy.deepcopy(model.state_dict())
181 | else:
182 | accumulateStep += 1
183 | if accumulateStep > self.patience/2:
184 | print('[INFO] accumulateStep: ', accumulateStep)
185 | if accumulateStep == self.patience: # early stop
186 | logger.info('Early stop.')
187 | logger.debug("Overall training time %f", durationOverallTrain)
188 | logger.debug("Overall validation time %f", durationOverallVal)
189 | logger.debug("best_val_acc: %f", valBestAcc)
190 | break
191 |
192 | logger.info("best_val_acc: %f", valBestAcc)
193 |
194 |
195 | ##
196 | # @brief TransferTrainer used to do transfer-training. The training is performed in a supervised manner. All available data is used fo training. By contrast, meta-training is performed by tasks.
197 | class MLMOnlyTrainer(TrainerBase):
198 | def __init__(self,
199 | trainingParam:dict,
200 | optimizer,
201 | dataset:IntentDataset,
202 | unlabeled:IntentDataset,
203 | testEvaluator:EvaluatorBase):
204 | super(MLMOnlyTrainer, self).__init__()
205 | self.epoch = trainingParam['epoch']
206 | self.batch_size = trainingParam['batch']
207 | self.tensorboard = trainingParam['tensorboard']
208 |
209 | self.dataset = dataset
210 | self.unlabeled = unlabeled
211 | self.optimizer = optimizer
212 | self.testEvaluator = testEvaluator
213 |
214 | if self.tensorboard:
215 | self.writer = SummaryWriter()
216 |
217 | def train(self, model, tokenizer):
218 | durationOverallTrain = 0.0
219 |
220 | # evaluate before training
221 | teAcc, tePre, teRec, teFsc = self.testEvaluator.evaluate(model, tokenizer, 'multi-class')
222 | logger.info('---- Before training ----')
223 | logger.info("TestAcc %f, Test pre %f, Test rec %f, Test Fsc %f", teAcc, tePre, teRec, teFsc)
224 |
225 | labTensorData = makeTrainExamples(self.dataset.getTokList(), tokenizer, mode='unlabel')
226 | dataloader = DataLoader(labTensorData, batch_size=self.batch_size, shuffle=True, num_workers=4, pin_memory=True)
227 |
228 | for epoch in range(self.epoch): # an epoch means all sampled tasks are done
229 | model.train()
230 | batchTrLossSum = 0.0
231 | timeEpochStart = time.time()
232 |
233 | for batch in dataloader:
234 | # task data
235 | ids, types, masks = batch
236 | X = {'input_ids':ids.to(model.device),
237 | 'token_type_ids':types.to(model.device),
238 | 'attention_mask':masks.to(model.device)}
239 |
240 | # forward
241 | mask_ids, mask_lb = mask_tokens(X['input_ids'].cpu(), tokenizer)
242 | X = {'input_ids':mask_ids.to(model.device),
243 | 'token_type_ids':X['token_type_ids'],
244 | 'attention_mask':X['attention_mask']}
245 | loss = model.mlmForward(X, mask_lb.to(model.device))
246 |
247 | # backward
248 | self.optimizer.zero_grad()
249 | loss.backward()
250 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
251 | self.optimizer.step()
252 |
253 | durationTrain = self.round(time.time() - timeEpochStart)
254 | durationOverallTrain += durationTrain
255 | batchTrLossAvrg = batchTrLossSum/len(dataloader)
256 |
257 | teAcc, tePre, teRec, teFsc = self.testEvaluator.evaluate(model, tokenizer, 'multi-class')
258 |
259 | # display current epoch's info
260 | logger.info("---- epoch: %d/%d, train_time %f ----", epoch, self.epoch, durationTrain)
261 | logger.info("TrainLoss %f", batchTrLossAvrg)
262 | logger.info("TestAcc %f, Test pre %f, Test rec %f, Test Fsc %f", teAcc, tePre, teRec, teFsc)
263 | if self.tensorboard:
264 | self.writer.add_scalar('train loss', batchTrLossAvrg, global_step=epoch)
265 | self.writer.add_scalar('test acc', teAcc, global_step=epoch)
266 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fanolabs/IntentBert/833ffdd16f004a8f5500d19b59a2bdf4ccd23674/utils/__init__.py
--------------------------------------------------------------------------------
/utils/commonVar.py:
--------------------------------------------------------------------------------
1 | # embedding name
2 | LM_NAME_FASTTEXT="fasttext"
3 | LM_NAME_BERT_BASE_UNCASED='bert-base-uncased'
4 |
5 | # debug level
6 | LOGGING_LEVEL_CRITICAL="CRITICAL"
7 | LOGGING_LEVEL_ERROR="ERROR"
8 | LOGGING_LEVEL_WARNING="WARNING"
9 | LOGGING_LEVEL_INFO="INFO"
10 | LOGGING_LEVEL_DEBUG="DEBUG"
11 | LOGGING_LEVEL_NOTSET="NOTSET"
12 |
13 | # dir and file name
14 | FILE_NAME_DATASET = "dataset.json"
15 |
16 | # meta-task information
17 | META_TASK_GLB_LABID = "META_TASK_GLB_LABID"
18 | META_TASK_SHOT_GLB_LABID = "META_TASK_SHOT_GLB_LABID"
19 | META_TASK_SHOT_LOC_LABID = "META_TASK_SHOT_LOC_LABID"
20 | META_TASK_SHOT_DATAIND = "META_TASK_SHOT_DATAIND"
21 | META_TASK_SHOT_TOKEN = "META_TASK_SHOT_TOKEN"
22 | META_TASK_SHOT_LAB = "META_TASK_SHOT_LAB"
23 | META_TASK_QUERY_GLB_LABID = "META_TASK_QUERY_GLB_LABID"
24 | META_TASK_QUERY_LOC_LABID = "META_TASK_QUERY_LOC_LABID"
25 | META_TASK_QUERY_DATAIND = "META_TASK_QUERY_DATAIND"
26 | META_TASK_QUERY_TOKEN = "META_TASK_QUERY_TOKEN"
27 | META_TASK_QUERY_LAB = "META_TASK_QUERY_LAB"
28 |
29 | # classifier name for validation and evaluation, meta-evaluation
30 | CLSFIER_LINEAR_REGRESSION = "Linear"
31 | CLSFIER_SVM = "SVM"
32 | CLSFIER_NN = "NN"
33 | CLSFIER_COSINE = "Cosine"
34 | CLSFIER_MULTI_LABEL = "MultiLabel"
35 |
36 | # optmizer name
37 | OPTER_ADAM = "Adam"
38 | OPTER_SGD = "SGD"
39 |
40 | # path
41 | SAVE_PATH = './saved_models'
42 | DATA_PATH = './data'
--------------------------------------------------------------------------------
/utils/models.py:
--------------------------------------------------------------------------------
1 | #coding=utf-8
2 | import os
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from transformers import AutoModelForMaskedLM
8 | from utils.commonVar import *
9 |
10 | from sklearn.linear_model import LogisticRegression
11 | from sklearn.svm import SVC
12 | from sklearn.multioutput import MultiOutputClassifier
13 | from sklearn.pipeline import make_pipeline
14 | from sklearn.preprocessing import StandardScaler
15 |
16 | class IntentBERT(nn.Module):
17 | def __init__(self, config):
18 | super(IntentBERT, self).__init__()
19 | self.device = config['device']
20 | self.LMName = config['LMName']
21 | self.clsNum = config['clsNumber']
22 | try:
23 | self.word_embedding = AutoModelForMaskedLM.from_pretrained(self.LMName)
24 | except:
25 | self.word_embedding = AutoModelForMaskedLM.from_pretrained(os.path.join(SAVE_PATH, self.LMName))
26 | self.linearClsfier = nn.Linear(768, self.clsNum)
27 | self.dropout = nn.Dropout(0.1) # follow the default in bert model
28 | # self.word_embedding = nn.DataParallel(self.word_embedding)
29 | self.word_embedding.to(self.device)
30 | self.linearClsfier.to(self.device)
31 |
32 | def loss_ce(self, logits, Y):
33 | loss = nn.CrossEntropyLoss()
34 | output = loss(logits, Y)
35 | return output
36 |
37 | def loss_mse(self, logits, Y):
38 | loss = nn.MSELoss()
39 | output = loss(torch.sigmoid(logits).squeeze(), Y)
40 | return output
41 |
42 | def loss_kl(self, logits, label):
43 | # KL-div loss
44 | probs = F.log_softmax(logits, dim=1)
45 | # label_probs = F.log_softmax(label, dim=1)
46 | loss = F.kl_div(probs, label, reduction='batchmean')
47 | return loss
48 |
49 | def forward(self, X):
50 | # BERT forward
51 | outputs = self.word_embedding(**X, output_hidden_states=True)
52 |
53 | # extract [CLS] for utterance representation
54 | CLSEmbedding = outputs.hidden_states[-1][:,0]
55 |
56 | # linear classifier
57 | CLSEmbedding = self.dropout(CLSEmbedding)
58 | logits = self.linearClsfier(CLSEmbedding)
59 |
60 | return logits
61 |
62 | def mlmForward(self, X, Y):
63 | # BERT forward
64 | outputs = self.word_embedding(**X, labels=Y)
65 |
66 | return outputs.loss
67 |
68 | def fewShotPredict(self, supportX, supportY, queryX, clsFierName, mode='multi-class'):
69 | # calculate word embedding
70 | # BERT forward
71 | s_embedding = self.word_embedding(**supportX, output_hidden_states=True).hidden_states[-1]
72 | q_embedding = self.word_embedding(**queryX, output_hidden_states=True).hidden_states[-1]
73 |
74 | # extract [CLS] for utterance representation
75 | supportEmbedding = s_embedding[:,0]
76 | queryEmbedding = q_embedding[:,0]
77 | support_features = self.normalize(supportEmbedding).cpu()
78 | query_features = self.normalize(queryEmbedding).cpu()
79 |
80 | # select clsfier
81 | clf = None
82 | if clsFierName == CLSFIER_LINEAR_REGRESSION:
83 | clf = LogisticRegression(penalty='l2',
84 | random_state=0,
85 | C=1.0,
86 | solver='lbfgs',
87 | max_iter=1000,
88 | multi_class='multinomial')
89 | # fit and predict
90 | clf.fit(support_features, supportY)
91 | elif clsFierName == CLSFIER_SVM:
92 | clf = make_pipeline(StandardScaler(),
93 | SVC(gamma='auto',C=1,
94 | kernel='linear',
95 | decision_function_shape='ovr'))
96 | # fit and predict
97 | clf.fit(support_features, supportY)
98 | elif clsFierName == CLSFIER_MULTI_LABEL:
99 | clf = MultiOutputClassifier(LogisticRegression(penalty='l2',
100 | random_state=0,
101 | C=1.0,
102 | solver='liblinear',
103 | max_iter=1000,
104 | multi_class='ovr',
105 | class_weight='balanced'))
106 |
107 | clf.fit(support_features, supportY)
108 | else:
109 | raise NotImplementedError("Not supported clasfier name %s", clsFierName)
110 |
111 | if mode == 'multi-class':
112 | query_pred = clf.predict(query_features)
113 | else:
114 | logger.error("Invalid model %d"%(mode))
115 |
116 | return query_pred
117 |
118 | def reinit_clsfier(self):
119 | self.linearClsfier.weight.data.normal_(mean=0.0, std=0.02)
120 | self.linearClsfier.bias.data.zero_()
121 |
122 | def set_dropout_layer(self, dropout_rate):
123 | self.dropout = nn.Dropout(dropout_rate)
124 |
125 | def set_linear_layer(self, clsNum):
126 | self.linearClsfier = nn.Linear(768, clsNum)
127 |
128 | def normalize(self, x):
129 | norm = x.pow(2).sum(1, keepdim=True).pow(1. / 2)
130 | out = x.div(norm)
131 | return out
132 |
133 | def NN(self, support, support_ys, query):
134 | """nearest classifier"""
135 | support = np.expand_dims(support.transpose(), 0)
136 | query = np.expand_dims(query, 2)
137 |
138 | diff = np.multiply(query - support, query - support)
139 | distance = diff.sum(1)
140 | min_idx = np.argmin(distance, axis=1)
141 | pred = [support_ys[idx] for idx in min_idx]
142 | return pred
143 |
144 | def CosineClsfier(self, support, support_ys, query):
145 | """Cosine classifier"""
146 | support_norm = np.linalg.norm(support, axis=1, keepdims=True)
147 | support = support / support_norm
148 | query_norm = np.linalg.norm(query, axis=1, keepdims=True)
149 | query = query / query_norm
150 |
151 | cosine_distance = query @ support.transpose()
152 | max_idx = np.argmax(cosine_distance, axis=1)
153 | pred = [support_ys[idx] for idx in max_idx]
154 | return pred
155 |
156 | def save(self, path):
157 | self.word_embedding.save_pretrained(path)
158 |
--------------------------------------------------------------------------------
/utils/printHelper.py:
--------------------------------------------------------------------------------
1 | from utils.Logger import logger
2 | import logging
3 |
4 | ##
5 | # @brief print means value, std value and item names
6 | #
7 | # @param meanList: example [1.1, 2.2, 1.2]
8 | # @param stdList: example [0.1, 0.15, 0.001]
9 | # @param itemList: example ['acc', 'pre', 'recall', 'fsc']
10 | # @param debugLevel: example logging.INFO
11 | #
12 | # @return
13 | def printMeanStd(meanList, stdList, itemList, debugLevel=logging.INFO):
14 | # select logging function
15 | loggingFunc = None
16 | if (debugLevel == logging.INFO):
17 | loggingFunc = logger.info
18 | elif (debugLevel == logging.DEBUG):
19 | loggingFunc = logger.debug
20 | else:
21 | raise NotImplementedError("Not supported logging level.")
22 |
23 | lengthSet = set()
24 | lengthSet.add(len(meanList))
25 | lengthSet.add(len(stdList))
26 | lengthSet.add(len(itemList))
27 | if not len(lengthSet) == 1:
28 | logger.error("Inconsisten list lengths when printing statistics.")
29 | exit(1)
30 |
31 | for mean, std, item in zip(meanList, stdList, itemList):
32 | loggingFunc("%-6s: %f +- %f", item, mean, std)
33 |
34 |
35 | ##
36 | # @brief print means value and item names
37 | #
38 | # @param meanList: example [1.1, 2.2, 1.2]
39 | # @param itemList: example ['acc', 'pre', 'recall', 'fsc']
40 | # @param debugLevel: example logging.INFO
41 | #
42 | # @return
43 | def printMean(meanList, itemList, debugLevel=logging.INFO):
44 | # select logging function
45 | loggingFunc = None
46 | if (debugLevel == logging.INFO):
47 | loggingFunc = logger.info
48 | elif (debugLevel == logging.DEBUG):
49 | loggingFunc = logger.debug
50 | else:
51 | raise NotImplementedError("Not supported logging level.")
52 |
53 | lengthSet = set()
54 | lengthSet.add(len(meanList))
55 | lengthSet.add(len(itemList))
56 | if not len(lengthSet) == 1:
57 | logger.error("Inconsisten list lengths when printing statistics.")
58 | exit(1)
59 |
60 | for mean, item in zip(meanList, itemList):
61 | loggingFunc("%-6s: %f", item, mean)
62 |
--------------------------------------------------------------------------------
/utils/tools.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import TensorDataset
3 | import numpy as np
4 | from random import sample
5 | import random
6 | random.seed(0)
7 |
8 | entail_label_map = {'entail':1, 'nonentail':0}
9 |
10 | def getDomainName(name):
11 | if name=="auto_commute":
12 | return "auto commute"
13 | elif name=="credit_cards":
14 | return "credit cards"
15 | elif name=="kitchen_dining":
16 | return "kitchen dining"
17 | elif name=="small_talk":
18 | return "small talk"
19 | elif ' ' not in name:
20 | return name
21 | else:
22 | raise NotImplementedError("Not supported domain name %s"%(name))
23 |
24 | def splitName(dom):
25 | domList = []
26 | for name in dom.split(','):
27 | domList.append(getDomainName(name))
28 | return domList
29 |
30 | def makeTrainExamples(data:list, tokenizer, label=None, mode='unlabel'):
31 | """
32 | unlabel: simply pad data and then convert into tensor
33 | multi-class: pad data and compose tensor dataset with labels
34 | """
35 | if mode != "unlabel":
36 | assert label is not None, f"Label is provided for the required setting {mode}"
37 | if mode == "multi-class":
38 | examples = tokenizer.pad(data, padding='longest', return_tensors='pt')
39 | if not isinstance(label, torch.Tensor):
40 | label = torch.tensor(label)
41 | examples = TensorDataset(label,
42 | examples['input_ids'],
43 | examples['token_type_ids'],
44 | examples['attention_mask'])
45 | else:
46 | raise ValueError(f"Undefined setting {mode}")
47 | else:
48 | examples = tokenizer.pad(data, padding='longest', return_tensors='pt')
49 | examples = TensorDataset(examples['input_ids'],
50 | examples['token_type_ids'],
51 | examples['attention_mask'])
52 | return examples
53 |
54 | def makeEvalExamples(supportX, supportY, queryX, queryY, tokenizer, mode='multi-class'):
55 | """
56 | multi-class: simply pad data
57 | """
58 | if mode == "multi-class":
59 | supportX = tokenizer.pad(supportX, padding='longest', return_tensors='pt')
60 | queryX = tokenizer.pad(queryX, padding='longest', return_tensors='pt')
61 | else:
62 | raise ValueError("Invalid mode %d."%(mode))
63 | return supportX, supportY, queryX, queryY
64 |
65 | #https://github.com/huggingface/transformers/blob/master/src/transformers/data/data_collator.py#L70
66 | def mask_tokens(inputs, tokenizer,\
67 | special_tokens_mask=None, mlm_probability=0.15):
68 | """
69 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
70 | """
71 | labels = inputs.clone()
72 | # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
73 | probability_matrix = torch.full(labels.shape, mlm_probability)
74 | if special_tokens_mask is None:
75 | special_tokens_mask = [
76 | tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
77 | ]
78 | special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
79 | else:
80 | special_tokens_mask = special_tokens_mask.bool()
81 |
82 | probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
83 | probability_matrix[torch.where(inputs==0)] = 0.0
84 | masked_indices = torch.bernoulli(probability_matrix).bool()
85 | labels[~masked_indices] = -100 # We only compute loss on masked tokens
86 |
87 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
88 | indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
89 | inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
90 |
91 | # 10% of the time, we replace masked input tokens with random word
92 | indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
93 | random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
94 | inputs[indices_random] = random_words[indices_random]
95 |
96 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged
97 | return inputs, labels
98 |
--------------------------------------------------------------------------------