├── requirements.in ├── slot_tag.sh ├── environment.yml ├── intent_cls.sh ├── preprocess.sh ├── README.md ├── Makefile ├── model.py ├── dataset.py ├── utils.py ├── test_slot.py ├── train_slot.py ├── preprocess_slot.py ├── test_intent.py ├── train_intent.py └── preprocess_intent.py /requirements.in: -------------------------------------------------------------------------------- 1 | # requirements 2 | torch==1.12.1 3 | tensorflow==2.10.0 4 | seqeval==1.2.2 5 | tqdm 6 | numpy 7 | pandas 8 | scikit-learn==1.1.2 9 | -------------------------------------------------------------------------------- /slot_tag.sh: -------------------------------------------------------------------------------- 1 | # "${1}" is the first argument passed to the script 2 | # "${2}" is the second argument passed to the script 3 | python3 test_slot.py --test_file ... -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: adl-hw1 2 | channels: 3 | - defaults 4 | dependencies: 5 | - python=3.9 6 | - cudatoolkit=10.2 7 | - cudnn=7.6 8 | - pip 9 | - pip: 10 | - pip-tools -------------------------------------------------------------------------------- /intent_cls.sh: -------------------------------------------------------------------------------- 1 | # "${1}" is the first argument passed to the script 2 | # "${2}" is the second argument passed to the script 3 | python3 test_intent.py --test_file "${1}" --ckpt_path ckpt/intent/best.pt --pred_file "${2}" -------------------------------------------------------------------------------- /preprocess.sh: -------------------------------------------------------------------------------- 1 | if [ ! -f glove.840B.300d.txt ]; then 2 | wget http://nlp.stanford.edu/data/glove.840B.300d.zip -O glove.840B.300d.zip 3 | unzip glove.840B.300d.zip 4 | fi 5 | python preprocess_intent.py 6 | python preprocess_slot.py 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sample Code for Homework 1 ADL NTU 2 | 3 | ## Environment 4 | ```shell 5 | # If you have conda, we recommend you to build a conda environment called "adl-hw1" 6 | make 7 | conda activate adl-hw1 8 | pip install -r requirements.txt 9 | # Otherwise 10 | pip install -r requirements.in 11 | ``` 12 | 13 | ## Preprocessing 14 | ```shell 15 | # To preprocess intent detectiona and slot tagging datasets 16 | bash preprocess.sh 17 | ``` 18 | 19 | ## Intent detection 20 | ```shell 21 | python train_intent.py 22 | ``` 23 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Thanks to http://blog.ianpreston.ca/2020/05/13/conda_envs.html for working some of this out! 2 | 3 | # Oneshell means all lines in a recipe run in the same shell 4 | .ONESHELL: 5 | 6 | # Need to specify bash in order for conda activate to work 7 | SHELL=/bin/bash 8 | 9 | # Note that the extra activate is needed to ensure that the activate floats env to the front of PATH 10 | CONDA_ACTIVATE=source $$(conda info --base)/etc/profile.d/conda.sh ; conda activate ; conda activate 11 | 12 | # Same name as in environment.yml 13 | CONDA_ENV=adl-hw1 14 | 15 | all: conda-env-update pip-compile pip-sync 16 | 17 | # Create or update conda env 18 | conda-env-update: 19 | conda env update --prune 20 | 21 | # Compile exact pip packages 22 | pip-compile: 23 | $(CONDA_ACTIVATE) $(CONDA_ENV) 24 | pip-compile -v requirements.in 25 | 26 | # Install pip packages 27 | pip-sync: 28 | $(CONDA_ACTIVATE) $(CONDA_ENV) 29 | pip-sync requirements.txt 30 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | from torch.nn import Embedding 5 | 6 | 7 | class SeqClassifier(torch.nn.Module): 8 | def __init__( 9 | self, 10 | embeddings: torch.tensor, 11 | hidden_size: int, 12 | num_layers: int, 13 | dropout: float, 14 | bidirectional: bool, 15 | num_class: int, 16 | ) -> None: 17 | super(SeqClassifier, self).__init__() 18 | self.embed = Embedding.from_pretrained(embeddings, freeze=False) 19 | # TODO: model architecture 20 | 21 | @property 22 | def encoder_output_size(self) -> int: 23 | # TODO: calculate the output dimension of rnn 24 | raise NotImplementedError 25 | 26 | def forward(self, batch) -> Dict[str, torch.Tensor]: 27 | # TODO: implement model forward 28 | raise NotImplementedError 29 | 30 | 31 | class SeqTagger(SeqClassifier): 32 | def forward(self, batch) -> Dict[str, torch.Tensor]: 33 | # TODO: implement model forward 34 | raise NotImplementedError 35 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | from torch.utils.data import Dataset 4 | 5 | from utils import Vocab 6 | 7 | 8 | class SeqClsDataset(Dataset): 9 | def __init__( 10 | self, 11 | data: List[Dict], 12 | vocab: Vocab, 13 | label_mapping: Dict[str, int], 14 | max_len: int, 15 | ): 16 | self.data = data 17 | self.vocab = vocab 18 | self.label_mapping = label_mapping 19 | self._idx2label = {idx: intent for intent, idx in self.label_mapping.items()} 20 | self.max_len = max_len 21 | 22 | def __len__(self) -> int: 23 | return len(self.data) 24 | 25 | def __getitem__(self, index) -> Dict: 26 | instance = self.data[index] 27 | return instance 28 | 29 | @property 30 | def num_classes(self) -> int: 31 | return len(self.label_mapping) 32 | 33 | def collate_fn(self, samples: List[Dict]) -> Dict: 34 | # TODO: implement collate_fn 35 | raise NotImplementedError 36 | 37 | def label2idx(self, label: str): 38 | return self.label_mapping[label] 39 | 40 | def idx2label(self, idx: int): 41 | return self._idx2label[idx] 42 | 43 | 44 | class SeqTaggingClsDataset(SeqClsDataset): 45 | ignore_idx = -100 46 | 47 | def collate_fn(self, samples): 48 | # TODO: implement collate_fn 49 | raise NotImplementedError 50 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List 2 | 3 | 4 | class Vocab: 5 | PAD = "[PAD]" 6 | UNK = "[UNK]" 7 | 8 | def __init__(self, vocab: Iterable[str]) -> None: 9 | self.token2idx = { 10 | Vocab.PAD: 0, 11 | Vocab.UNK: 1, 12 | **{token: i for i, token in enumerate(vocab, 2)}, 13 | } 14 | 15 | @property 16 | def pad_id(self) -> int: 17 | return self.token2idx[Vocab.PAD] 18 | 19 | @property 20 | def unk_id(self) -> int: 21 | return self.token2idx[Vocab.UNK] 22 | 23 | @property 24 | def tokens(self) -> List[str]: 25 | return list(self.token2idx.keys()) 26 | 27 | def token_to_id(self, token: str) -> int: 28 | return self.token2idx.get(token, self.unk_id) 29 | 30 | def encode(self, tokens: List[str]) -> List[int]: 31 | return [self.token_to_id(token) for token in tokens] 32 | 33 | def encode_batch( 34 | self, batch_tokens: List[List[str]], to_len: int = None 35 | ) -> List[List[int]]: 36 | batch_ids = [self.encode(tokens) for tokens in batch_tokens] 37 | to_len = max(len(ids) for ids in batch_ids) if to_len is None else to_len 38 | padded_ids = pad_to_len(batch_ids, to_len, self.pad_id) 39 | return padded_ids 40 | 41 | 42 | def pad_to_len(seqs: List[List[int]], to_len: int, padding: int) -> List[List[int]]: 43 | paddeds = [seq[:to_len] + [padding] * max(0, to_len - len(seq)) for seq in seqs] 44 | return paddeds 45 | -------------------------------------------------------------------------------- /test_slot.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | from argparse import ArgumentParser, Namespace 4 | from pathlib import Path 5 | from typing import Dict 6 | 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | from dataset import SeqTaggingClsDataset 12 | from model import SeqTagger 13 | from utils import Vocab 14 | 15 | 16 | def main(args): 17 | # TODO: implement main function 18 | raise NotImplementedError 19 | 20 | 21 | def parse_args() -> Namespace: 22 | parser = ArgumentParser() 23 | parser.add_argument( 24 | "--data_dir", 25 | type=Path, 26 | help="Directory to the dataset.", 27 | default="./data/slot/", 28 | ) 29 | parser.add_argument( 30 | "--cache_dir", 31 | type=Path, 32 | help="Directory to the preprocessed caches.", 33 | default="./cache/slot/", 34 | ) 35 | parser.add_argument( 36 | "--ckpt_dir", 37 | type=Path, 38 | help="Directory to save the model file.", 39 | default="./ckpt/slot/", 40 | ) 41 | parser.add_argument("--pred_file", type=Path, default="pred.slot.csv") 42 | 43 | # data 44 | parser.add_argument("--max_len", type=int, default=128) 45 | 46 | # model 47 | parser.add_argument("--hidden_size", type=int, default=512) 48 | parser.add_argument("--num_layers", type=int, default=2) 49 | parser.add_argument("--dropout", type=float, default=0.1) 50 | parser.add_argument("--bidirectional", type=bool, default=True) 51 | 52 | # data loader 53 | parser.add_argument("--batch_size", type=int, default=128) 54 | 55 | parser.add_argument( 56 | "--device", type=torch.device, help="cpu, cuda, cuda:0, cuda:1", default="cpu" 57 | ) 58 | args = parser.parse_args() 59 | return args 60 | 61 | 62 | if __name__ == "__main__": 63 | args = parse_args() 64 | main(args) -------------------------------------------------------------------------------- /train_slot.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | from argparse import ArgumentParser, Namespace 4 | from pathlib import Path 5 | from typing import Dict 6 | 7 | import torch 8 | from torch.optim import Adam 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm, trange 11 | 12 | from dataset import SeqTaggingClsDataset 13 | from model import SeqTagger 14 | from utils import Vocab 15 | 16 | TRAIN = "train" 17 | DEV = "eval" 18 | SPLITS = [TRAIN, DEV] 19 | 20 | 21 | def main(args): 22 | # TODO: implement main function 23 | raise NotImplementedError 24 | 25 | 26 | def parse_args() -> Namespace: 27 | parser = ArgumentParser() 28 | parser.add_argument( 29 | "--data_dir", 30 | type=Path, 31 | help="Directory to the dataset.", 32 | default="./data/slot/", 33 | ) 34 | parser.add_argument( 35 | "--cache_dir", 36 | type=Path, 37 | help="Directory to the preprocessed caches.", 38 | default="./cache/slot/", 39 | ) 40 | parser.add_argument( 41 | "--ckpt_dir", 42 | type=Path, 43 | help="Directory to save the model file.", 44 | default="./ckpt/slot/", 45 | ) 46 | 47 | # data 48 | parser.add_argument("--max_len", type=int, default=128) 49 | 50 | # model 51 | parser.add_argument("--hidden_size", type=int, default=512) 52 | parser.add_argument("--num_layers", type=int, default=2) 53 | parser.add_argument("--dropout", type=float, default=0.1) 54 | parser.add_argument("--bidirectional", type=bool, default=True) 55 | 56 | # optimizer 57 | parser.add_argument("--lr", type=float, default=1e-3) 58 | 59 | # data loader 60 | parser.add_argument("--batch_size", type=int, default=128) 61 | 62 | # training 63 | parser.add_argument( 64 | "--device", type=torch.device, help="cpu, cuda, cuda:0, cuda:1", default="cpu" 65 | ) 66 | parser.add_argument("--num_epoch", type=int, default=100) 67 | 68 | args = parser.parse_args() 69 | return args 70 | 71 | 72 | if __name__ == "__main__": 73 | args = parse_args() 74 | args.ckpt_dir.mkdir(parents=True, exist_ok=True) 75 | main(args) -------------------------------------------------------------------------------- /preprocess_slot.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from argparse import ArgumentParser, Namespace 4 | from collections import Counter 5 | from pathlib import Path 6 | from random import seed 7 | 8 | from preprocess_intent import build_vocab 9 | 10 | logging.basicConfig( 11 | format="%(asctime)s | %(levelname)s | %(message)s", 12 | level=logging.INFO, 13 | datefmt="%Y-%m-%d %H:%M:%S", 14 | ) 15 | 16 | 17 | def main(args): 18 | seed(args.rand_seed) 19 | 20 | tags = set() 21 | words = Counter() 22 | for split in ["train", "eval"]: 23 | dataset_path = args.data_dir / f"{split}.json" 24 | dataset = json.loads(dataset_path.read_text()) 25 | logging.info(f"Dataset loaded at {str(dataset_path.resolve())}") 26 | 27 | tags.update({tag for instance in dataset for tag in instance["tags"]}) 28 | words.update([token for instance in dataset for token in instance["tokens"]]) 29 | 30 | tag2idx = {tag: i for i, tag in enumerate(tags)} 31 | tag_idx_path = args.output_dir / "tag2idx.json" 32 | tag_idx_path.write_text(json.dumps(tag2idx, indent=2)) 33 | logging.info(f"Tag 2 index saved at {str(tag_idx_path.resolve())}") 34 | 35 | build_vocab(words, args.vocab_size, args.output_dir, args.glove_path) 36 | 37 | 38 | def parse_args() -> Namespace: 39 | parser = ArgumentParser() 40 | parser.add_argument( 41 | "--data_dir", 42 | type=Path, 43 | help="Directory to the dataset.", 44 | default="./data/slot/", 45 | ) 46 | parser.add_argument( 47 | "--glove_path", 48 | type=Path, 49 | help="Path to Glove Embedding.", 50 | default="./glove.840B.300d.txt", 51 | ) 52 | parser.add_argument("--rand_seed", type=int, help="Random seed.", default=13) 53 | parser.add_argument( 54 | "--output_dir", 55 | type=Path, 56 | help="Directory to save the processed file.", 57 | default="./cache/slot/", 58 | ) 59 | parser.add_argument( 60 | "--vocab_size", 61 | type=int, 62 | help="Number of token in the vocabulary", 63 | default=10_000, 64 | ) 65 | args = parser.parse_args() 66 | return args 67 | 68 | 69 | if __name__ == "__main__": 70 | args = parse_args() 71 | args.output_dir.mkdir(parents=True, exist_ok=True) 72 | main(args) 73 | -------------------------------------------------------------------------------- /test_intent.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | from argparse import ArgumentParser, Namespace 4 | from pathlib import Path 5 | from typing import Dict 6 | 7 | import torch 8 | 9 | from dataset import SeqClsDataset 10 | from model import SeqClassifier 11 | from utils import Vocab 12 | 13 | 14 | def main(args): 15 | with open(args.cache_dir / "vocab.pkl", "rb") as f: 16 | vocab: Vocab = pickle.load(f) 17 | 18 | intent_idx_path = args.cache_dir / "intent2idx.json" 19 | intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text()) 20 | 21 | data = json.loads(args.test_file.read_text()) 22 | dataset = SeqClsDataset(data, vocab, intent2idx, args.max_len) 23 | # TODO: crecate DataLoader for test dataset 24 | 25 | embeddings = torch.load(args.cache_dir / "embeddings.pt") 26 | 27 | model = SeqClassifier( 28 | embeddings, 29 | args.hidden_size, 30 | args.num_layers, 31 | args.dropout, 32 | args.bidirectional, 33 | dataset.num_classes, 34 | ) 35 | model.eval() 36 | 37 | ckpt = torch.load(args.ckpt_path) 38 | # load weights into model 39 | 40 | # TODO: predict dataset 41 | 42 | # TODO: write prediction to file (args.pred_file) 43 | 44 | 45 | def parse_args() -> Namespace: 46 | parser = ArgumentParser() 47 | parser.add_argument( 48 | "--test_file", 49 | type=Path, 50 | help="Path to the test file.", 51 | required=True 52 | ) 53 | parser.add_argument( 54 | "--cache_dir", 55 | type=Path, 56 | help="Directory to the preprocessed caches.", 57 | default="./cache/intent/", 58 | ) 59 | parser.add_argument( 60 | "--ckpt_path", 61 | type=Path, 62 | help="Path to model checkpoint.", 63 | required=True 64 | ) 65 | parser.add_argument("--pred_file", type=Path, default="pred.intent.csv") 66 | 67 | # data 68 | parser.add_argument("--max_len", type=int, default=128) 69 | 70 | # model 71 | parser.add_argument("--hidden_size", type=int, default=512) 72 | parser.add_argument("--num_layers", type=int, default=2) 73 | parser.add_argument("--dropout", type=float, default=0.1) 74 | parser.add_argument("--bidirectional", type=bool, default=True) 75 | 76 | # data loader 77 | parser.add_argument("--batch_size", type=int, default=128) 78 | 79 | parser.add_argument( 80 | "--device", type=torch.device, help="cpu, cuda, cuda:0, cuda:1", default="cpu" 81 | ) 82 | args = parser.parse_args() 83 | return args 84 | 85 | 86 | if __name__ == "__main__": 87 | args = parse_args() 88 | main(args) 89 | -------------------------------------------------------------------------------- /train_intent.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | from argparse import ArgumentParser, Namespace 4 | from pathlib import Path 5 | from typing import Dict 6 | 7 | import torch 8 | from tqdm import trange 9 | 10 | from dataset import SeqClsDataset 11 | from utils import Vocab 12 | 13 | TRAIN = "train" 14 | DEV = "eval" 15 | SPLITS = [TRAIN, DEV] 16 | 17 | 18 | def main(args): 19 | with open(args.cache_dir / "vocab.pkl", "rb") as f: 20 | vocab: Vocab = pickle.load(f) 21 | 22 | intent_idx_path = args.cache_dir / "intent2idx.json" 23 | intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text()) 24 | 25 | data_paths = {split: args.data_dir / f"{split}.json" for split in SPLITS} 26 | data = {split: json.loads(path.read_text()) for split, path in data_paths.items()} 27 | datasets: Dict[str, SeqClsDataset] = { 28 | split: SeqClsDataset(split_data, vocab, intent2idx, args.max_len) 29 | for split, split_data in data.items() 30 | } 31 | # TODO: crecate DataLoader for train / dev datasets 32 | 33 | embeddings = torch.load(args.cache_dir / "embeddings.pt") 34 | # TODO: init model and move model to target device(cpu / gpu) 35 | model = None 36 | 37 | # TODO: init optimizer 38 | optimizer = None 39 | 40 | epoch_pbar = trange(args.num_epoch, desc="Epoch") 41 | for epoch in epoch_pbar: 42 | # TODO: Training loop - iterate over train dataloader and update model weights 43 | # TODO: Evaluation loop - calculate accuracy and save model weights 44 | pass 45 | 46 | # TODO: Inference on test set 47 | 48 | 49 | def parse_args() -> Namespace: 50 | parser = ArgumentParser() 51 | parser.add_argument( 52 | "--data_dir", 53 | type=Path, 54 | help="Directory to the dataset.", 55 | default="./data/intent/", 56 | ) 57 | parser.add_argument( 58 | "--cache_dir", 59 | type=Path, 60 | help="Directory to the preprocessed caches.", 61 | default="./cache/intent/", 62 | ) 63 | parser.add_argument( 64 | "--ckpt_dir", 65 | type=Path, 66 | help="Directory to save the model file.", 67 | default="./ckpt/intent/", 68 | ) 69 | 70 | # data 71 | parser.add_argument("--max_len", type=int, default=128) 72 | 73 | # model 74 | parser.add_argument("--hidden_size", type=int, default=512) 75 | parser.add_argument("--num_layers", type=int, default=2) 76 | parser.add_argument("--dropout", type=float, default=0.1) 77 | parser.add_argument("--bidirectional", type=bool, default=True) 78 | 79 | # optimizer 80 | parser.add_argument("--lr", type=float, default=1e-3) 81 | 82 | # data loader 83 | parser.add_argument("--batch_size", type=int, default=128) 84 | 85 | # training 86 | parser.add_argument( 87 | "--device", type=torch.device, help="cpu, cuda, cuda:0, cuda:1", default="cpu" 88 | ) 89 | parser.add_argument("--num_epoch", type=int, default=100) 90 | 91 | args = parser.parse_args() 92 | return args 93 | 94 | 95 | if __name__ == "__main__": 96 | args = parse_args() 97 | args.ckpt_dir.mkdir(parents=True, exist_ok=True) 98 | main(args) 99 | -------------------------------------------------------------------------------- /preprocess_intent.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import pickle 4 | import re 5 | from argparse import ArgumentParser, Namespace 6 | from collections import Counter 7 | from pathlib import Path 8 | from random import random, seed 9 | from typing import List, Dict 10 | 11 | import torch 12 | from tqdm.auto import tqdm 13 | 14 | from utils import Vocab 15 | 16 | logging.basicConfig( 17 | format="%(asctime)s | %(levelname)s | %(message)s", 18 | level=logging.INFO, 19 | datefmt="%Y-%m-%d %H:%M:%S", 20 | ) 21 | 22 | 23 | def build_vocab( 24 | words: Counter, vocab_size: int, output_dir: Path, glove_path: Path 25 | ) -> None: 26 | common_words = {w for w, _ in words.most_common(vocab_size)} 27 | vocab = Vocab(common_words) 28 | vocab_path = output_dir / "vocab.pkl" 29 | with open(vocab_path, "wb") as f: 30 | pickle.dump(vocab, f) 31 | logging.info(f"Vocab saved at {str(vocab_path.resolve())}") 32 | 33 | glove: Dict[str, List[float]] = {} 34 | logging.info(f"Loading glove: {str(glove_path.resolve())}") 35 | with open(glove_path) as fp: 36 | row1 = fp.readline() 37 | # if the first row is not header 38 | if not re.match("^[0-9]+ [0-9]+$", row1): 39 | # seek to 0 40 | fp.seek(0) 41 | # otherwise ignore the header 42 | 43 | for i, line in tqdm(enumerate(fp)): 44 | cols = line.rstrip().split(" ") 45 | word = cols[0] 46 | vector = [float(v) for v in cols[1:]] 47 | 48 | # skip word not in words if words are provided 49 | if word not in common_words: 50 | continue 51 | glove[word] = vector 52 | glove_dim = len(vector) 53 | 54 | assert all(len(v) == glove_dim for v in glove.values()) 55 | assert len(glove) <= vocab_size 56 | 57 | num_matched = sum([token in glove for token in vocab.tokens]) 58 | logging.info( 59 | f"Token covered: {num_matched} / {len(vocab.tokens)} = {num_matched / len(vocab.tokens)}" 60 | ) 61 | embeddings: List[List[float]] = [ 62 | glove.get(token, [random() * 2 - 1 for _ in range(glove_dim)]) 63 | for token in vocab.tokens 64 | ] 65 | embeddings = torch.tensor(embeddings) 66 | embedding_path = output_dir / "embeddings.pt" 67 | torch.save(embeddings, str(embedding_path)) 68 | logging.info(f"Embedding shape: {embeddings.shape}") 69 | logging.info(f"Embedding saved at {str(embedding_path.resolve())}") 70 | 71 | 72 | def main(args): 73 | seed(args.rand_seed) 74 | 75 | intents = set() 76 | words = Counter() 77 | for split in ["train", "eval"]: 78 | dataset_path = args.data_dir / f"{split}.json" 79 | dataset = json.loads(dataset_path.read_text()) 80 | logging.info(f"Dataset loaded at {str(dataset_path.resolve())}") 81 | 82 | intents.update({instance["intent"] for instance in dataset}) 83 | words.update( 84 | [token for instance in dataset for token in instance["text"].split()] 85 | ) 86 | 87 | intent2idx = {tag: i for i, tag in enumerate(intents)} 88 | intent_tag_path = args.output_dir / "intent2idx.json" 89 | intent_tag_path.write_text(json.dumps(intent2idx, indent=2)) 90 | logging.info(f"Intent 2 index saved at {str(intent_tag_path.resolve())}") 91 | 92 | build_vocab(words, args.vocab_size, args.output_dir, args.glove_path) 93 | 94 | 95 | def parse_args() -> Namespace: 96 | parser = ArgumentParser() 97 | parser.add_argument( 98 | "--data_dir", 99 | type=Path, 100 | help="Directory to the dataset.", 101 | default="./data/intent/", 102 | ) 103 | parser.add_argument( 104 | "--glove_path", 105 | type=Path, 106 | help="Path to Glove Embedding.", 107 | default="./glove.840B.300d.txt", 108 | ) 109 | parser.add_argument("--rand_seed", type=int, help="Random seed.", default=13) 110 | parser.add_argument( 111 | "--output_dir", 112 | type=Path, 113 | help="Directory to save the processed file.", 114 | default="./cache/intent/", 115 | ) 116 | parser.add_argument( 117 | "--vocab_size", 118 | type=int, 119 | help="Number of token in the vocabulary", 120 | default=10_000, 121 | ) 122 | args = parser.parse_args() 123 | return args 124 | 125 | 126 | if __name__ == "__main__": 127 | args = parse_args() 128 | args.output_dir.mkdir(parents=True, exist_ok=True) 129 | main(args) 130 | --------------------------------------------------------------------------------