├── .gitattributes ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── demo.png ├── nfcats ├── __init__.py ├── classifier.py ├── data │ ├── test.csv │ ├── train.csv │ └── valid.csv ├── dataset_reader.py ├── embedder.py ├── nft_classifier.jsonnet ├── predict.py ├── sampler.py ├── tf_idf.py ├── train.py ├── utils.py └── wandb_callback.py ├── poetry.lock ├── pyproject.toml └── setup.cfg /.gitattributes: -------------------------------------------------------------------------------- 1 | nfcats/pretrained_model filter=lfs diff=lfs merge=lfs -text 2 | nfcats/data filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | .idea 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Lurunchik 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | CODE = nfcats 2 | VENV = poetry run 3 | WIDTH = 120 4 | 5 | .PHONY: pretty lint 6 | 7 | pretty: 8 | $(VENV) black --skip-string-normalization --line-length $(WIDTH) $(CODE) 9 | $(VENV) isort --apply --recursive --line-width $(WIDTH) $(CODE) 10 | $(VENV) unify --in-place --recursive $(CODE) 11 | 12 | lint: 13 | $(VENV) black --check --skip-string-normalization --line-length $(WIDTH) $(CODE) 14 | $(VENV) flake8 --statistics --max-line-length $(WIDTH) $(CODE) 15 | $(VENV) pylint --rcfile=setup.cfg $(CODE) 16 | $(VENV) mypy $(CODE) 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Non Factoid Question Category classification 2 | 3 | 4 | This repository contains code for the following paper: 5 | 6 | >["A Non-Factoid Question-Answering Taxonomy" published at SIGIR '22, won "Best Paper" Award](https://dl.acm.org/doi/pdf/10.1145/3477495.3531926) 7 | > Valeriia Bolotova, Vladislav Blinov, W. Falk Scholer, Bruce Croft, Mark Sanderson 8 | > ACM SIGIR Conference on Research and Development in Information Retrieval (SIGIR), 2022 9 | 10 | 11 | ## NF_CATS Dataset 12 | The dataset for training is located in [nfcats/data](nfcats/data/). 13 | 14 | ## Model 15 | The trained model could be downloaded from [the hugginface repository](https://huggingface.co/Lurunchik/nf-cats) and you can test the model via [hugginface space](https://huggingface.co/spaces/Lurunchik/nf-cats). 16 | 17 | [![demo.png](demo.png)](https://huggingface.co/spaces/Lurunchik/nf-cats) 18 | 19 | ## Installation 20 | 21 | From source: 22 | 23 | cd NF-CATS 24 | pip install poetry>=1.0.5 25 | poetry install 26 | 27 | 28 | ## Usage 29 | Run test validation of best fine-tuned model: 30 | 31 | python nfcats/predict.py 32 | 33 | Train transformer model 34 | 35 | python nfcats/train.py 36 | 37 | Tf-idf experiments: 38 | 39 | python tf_idf.py 40 | 41 | ## Citation 42 | 43 | If you use `NFQA-cats` in your work, please cite [this paper](https://dl.acm.org/doi/10.1145/3477495.3531926) 44 | 45 | ``` 46 | @misc{bolotova2022nfcats, 47 | author = {Bolotova, Valeriia and Blinov, Vladislav and Scholer, Falk and Croft, W. Bruce and Sanderson, Mark}, 48 | title = {A Non-Factoid Question-Answering Taxonomy}, 49 | year = {2022}, 50 | isbn = {9781450387323}, 51 | publisher = {Association for Computing Machinery}, 52 | address = {New York, NY, USA}, 53 | url = {https://doi.org/10.1145/3477495.3531926}, 54 | doi = {10.1145/3477495.3531926}, 55 | booktitle = {Proceedings of the 45th International ACM SIGIR Conference on Research and Development in Information Retrieval}, 56 | pages = {1196–1207}, 57 | numpages = {12}, 58 | keywords = {question taxonomy, non-factoid question-answering, editorial study, dataset analysis}, 59 | location = {Madrid, Spain}, 60 | series = {SIGIR '22} 61 | } 62 | ``` 63 | 64 | 65 | -------------------------------------------------------------------------------- /demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lurunchik/NF-CATS/1b8c3b83337dfe0889245e817f1ba104e69939bd/demo.png -------------------------------------------------------------------------------- /nfcats/__init__.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | ROOT_PATH = pathlib.Path(__file__).parent.absolute() 4 | DATA_PATH = ROOT_PATH / 'data' 5 | MODEL_PATH = ROOT_PATH / 'pretrained_model' 6 | -------------------------------------------------------------------------------- /nfcats/classifier.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | import torch 4 | from allennlp.data import TextFieldTensors, Vocabulary 5 | from allennlp.models.model import Model 6 | from allennlp.modules import FeedForward 7 | from allennlp.training.metrics import FBetaMeasure 8 | from overrides import overrides 9 | 10 | from nfcats.embedder import SentenceEmbedder 11 | 12 | 13 | @Model.register('nfq_cats_classifier') 14 | class NFQCatsClassifier(Model): 15 | default_predictor = 'text_classifier' 16 | label_namespace = 'labels' 17 | namespace = 'tokens' 18 | 19 | def __init__( 20 | self, 21 | vocab: Vocabulary, 22 | embedder: Optional[SentenceEmbedder] = None, 23 | embeddings_dim: Optional[int] = None, 24 | dropout: Optional[float] = None, 25 | feedforward: Optional[FeedForward] = None, 26 | ) -> None: 27 | super().__init__(vocab) 28 | self._embedder = embedder 29 | self._feedforward = feedforward 30 | 31 | if feedforward is not None: 32 | self._classifier_input_dim = self._feedforward.get_output_dim() 33 | elif self._embedder is not None: 34 | self._classifier_input_dim = self._embedder.get_output_dim() 35 | else: 36 | self._classifier_input_dim = embeddings_dim 37 | 38 | if self._embedder is None and embeddings_dim is None: 39 | raise ValueError('You must pass `Embedder` or `embeddings_dim`') 40 | 41 | self._dropout = torch.nn.Dropout(dropout) if dropout else None 42 | 43 | self._num_labels = vocab.get_vocab_size(namespace=self.label_namespace) 44 | self._classification_layer = torch.nn.Linear(self._classifier_input_dim, self._num_labels) 45 | 46 | def get_metrics(self, reset: bool = False) -> Dict[str, float]: 47 | metrics = super().get_metrics(reset) 48 | metrics['weighted_f1'] = FBetaMeasure(beta=1.0, average='weighted', labels=None) 49 | return metrics 50 | 51 | @overrides 52 | def forward( 53 | self, 54 | tokens: Optional[TextFieldTensors] = None, 55 | texts: Optional[List[str]] = None, 56 | embeddings: Optional[torch.FloatTensor] = None, 57 | label: Optional[torch.IntTensor] = None, 58 | ) -> Dict[str, torch.Tensor]: 59 | 60 | if self._embedder is not None: 61 | device = self._get_prediction_device() 62 | embeddings = self._embedder.forward(tokens=tokens, texts=texts) 63 | 64 | if device != -1: 65 | embeddings = embeddings.to(device) 66 | 67 | if self._dropout: 68 | embeddings = self._dropout(embeddings) # pylint: disable=not-callable 69 | 70 | if self._feedforward is not None: 71 | embeddings = self._feedforward(embeddings) 72 | 73 | logits = self._classification_layer(embeddings) 74 | output = {'logits': logits} 75 | 76 | if label is not None: 77 | output['loss'] = torch.nn.functional.cross_entropy(logits, label.long().view(-1)) 78 | 79 | return output 80 | 81 | def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, Any]: 82 | logits = output_dict['logits'] 83 | idx2label = self.vocab.get_index_to_token_vocabulary(self.label_namespace) 84 | probabilities = torch.nn.functional.softmax(logits, dim=-1) 85 | top_p, top_indsx = probabilities.topk(1, dim=1) 86 | 87 | output = {'label': [idx2label.get(idx, str(idx)) for idxs in top_indsx.cpu().tolist() for idx in idxs]} 88 | 89 | return output 90 | -------------------------------------------------------------------------------- /nfcats/dataset_reader.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import logging 3 | from typing import Dict 4 | 5 | from allennlp.common.file_utils import cached_path 6 | from allennlp.data.dataset_readers import TextClassificationJsonReader 7 | from allennlp.data.dataset_readers.dataset_reader import DatasetReader 8 | from allennlp.data.token_indexers import PretrainedTransformerIndexer, TokenIndexer 9 | from allennlp.data.tokenizers import PretrainedTransformerTokenizer, Tokenizer 10 | from overrides import overrides 11 | 12 | LOGGER = logging.getLogger(__name__) 13 | 14 | 15 | @DatasetReader.register('csv_text_label') 16 | class TextClassificationCsvReader(TextClassificationJsonReader): 17 | _default_transformer_model = 'bert-base-uncased' 18 | 19 | def __init__( 20 | self, 21 | lower: bool = True, 22 | sep: str = ',', 23 | label_field: str = 'label', 24 | text_field: str = 'text', 25 | filter_rows_by_values: Dict = None, 26 | token_indexers: Dict[str, TokenIndexer] = None, 27 | tokenizer: Tokenizer = None, 28 | **kwargs, 29 | ) -> None: 30 | super().__init__( 31 | tokenizer=tokenizer or PretrainedTransformerTokenizer(self._default_transformer_model), 32 | token_indexers=token_indexers or {'tokens': PretrainedTransformerIndexer(self._default_transformer_model)}, 33 | **kwargs, 34 | ) 35 | self._filters = filter_rows_by_values or {} 36 | self._sep = sep 37 | self._lower = lower 38 | self._label, self._text = label_field, text_field 39 | 40 | @overrides 41 | def _read(self, file_path: str): 42 | file_path = cached_path(file_path) # if `file_path` is a URL, redirect to the cache 43 | 44 | LOGGER.info('Reading file at %s', file_path) 45 | 46 | with open(file_path) as dataset_file: 47 | data = csv.DictReader(dataset_file, delimiter=self._sep) 48 | for row in data: 49 | for field_name, field_value in self._filters.items(): 50 | if row[field_name] != field_value: 51 | break 52 | else: 53 | text = row[self._text].lower() if self._lower else row[self._text] 54 | instance = self.text_to_instance(text=text, label=row[self._label]) 55 | if instance is not None: 56 | yield instance 57 | -------------------------------------------------------------------------------- /nfcats/embedder.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=W0223 2 | from abc import ABC, abstractmethod 3 | from typing import List, Optional 4 | 5 | import torch 6 | from allennlp.common import Registrable 7 | from allennlp.data import TextFieldTensors 8 | from allennlp.modules import Seq2SeqEncoder, Seq2VecEncoder, TextFieldEmbedder 9 | from allennlp.nn.util import get_text_field_mask 10 | 11 | 12 | class SentenceEmbedder(Registrable, ABC): 13 | @abstractmethod 14 | def get_output_dim(self) -> int: 15 | pass 16 | 17 | @abstractmethod 18 | def forward(self, tokens: Optional[TextFieldTensors] = None, texts: Optional[List[str]] = None) -> torch.Tensor: 19 | pass 20 | 21 | 22 | @SentenceEmbedder.register('trainable') 23 | class TrainableEmbedder(torch.nn.Module, SentenceEmbedder): 24 | def __init__( 25 | self, 26 | text_field_embedder: TextFieldEmbedder, 27 | seq2vec_encoder: Seq2VecEncoder, 28 | seq2seq_encoder: Optional[Seq2SeqEncoder] = None, 29 | ) -> None: 30 | super().__init__() 31 | self._text_field_embedder = text_field_embedder 32 | self._seq2seq_encoder = seq2seq_encoder 33 | self._seq2vec_encoder = seq2vec_encoder 34 | 35 | def get_output_dim(self) -> int: 36 | return self._seq2vec_encoder.get_output_dim() 37 | 38 | def forward(self, tokens: Optional[TextFieldTensors] = None, texts: Optional[List[str]] = None) -> torch.Tensor: 39 | embedded_text = self._text_field_embedder(tokens) 40 | mask = get_text_field_mask(tokens) 41 | 42 | if self._seq2seq_encoder is not None: 43 | embedded_text = self._seq2seq_encoder(embedded_text, mask=mask) 44 | 45 | embedded_text = self._seq2vec_encoder(embedded_text, mask=mask) 46 | 47 | return embedded_text 48 | -------------------------------------------------------------------------------- /nfcats/nft_classifier.jsonnet: -------------------------------------------------------------------------------- 1 | local transformer_model = "deepset/roberta-base-squad2"; 2 | local transformer_dim = 768; 3 | 4 | { 5 | "dataset_reader": { 6 | "type": "csv_text_label", 7 | "label_field": "category", 8 | "text_field": "question", 9 | "tokenizer": { 10 | "type": "pretrained_transformer", 11 | "model_name": transformer_model, 12 | "add_special_tokens": true 13 | }, 14 | "token_indexers": { 15 | "tokens": { 16 | "type": "pretrained_transformer", 17 | "model_name": transformer_model, 18 | "max_length": 512 19 | } 20 | } 21 | }, 22 | "train_data_path": std.extVar("TRAIN_DATA_PATH"), 23 | "validation_data_path": std.extVar("VALID_DATA_PATH"), 24 | "model": { 25 | "type": "nfq_cats_classifier", 26 | "embedder": { 27 | "type": "trainable", 28 | "text_field_embedder": { 29 | "token_embedders": { 30 | "tokens": { 31 | "type": "pretrained_transformer", 32 | "model_name": transformer_model, 33 | "max_length": 512, 34 | "train_parameters": true 35 | } 36 | } 37 | }, 38 | "seq2vec_encoder": { 39 | "type": "bert_pooler", 40 | "pretrained_model": transformer_model, 41 | "dropout": 0.0, 42 | } 43 | }, 44 | "feedforward": { 45 | "input_dim": transformer_dim, 46 | "num_layers": 2, 47 | "hidden_dims": [768, 512], 48 | "activations": "mish", 49 | "dropout": 0.6 50 | }, 51 | "dropout": 0.3, 52 | }, 53 | "data_loader": { 54 | "batch_sampler": { 55 | "type": "balanced", 56 | "num_classes_per_batch": 8, 57 | "num_examples_per_class": 8 58 | }, 59 | }, 60 | "validation_data_loader": { 61 | "batch_size": 128, 62 | }, 63 | "trainer": { 64 | "num_epochs": 100, 65 | "patience": 5, 66 | "validation_metric": "+weighted_f1", 67 | "cuda_device": 0, 68 | "learning_rate_scheduler": { 69 | "type": "slanted_triangular", 70 | "cut_frac": 0.06 71 | }, 72 | "optimizer": { 73 | "type": "huggingface_adamw", 74 | "lr": 2e-5, 75 | "weight_decay": 0.1, 76 | }, 77 | "callbacks": ["wandb"], 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /nfcats/predict.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=F401 2 | import pandas as pd 3 | import tqdm 4 | from allennlp.models.archival import load_archive 5 | from allennlp.predictors.predictor import Predictor 6 | 7 | from nfcats import DATA_PATH, MODEL_PATH 8 | from nfcats.classifier import NFQCatsClassifier 9 | from nfcats.dataset_reader import TextClassificationCsvReader 10 | from nfcats.sampler import BalancedBatchSampler 11 | from nfcats.utils import pandas_classification_report 12 | from nfcats.wandb_callback import WnBCallback 13 | 14 | 15 | def main(): 16 | archive = load_archive(MODEL_PATH, cuda_device=-1) 17 | predictor = Predictor.from_archive( 18 | archive, predictor_name='text_classifier', dataset_reader_to_load='csv_text_label' 19 | ) 20 | 21 | test_df = pd.read_csv(f'{DATA_PATH}/test.csv', sep=',') 22 | 23 | y_pred = [predictor.predict(q) for q in tqdm.tqdm(test_df.question)] 24 | print(pandas_classification_report(y_true=test_df.category, y_pred=[p['label'] for p in y_pred])) 25 | 26 | print(predictor.predict('why do we need a taxonomy?')) 27 | 28 | 29 | if __name__ == "__main__": 30 | main() 31 | -------------------------------------------------------------------------------- /nfcats/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import warnings 4 | from typing import Iterable, Iterator, List, Sequence, TypeVar 5 | 6 | import numpy as np 7 | from allennlp.data import Instance 8 | from allennlp.data.samplers import BatchSampler 9 | 10 | Example = TypeVar('Example') 11 | 12 | 13 | def iterate_random_batches(data: Iterable[Example], batch_size: int) -> Iterator[List[Example]]: 14 | """ 15 | Uniformly sample random batches of the same size from the data indefinitely (without replacement) 16 | 17 | Args: 18 | data: Iterable with data examples 19 | batch_size: Batch size to use for all batches 20 | 21 | Returns: 22 | Iterator over batches 23 | """ 24 | population = list(data) 25 | 26 | if len(population) < batch_size: 27 | raise ValueError(f'Population size {len(population)} must be greater than batch size {batch_size}') 28 | 29 | seen: List[Example] = [] 30 | while True: 31 | random.shuffle(population) 32 | 33 | num_full, num_trailing = divmod(len(population), batch_size) 34 | 35 | for start in range(0, num_full * batch_size, batch_size): 36 | batch = population[start : start + batch_size] 37 | seen.extend(batch) 38 | yield batch 39 | 40 | if num_trailing > 0: 41 | trailing = population[-num_trailing:] 42 | random.shuffle(seen) 43 | num_missing = batch_size - num_trailing 44 | seen, population = seen[:num_missing], seen[num_missing:] + trailing 45 | yield trailing + seen 46 | else: 47 | population = seen 48 | seen = [] 49 | 50 | 51 | @BatchSampler.register('balanced') 52 | class BalancedBatchSampler(BatchSampler): 53 | def __init__(self, num_classes_per_batch: int = 8, num_examples_per_class: int = 32) -> None: 54 | 55 | self._num_classes_per_batch = num_classes_per_batch 56 | self._num_examples_per_class = num_examples_per_class 57 | 58 | def get_batch_indices(self, instances: Sequence[Instance]) -> Iterable[List[int]]: 59 | labels = np.array([instance.fields['label'].label for instance in instances]) 60 | unique_labels, counts = np.unique(labels, return_counts=True) 61 | unique_labels = list(unique_labels) 62 | 63 | min_number_of_examples = counts.min() 64 | if self._num_examples_per_class > min_number_of_examples: 65 | warnings.warn( 66 | f'Setting `num_examples_per_class` to {min_number_of_examples}, ' 67 | f'since there are classes with less examples than {self._num_examples_per_class}' 68 | ) 69 | self._num_examples_per_class = min_number_of_examples 70 | 71 | num_unique_classes = len(unique_labels) 72 | if self._num_classes_per_batch > num_unique_classes: 73 | warnings.warn( 74 | f'Setting `num_classes_per_batch` to {num_unique_classes}, ' 75 | f'since there are only {num_unique_classes} classes (not {self._num_classes_per_batch})' 76 | ) 77 | self._num_classes_per_batch = num_unique_classes 78 | 79 | class_examples_generators = { 80 | label: iterate_random_batches(np.flatnonzero(labels == label), self._num_examples_per_class) 81 | for label in unique_labels 82 | } 83 | 84 | batch_classes_generator = iterate_random_batches(unique_labels, self._num_classes_per_batch) 85 | 86 | for _ in range(self.get_num_batches(instances)): 87 | batch = [] 88 | chosen_labels = next(batch_classes_generator) # noqa # pylint: disable=stop-iteration-return 89 | for label in chosen_labels: 90 | batch.extend(next(class_examples_generators[label])) # noqa # pylint: disable=stop-iteration-return 91 | yield batch 92 | 93 | def get_num_batches(self, instances: Sequence[Instance]) -> int: 94 | return math.ceil(len(instances) / (self._num_classes_per_batch * self._num_examples_per_class)) 95 | -------------------------------------------------------------------------------- /nfcats/tf_idf.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import csv 3 | import json 4 | import logging 5 | import pathlib 6 | import pickle 7 | import shutil 8 | import threading 9 | import time 10 | from typing import List, Tuple 11 | 12 | import dill 13 | import optuna 14 | from sklearn.feature_extraction.text import TfidfVectorizer 15 | from sklearn.linear_model import LogisticRegression 16 | from sklearn.metrics import f1_score 17 | from sklearn.pipeline import FeatureUnion, Pipeline 18 | 19 | from nfcats import DATA_PATH, ROOT_PATH 20 | 21 | LOGGER = logging.getLogger('tf_idf') 22 | 23 | 24 | class TrialSaver: 25 | def __init__(self, folder: pathlib.Path, num_best: int = 1): 26 | if num_best < 1: 27 | raise ValueError('num_best must be greater than 0') 28 | 29 | self._folder = folder 30 | self._num_best = num_best 31 | 32 | self._lock = threading.Lock() 33 | self._results: List[float] = [] 34 | self._folders: List[pathlib.Path] = [] 35 | 36 | def save(self, trial: optuna.Trial, pipeline: Pipeline, result: float): 37 | with self._lock: 38 | result = -result # negation allows to keep the smallest result at the end of the sorted result list 39 | if len(self._results) >= self._num_best and self._results[-1] <= result: 40 | return 41 | 42 | trial_folder = self._folder / f'trial_{trial.number}' 43 | trial_folder.mkdir(parents=True, exist_ok=True) 44 | 45 | with (trial_folder / f'{trial.number}.model').open('wb') as f: 46 | dill.dump(pipeline, f, protocol=pickle.HIGHEST_PROTOCOL) 47 | 48 | with (trial_folder / 'params.json').open('w', encoding='utf-8') as f: 49 | json.dump(trial.params, f, ensure_ascii=False, indent=4) 50 | 51 | index = bisect.bisect_left(self._results, result) 52 | self._results.insert(index, result) 53 | self._folders.insert(index, trial_folder) 54 | 55 | while len(self._results) > self._num_best: 56 | self._results.pop() 57 | shutil.rmtree(self._folders.pop()) 58 | 59 | 60 | def get_word_vectorizer(trial: optuna.Trial) -> TfidfVectorizer: 61 | max_ngram = trial.suggest_int('features.word.max_ngram', 1, 2) 62 | min_df = trial.suggest_int('features.word.min_df', 2, 50) 63 | 64 | return TfidfVectorizer( 65 | analyzer='word', 66 | min_df=min_df, 67 | ngram_range=(1, max_ngram), 68 | sublinear_tf=trial.suggest_categorical('features.word.sublinear_tf', [False, True]), 69 | use_idf=trial.suggest_categorical('features.word.use_idf', [False, True]), 70 | smooth_idf=trial.suggest_categorical('features.word.smooth_idf', [False, True]), 71 | binary=trial.suggest_categorical('features.word.binary', [False, True]), 72 | ) 73 | 74 | 75 | def get_char_vectorizer(trial: optuna.Trial) -> TfidfVectorizer: 76 | min_ngram = trial.suggest_int('features.char.min_ngram', 1, 5) 77 | max_ngram = trial.suggest_int('features.char.max_ngram', min_ngram, 5) 78 | min_df = trial.suggest_int('features.char.min_df', 2, 50) 79 | 80 | if trial.suggest_categorical('features.char.word_boundaries', [False, True]): 81 | 82 | def char_analyzer(tokens): 83 | return ' '.join(f'<{token}>' for token in tokens) 84 | 85 | else: 86 | 87 | def char_analyzer(tokens): 88 | return ' '.join(tokens) 89 | 90 | return TfidfVectorizer( 91 | analyzer=char_analyzer, 92 | min_df=min_df, 93 | ngram_range=(min_ngram, max_ngram), 94 | sublinear_tf=trial.suggest_categorical('features.char.sublinear_tf', [False, True]), 95 | use_idf=trial.suggest_categorical('features.char.use_idf', [False, True]), 96 | smooth_idf=trial.suggest_categorical('features.char.smooth_idf', [False, True]), 97 | binary=trial.suggest_categorical('features.char.binary', [False, True]), 98 | ) 99 | 100 | 101 | def parametrize(trial: optuna.Trial): 102 | feature_type = trial.suggest_categorical('features.type', ['word', 'char', 'all']) 103 | 104 | featurizers = [] 105 | if feature_type in ['word', 'all']: 106 | featurizers.append(('word', get_word_vectorizer(trial))) 107 | if feature_type in ['char', 'all']: 108 | featurizers.append(('char', get_char_vectorizer(trial))) 109 | 110 | features = FeatureUnion(featurizers) 111 | 112 | classifier = LogisticRegression( 113 | solver='lbfgs', 114 | multi_class='multinomial', 115 | max_iter=1000000, 116 | random_state=42, 117 | class_weight=trial.suggest_categorical('class_weight', [None, 'balanced']), 118 | C=trial.suggest_loguniform('C', low=1e-3, high=1e3), 119 | ) 120 | 121 | return Pipeline([('features', features), ('classifier', classifier)]) 122 | 123 | 124 | def load_question_dataset(path: pathlib.Path) -> Tuple[List[str], List[str]]: 125 | with path.open(encoding='utf-8') as f: 126 | data = list(csv.DictReader(f)) 127 | 128 | texts = [item['question'] for item in data] 129 | labels = [item['category'] for item in data] 130 | return texts, labels 131 | 132 | 133 | def main(): 134 | logging.basicConfig(format='%(asctime)s [%(name)s] %(levelname)s - %(message)s', level=logging.INFO) 135 | 136 | train_texts, train_labels = load_question_dataset(DATA_PATH / 'train.csv') 137 | test_texts, test_labels = load_question_dataset(DATA_PATH / 'test.csv') 138 | val_texts, val_labels = load_question_dataset(DATA_PATH / 'valid.csv') 139 | 140 | trial_saver = TrialSaver(num_best=5, folder=ROOT_PATH / 'tf_idf_models') 141 | 142 | def run_trial(trial: optuna.Trial) -> float: 143 | pipeline = parametrize(trial) 144 | 145 | start = time.perf_counter() 146 | pipeline.fit(train_texts, train_labels) 147 | training_time = time.perf_counter() - start 148 | LOGGER.info('Training time: %.3fs', training_time) 149 | 150 | test_score = f1_score(y_true=test_labels, y_pred=pipeline.predict(test_texts), average='macro') 151 | val_score = f1_score(y_true=val_labels, y_pred=pipeline.predict(val_texts), average='macro') 152 | 153 | trial.user_attrs['test'] = test_score 154 | trial.user_attrs['val'] = val_score 155 | 156 | LOGGER.info('Trial %i test score: %s', trial.number, test_score) 157 | LOGGER.info('Trial %i val score: %s', trial.number, val_score) 158 | 159 | trial_saver.save(trial=trial, pipeline=pipeline, result=val_score) 160 | return val_score 161 | 162 | study = optuna.create_study( 163 | storage=f'sqlite:///{ROOT_PATH}/hypertuning.db', study_name='tf-idf', direction='maximize', load_if_exists=True, 164 | ) 165 | study.optimize(run_trial, n_trials=1000) 166 | 167 | 168 | if __name__ == "__main__": 169 | main() 170 | -------------------------------------------------------------------------------- /nfcats/train.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=F401 2 | import json 3 | import logging 4 | from datetime import datetime 5 | from pathlib import Path 6 | 7 | import wandb 8 | from allennlp.commands.train import train_model 9 | from allennlp.common import Params 10 | from allennlp.models import load_archive 11 | 12 | from nfcats import DATA_PATH, ROOT_PATH 13 | from nfcats.classifier import NFQCatsClassifier 14 | from nfcats.dataset_reader import TextClassificationCsvReader 15 | from nfcats.sampler import BalancedBatchSampler 16 | from nfcats.wandb_callback import WnBCallback 17 | 18 | WB_LOGIN = 'your_login' 19 | 20 | 21 | def main(): 22 | logging.basicConfig(format='%(asctime)s [%(name)s] %(levelname)s - %(message)s', level=logging.DEBUG) 23 | 24 | model_params = ROOT_PATH / 'nft_classifier.jsonnet' 25 | 26 | wandb.init(entity=WB_LOGIN, project='non-factoid-classification', reinit=True, name='roberta-tuned-on-squad') 27 | params = Params.from_file( 28 | model_params, 29 | ext_vars={ 30 | 'TRAIN_DATA_PATH': f'{DATA_PATH}/train.csv', 31 | 'VALID_DATA_PATH': f'{DATA_PATH}/valid.csv', 32 | 'EPOCHS': '10', 33 | 'DEVICE': '0', 34 | }, 35 | ) 36 | 37 | date = datetime.utcnow().strftime('%H%M%S-%d%m') 38 | serialization_dir = Path(f'./logs/{date}_{model_params}') 39 | 40 | train_model(params=params, serialization_dir=str(serialization_dir)) 41 | 42 | wandb.config.update( # noqa: E1101 pylint: disable=no-member 43 | {'serialization_dir': str(serialization_dir), **params.as_flat_dict()} 44 | ) 45 | 46 | with (serialization_dir / 'metrics.json').open() as f: 47 | metrics = json.load(f) 48 | wandb.run.summary.update(metrics) 49 | 50 | wandb.save(str(serialization_dir / 'model.tar.gz')) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /nfcats/utils.py: -------------------------------------------------------------------------------- 1 | import enum 2 | 3 | import pandas as pd 4 | from sklearn.metrics import accuracy_score, precision_recall_fscore_support 5 | 6 | PLACEHOLDER = ' ' 7 | 8 | 9 | class MetricAggregation(str, enum.Enum): 10 | WEIGHTED = 'weighted' 11 | MACRO = 'macro' 12 | MICRO = 'micro' 13 | 14 | 15 | class Metric(str, enum.Enum): 16 | PRECISION = 'precision' 17 | RECALL = 'recall' 18 | F1 = 'f1-score' 19 | F0_5 = 'f0.5-score' 20 | ACCURACY = 'accuracy' 21 | 22 | 23 | def pandas_classification_report(y_true, y_pred) -> pd.DataFrame: 24 | """ 25 | Create a report with classification metrics (precision, recall, f-score, and accuracy) 26 | """ 27 | precision, recall, f1, support = precision_recall_fscore_support(y_true=y_true, y_pred=y_pred) 28 | 29 | avg_precision, avg_recall, avg_f1, _ = precision_recall_fscore_support( 30 | y_true=y_true, y_pred=y_pred, average='weighted' 31 | ) 32 | 33 | macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support( 34 | y_true=y_true, y_pred=y_pred, average='macro' 35 | ) 36 | 37 | micro_precision, micro_recall, micro_f1, _ = precision_recall_fscore_support( 38 | y_true=y_true, y_pred=y_pred, average='micro' 39 | ) 40 | 41 | _, _, f05, _ = precision_recall_fscore_support(y_true=y_true, y_pred=y_pred, beta=0.5) 42 | _, _, avg_f05, _ = precision_recall_fscore_support(y_true=y_true, y_pred=y_pred, beta=0.5, average='weighted') 43 | _, _, macro_f05, _ = precision_recall_fscore_support(y_true=y_true, y_pred=y_pred, beta=0.5, average='macro') 44 | _, _, micro_f05, _ = precision_recall_fscore_support(y_true=y_true, y_pred=y_pred, beta=0.5, average='micro') 45 | 46 | metrics_sum_index = ['precision', 'recall', 'f1-score', 'f0.5-score', 'support'] 47 | labels = sorted(set(y_true).union(y_pred)) 48 | 49 | class_report_df = pd.DataFrame([precision, recall, f1, f05, support], columns=labels, index=metrics_sum_index) 50 | 51 | support = class_report_df.loc['support'] 52 | total = support.sum() 53 | class_report_df[PLACEHOLDER] = [PLACEHOLDER] * 5 54 | class_report_df[MetricAggregation.MACRO.value] = macro_precision, macro_recall, macro_f1, macro_f05, total 55 | class_report_df[MetricAggregation.WEIGHTED.value] = avg_precision, avg_recall, avg_f1, avg_f05, total 56 | class_report_df[MetricAggregation.MICRO.value] = micro_precision, micro_recall, micro_f1, micro_f05, total 57 | class_report_df[Metric.ACCURACY.value] = ( 58 | accuracy_score(y_true, y_pred), 59 | PLACEHOLDER, 60 | PLACEHOLDER, 61 | PLACEHOLDER, 62 | PLACEHOLDER, 63 | ) 64 | 65 | return class_report_df.T 66 | -------------------------------------------------------------------------------- /nfcats/wandb_callback.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | import wandb 4 | from allennlp.training import GradientDescentTrainer, TrackEpochCallback, TrainerCallback 5 | 6 | DEFAULT_METRICS = [ 7 | 'training_loss', 8 | 'training_reg_loss', 9 | 'training_f1', 10 | 'training_accuracy', 11 | 'training_weighted_f05', 12 | 'training_pos_weighted_f05', 13 | 'validation_loss', 14 | 'validation_f1', 15 | 'validation_accuracy', 16 | 'validation_weighted_f05', 17 | 'validation_pos_weighted_f05', 18 | ] 19 | 20 | 21 | @TrainerCallback.register('wandb') 22 | class WnBCallback(TrackEpochCallback): 23 | def __init__(self, metrics_to_include: Optional[List[str]] = None, **kwargs): 24 | super().__init__(**kwargs) 25 | self._metrics_to_include = metrics_to_include or DEFAULT_METRICS 26 | 27 | def on_epoch( 28 | self, trainer: 'GradientDescentTrainer', metrics: Dict[str, Any], epoch: int, is_primary: bool = True, **kwargs, 29 | ) -> None: 30 | if is_primary: 31 | wandb.log({key: val for key, val in metrics.items() if key in self._metrics_to_include}) 32 | super().on_epoch(trainer, metrics, epoch, is_primary, **kwargs) 33 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "nfcats" 3 | version = "1.0.0" 4 | description = "Non-Factoid Question Classification" 5 | authors = [] 6 | repository = "https://github.com/Lurunchik/NF-CATS" 7 | readme = "README.md" 8 | license = "MIT" 9 | classifiers = [ 10 | "Programming Language :: Python :: 3.7", 11 | "Programming Language :: Python :: 3.8", 12 | "Programming Language :: Python :: 3.9", 13 | "Programming Language :: Python :: Implementation :: CPython", 14 | ] 15 | 16 | [tool.poetry.dependencies] 17 | python = "^3.7.0" 18 | torch = "*" 19 | scikit-learn = "*" 20 | spacy = "2.3.5" 21 | dill = "*" 22 | optuna = "*" 23 | allennlp = "2.0.1" 24 | allennlp-models = "2.0.1" 25 | wandb = "^0.10.18" 26 | pandas = "*" 27 | overrides = "^3.1.0" 28 | 29 | [tool.poetry.dev-dependencies] 30 | mypy = "^0.740" 31 | flake8 = "^3.7.9" 32 | flake8-isort = "^2.7.0" 33 | flake8-builtins = "^1.4.1" 34 | flake8-comprehensions = "^3.0.1" 35 | flake8-debugger = "^3.2.1" 36 | flake8-eradicate = "^0.2.3" 37 | pep8-naming = "^0.9.0" 38 | black = "^19.10b0" 39 | isort = "^4.3.21" 40 | unify = "^0.5" 41 | pylint = "^2.4.3" 42 | 43 | [tool.black] 44 | line-length = 88 45 | target-version = ["py37"] 46 | 47 | [build-system] 48 | requires = ["poetry>=1.0.5"] 49 | build-backend = "poetry.masonry.api" 50 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-complexity = 20 3 | show-source = true 4 | exclude = .git __pycache__ setup.py *.txt 5 | enable-extensions = G 6 | 7 | ignore = 8 | E203 ; Whitespace 9 | W293 ; Blank line contains whitespace 10 | 11 | [isort] 12 | multi_line_output = 3 13 | include_trailing_comma = True 14 | force_grid_wrap = 0 15 | use_parentheses = True 16 | balanced_wrapping = true 17 | default_section = THIRDPARTY 18 | known_first_party = nfcats 19 | 20 | [pylint] 21 | max-module-lines = 500 22 | min-public-methods = 1 23 | output-format = colorized 24 | 25 | disable= 26 | C0330, ; Wrong hanging indentation before block (add 4 spaces) 27 | C0114, ; Missing module docstring 28 | E0401, ; Import error 29 | 30 | [mypy] 31 | python_version = 3.7 32 | ignore_missing_imports = True 33 | warn_unused_configs = True --------------------------------------------------------------------------------