├── .gitignore ├── README.md ├── __init__.py ├── config ├── ecthr_a │ ├── bert_default.yml │ ├── bert_reproduced.yml │ ├── bert_tune.yml │ ├── bert_tuned.yml │ └── l2svm.yml ├── ecthr_b │ ├── bert_default.yml │ ├── bert_reproduced.yml │ ├── bert_tune.yml │ ├── bert_tuned.yml │ └── l2svm.yml ├── eurlex │ ├── bert_default.yml │ ├── bert_reproduced.yml │ ├── bert_tune.yml │ ├── bert_tuned.yml │ └── l2svm.yml ├── ledgar │ ├── bert_default.yml │ ├── bert_reproduced.yml │ ├── bert_tune.yml │ ├── bert_tuned.yml │ └── l2svm.yml ├── scotus │ ├── bert_default.yml │ ├── bert_reproduced.yml │ ├── bert_tune.yml │ ├── bert_tuned.yml │ └── l2svm.yml └── unfair_tos │ ├── bert_default.yml │ ├── bert_reproduced.yml │ ├── bert_tune.yml │ ├── bert_tuned.yml │ └── l2svm.yml ├── generate_data.py ├── generate_data.sh ├── libmultilabel ├── __init__.py ├── common_utils.py ├── linear │ ├── __init__.py │ ├── linear.py │ ├── metrics.py │ ├── preprocessor.py │ └── utils.py └── nn │ ├── __init__.py │ ├── data_utils.py │ ├── metrics.py │ ├── model.py │ ├── networks │ ├── __init__.py │ ├── bert.py │ ├── bert_attention.py │ ├── caml.py │ ├── hierbert.py │ ├── kim_cnn.py │ ├── labelwise_attention_networks.py │ ├── modules.py │ └── xml_cnn.py │ └── nn_utils.py ├── linear_trainer.py ├── main.py ├── requirements.txt ├── requirements_parameter_search.txt ├── run_experiments.sh ├── search_params.py ├── search_params.sh └── torch_trainer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | runs/ 23 | wheels/ 24 | prof/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv*/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | .dmypy.json 124 | dmypy.json 125 | 126 | # Pyre type checker 127 | .pyre/ 128 | 129 | # jupyter 130 | *.ipynb 131 | 132 | *.swp 133 | .editorconfig 134 | .vscode/ 135 | 136 | dataset/* 137 | data/* 138 | output/* 139 | model_output/* 140 | *.xlsx 141 | tmp* 142 | *.csv 143 | *report*.txt 144 | *.pkl 145 | *.txt 146 | *.json 147 | *.bin 148 | *.lock 149 | *.pt 150 | *.zip 151 | *.tar.gz 152 | *.ckpt 153 | *.tfevents.* 154 | 155 | # Sphinx build 156 | docs/_build/* 157 | docs/cli/*.include 158 | !docs/requirements.txt 159 | !requirements.txt 160 | !requirements_parameter_search.txt 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Linear Classifier: An Often-Forgotten Baseline for Text Classification 2 | 3 | This is the code for the ACL 2023 paper "[Linear Classifier: An Often-Forgotten Baseline for Text Classification](https://www.csie.ntu.edu.tw/~cjlin/papers/text_classification_baseline/text_classification_baseline.pdf)". The repository is used to reproduce the experimental results in our paper. If you find our work useful, please consider citing the following paper: 4 | ```bib 5 | @InProceedings{YCL22a, 6 | author = {Yu-Chen Lin and Si-An Chen and Jie-Jyun Liu and Chih-Jen Lin}, 7 | title = {Linear Classifier: An Often-Forgotten Baseline for Text Classification}, 8 | booktitle = {Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (ACL)}, 9 | year = {2023}, 10 | url = {https://www.csie.ntu.edu.tw/~cjlin/papers/text_classification_baseline/text_classification_baseline.pdf}, 11 | note = {Short paper} 12 | } 13 | ``` 14 | Please feel free to contact [Yu-Chen Lin](mailto:b06504025@csie.ntu.edu.tw) if you have any questions about the code/paper. 15 | 16 | ## Setup Environment 17 | 18 | It is optional but highly recommended to create a virtual environment. For example, you can first refer to the [link](https://docs.conda.io/en/latest/miniconda.html) for the installation guidances of Miniconda and then create a virtual environment as follows. 19 | ```bash 20 | conda create -n baseline python=3.8 21 | conda activate baseline 22 | ``` 23 | 24 | We recommend using **Python 3.8** because that is the version we used in our experiments. Please also install the following packages. 25 | ```bash 26 | pip install -r requirements.txt 27 | pip install -r requirements_parameter_search.txt 28 | ``` 29 | 30 | If you have a different version of CUDA, follow the installation instructions for PyTorch LTS on their [website](https://pytorch.org/get-started/previous-versions/). 31 | 32 | ## Generate Data 33 | 34 | You can simply run the following script to generate all the needed data. 35 | ```bash 36 | bash generate_data.sh 37 | ``` 38 | There are three formats of data. For all the formats, we use the same sets as the ones used in [LexGLUE](https://github.com/coastalcph/lex-glue). We process the sets to [LibMultiLabel](https://github.com/ASUS-AICS/LibMultiLabel) format and then modify them to be runnable on LibMultiLabel. We briefly explain what we do as follows. 39 | * linear: We combine training and validation subsets as the new training set, so only training and test sets are available. 40 | * nn: Training, validation, and test sets are available. 41 | * hier (hierarchical): We need to employ the hierarchical setting of BERT to reproduce the results in Chalkidis et al. (2022). Specifically, we add a special symbol to the data and modify the code in LibMultiLabel for conducting the experiments of the hierarchical BERT. Training, validation, and test sets are available. 42 | 43 | ## Experimental Results 44 | 45 | In Table 2, we present our investigation on two types of approaches: Linear SVM and BERT. 46 | 47 | | | ECtHR (A) | ECtHR (B) | SCOTUS | EUR-LEX | LEDGAR | UNFAIR-ToS | 48 | |:------------------------|:-----------:|:-----------:|:-----------:|:-----------:|:-----------:|:----------:| 49 | | **Method** | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 | 50 | | **Linear** | 51 | | one-vs-rest | 64.0 / 53.1 | 72.8 / 63.9 | 78.1 / 68.9 | 72.0 / 55.4 | 86.4 / 80.0 | 94.9 / 75.1 | 52 | | thresholding | 68.6 / 64.9 | 76.1 / 68.7 | 78.9 / 71.5 | 74.7 / 62.7 | 86.2 / 79.9 | 95.1 / 79.9 | 53 | | cost-sensitive | 67.4 / 60.5 | 75.5 / 67.3 | 78.3 / 71.5 | 73.4 / 60.5 | 86.2 / 80.1 | 95.3 / 77.9 | 54 | | Chalkidis et al. (2022) | 64.5 / 51.7 | 74.6 / 65.1 | 78.2 / 69.5 | 71.3 / 51.4 | 87.2 / 82.4 | 95.4 / 78.8 | 55 | | **BERT** | 56 | | Ours | 61.9 / 55.6 | 69.8 / 60.5 | 67.1 / 55.9 | 70.8 / 55.3 | 87.0 / 80.7 | 95.4 / 80.3 | 57 | | Chalkidis et al. (2022) | 71.2 / 63.6 | 79.7 / 73.4 | 68.3 / 58.3 | 71.4 / 57.2 | 87.6 / 81.8 | 95.6 / 81.3 | 58 | 59 | In Table 9, we present additional results from BERT. Note that the **Ours** setting (BERT) in Table 2 is the same as the **tuned** setting (BERT in LibMultiLabel) in Table 9. 60 | 61 | | | ECtHR (A) | ECtHR (B) | SCOTUS | EUR-LEX | LEDGAR | UNFAIR-ToS | 62 | |:------------|:-----------:|:-----------:|:-----------:|:-----------:|:-----------:|:----------:| 63 | | **Method** | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 | 64 | | **BERT in LibMultiLabel** | 65 | | default | 60.5 / 53.4 | 68.9 / 60.8 | 66.3 / 54.8 | 70.8 / 55.3 | 85.2 / 77.9 | 95.2 / 78.2 | 66 | | tuned | 61.9 / 55.6 | 69.8 / 60.5 | 67.1 / 55.9 | 70.8 / 55.3 | 87.0 / 80.7 | 95.4 / 80.3 | 67 | | reproduced | 70.2 / 63.7 | 78.8 / 73.1 | 70.8 / 62.6 | 71.6 / 56.1 | 88.1 / 82.6 | 95.3 / 80.6 | 68 | | **BERT in Chalkidis et al. (2022)** | 69 | | paper | 71.2 / 63.6 | 79.7 / 73.4 | 68.3 / 58.3 | 71.4 / 57.2 | 87.6 / 81.8 | 95.6 / 81.3 | 70 | | reproduced | 70.8 / 64.8 | 78.7 / 72.5 | 70.9 / 61.9 | 71.7 / 57.9 | 87.7 / 82.1 | 95.6 / 80.3 | 71 | 72 | ## Run Experiments 73 | 74 | We show how to conduct the experiments of each method as follows. 75 | ```bash 76 | bash run_experiments.sh [DATA] [METHOD] 77 | ``` 78 | 79 | First, you need to determine which data and method you want to try. Then, you should refer to the following lookup table for the value of **\[DATA\]** and **\[METHOD\]** arguments. 80 | 81 | 82 | 83 | 95 | 108 | 109 |
84 | 85 | | Dataset | \[Data\] | 86 | |:-----------|:----------:| 87 | | ECtHR (A) | ecthr_a | 88 | | ECtHR (B) | ecthr_b | 89 | | SCOTUS | scotus | 90 | | EUR-LEX | eurlex | 91 | | LEDGAR | ledgar | 92 | | UNFAIR-ToS | unfair_tos | 93 | 94 | 96 | 97 | | Method | \[METHOD\] | 98 | |:--------------------------------|:----------------:| 99 | | Linear_one-vs-rest (Table 2) | 1vsrest | 100 | | Linear_thresholding (Table 2) | thresholding | 101 | | Linear_cost-sensitive (Table 2) | cost_sensitive | 102 | | BERT_Ours (Table 2) | bert_tuned | 103 | | BERT_default (Table 9) | bert_default | 104 | | BERT_tuned (Table 9) | bert_tuned | 105 | | BERT_reproduced (Table 9) | bert_reproduced | 106 | 107 |
110 | 111 | For example, if you want to use **thresholding** techniques on the set **UNFAIR-ToS**, you should run the following command. 112 | ```bash 113 | bash run_experiments.sh unfair_tos thresholding 114 | ``` 115 | If you aim to deal with the data set **ECtHR (B)** with the **BERT_default** setting, you should place the arguments like the following command. 116 | ```bash 117 | bash run_experiments.sh ecthr_b bert_default 118 | ``` 119 | 120 | Additional information is shown as follows. 121 | 122 | * For the **BERT_tuned** setting, we have already tuned the parameters for you. The script only runs the experiments using the tuned parameters. The running time will be different from Table 10 because, in Table 10, the time for the parameter search is also included. However, if you want to tune the parameters by yourself, we also provide a script and a search space configuration to do that. Please check the following command. 123 | ```bash 124 | # Conduct the hyper-parameter search 125 | bash search_params.sh [DATA] 126 | 127 | # Replace the given tuned configuration with the searched parameters 128 | mv runs/[DATA]_bert_tune_XXX/trial_best_params/params.yml config/[DATA]/bert_tuned.yml 129 | 130 | # Run the BERT_tuned setting 131 | bash run_experiments.sh [DATA] bert_tuned 132 | ``` 133 | * To conduct the **BERT_reproduced** method on the data sets **ECtHR (A)**, **ECtHR (B)**, and **SCOTUS**, you need a GPU that includes more than 16GB of GPU memory. 134 | 135 | ## Evaluation 136 | 137 | For comparison purposes, we followed Chalkidis et al. (2022) to deal with unlabeled datasets during the evaluation process, 138 | though this setting is not a standard practice in multi-label classification, nor is it supported by [LibMultiLabel](https://github.com/ASUS-AICS/LibMultiLabel). 139 | 140 | ## Reproducibility 141 | 142 | For our experimental results, linear methods were run on the CPU **Intel Xeon E5-2690**, while for BERT we used the GPU **Nvidia V100**. You may notice some minor differences in results between your running of our scripts and our paper results, especially on the BERT results. If you want to fully reproduce our results, you should carefully follow the items below. 143 | * Make sure you install our suggested package version. 144 | * Use the same device as ours. 145 | * Because our BERT results are based on the average results from five runs of different seeds (1,2,3,4,5), you should modify [run_experiments.sh](run_experiments.sh) and follow us to do five runs. 146 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JamesLYC88/text_classification_baseline_code/6436f567372f3c36547cd52da79516900e5a6148/__init__.py -------------------------------------------------------------------------------- /config/ecthr_a/bert_default.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/ecthr_a 3 | data_name: ecthr_a 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: 0.00005 15 | weight_decay: 0.001 16 | patience: 5 17 | 18 | # eval 19 | eval_batch_size: 8 20 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 21 | val_metric: Micro-F1 22 | 23 | # model 24 | model_name: BERT 25 | init_weight: null 26 | network_config: 27 | lm_weight: bert-base-uncased 28 | 29 | # pretrained vocab / embeddings 30 | embed_file: null 31 | -------------------------------------------------------------------------------- /config/ecthr_a/bert_reproduced.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_hier/ecthr_a 3 | data_name: ecthr_a 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | hierarchical: true 9 | accumulate_grad_batches: 4 10 | 11 | # train 12 | seed: 1337 13 | epochs: 15 14 | batch_size: 2 15 | optimizer: adamw 16 | learning_rate: 0.00003 17 | weight_decay: 0 18 | patience: 5 19 | 20 | # eval 21 | eval_batch_size: 2 22 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 23 | val_metric: Micro-F1 24 | 25 | # model 26 | model_name: BERT 27 | init_weight: null 28 | network_config: 29 | lm_weight: bert-base-uncased 30 | 31 | # pretrained vocab / embeddings 32 | embed_file: null 33 | -------------------------------------------------------------------------------- /config/ecthr_a/bert_tune.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/ecthr_a 3 | data_name: ecthr_a 4 | min_vocab_freq: 1 5 | max_seq_length: ['grid_search', [128, 512]] 6 | include_test_labels: true 7 | remove_no_label_data: false 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: ['grid_search', [0.00002, 0.00003, 0.00005]] 15 | momentum: 0 16 | weight_decay: ['grid_search', [0, 0.001]] 17 | patience: 5 18 | shuffle: true 19 | 20 | # eval 21 | eval_batch_size: 8 22 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 23 | val_metric: Micro-F1 24 | 25 | # model 26 | model_name: BERT 27 | init_weight: null 28 | network_config: 29 | dropout: ['grid_search', [0.1, 0.2]] 30 | lm_weight: bert-base-uncased 31 | 32 | # pretrained vocab / embeddings 33 | vocab_file: null 34 | embed_file: null 35 | normalize_embed: false 36 | 37 | # hyperparamter search 38 | search_alg: basic_variant 39 | embed_cache_dir: null 40 | num_samples: 1 41 | scheduler: null 42 | # Uncomment the following lines to enable the ASHAScheduler. 43 | # See the documentation here: https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-ashascheduler 44 | #scheduler: 45 | #time_attr: training_iteration 46 | #max_t: 50 # the maximum epochs to run for each config (parameter R in the ASHA paper) 47 | #grace_period: 10 # the minimum epochs to run for each config (parameter r in the ASHA paper) 48 | #reduction_factor: 3 # reduce the number of configuration to floor(1/reduction_factor) each round of successive halving (called rung in ASHA paper) 49 | #brackets: 1 # number of brackets. A smaller bracket index (parameter s in the ASHA paper) means earlier stopping (i.e., less total resources used) 50 | 51 | # other parameters specified in main.py::get_args 52 | checkpoint_path: null 53 | cpu: false 54 | data_workers: 4 55 | eval: false 56 | label_file: null 57 | limit_train_batches: 1.0 58 | limit_val_batches: 1.0 59 | limit_test_batches: 1.0 60 | metric_threshold: 0.5 61 | result_dir: runs 62 | save_k_predictions: 0 63 | silent: true 64 | test_path: null 65 | train_path: null 66 | val_path: null 67 | val_size: 0.2 68 | 69 | # LexGLUE 70 | zero: true 71 | multi_class: false 72 | add_special_tokens: true 73 | enable_ce_loss: false 74 | hierarchical: false 75 | accumulate_grad_batches: 1 76 | enable_transformer_trainer: false 77 | -------------------------------------------------------------------------------- /config/ecthr_a/bert_tuned.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/ecthr_a 3 | data_name: ecthr_a 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: 2e-05 15 | weight_decay: 0 16 | patience: 5 17 | 18 | # eval 19 | eval_batch_size: 8 20 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 21 | val_metric: Micro-F1 22 | 23 | # model 24 | model_name: BERT 25 | init_weight: null 26 | network_config: 27 | lm_weight: bert-base-uncased 28 | dropout: 0.1 29 | 30 | # pretrained vocab / embeddings 31 | embed_file: null 32 | -------------------------------------------------------------------------------- /config/ecthr_a/l2svm.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_linear/ecthr_a 3 | data_name: ecthr_a 4 | 5 | # train 6 | seed: 1337 7 | linear: true 8 | liblinear_options: "-s 2 -B 1 -e 0.0001 -q" 9 | linear_technique: 1vsrest 10 | 11 | # eval 12 | eval_batch_size: 256 13 | monitor_metrics: [Micro-F1, Macro-F1] 14 | metric_threshold: 0 15 | 16 | data_format: txt 17 | -------------------------------------------------------------------------------- /config/ecthr_b/bert_default.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/ecthr_b 3 | data_name: ecthr_b 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: 0.00005 15 | weight_decay: 0.001 16 | patience: 5 17 | 18 | # eval 19 | eval_batch_size: 8 20 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 21 | val_metric: Micro-F1 22 | 23 | # model 24 | model_name: BERT 25 | init_weight: null 26 | network_config: 27 | lm_weight: bert-base-uncased 28 | 29 | # pretrained vocab / embeddings 30 | embed_file: null 31 | -------------------------------------------------------------------------------- /config/ecthr_b/bert_reproduced.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_hier/ecthr_b 3 | data_name: ecthr_b 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | hierarchical: true 9 | accumulate_grad_batches: 4 10 | 11 | # train 12 | seed: 1337 13 | epochs: 15 14 | batch_size: 2 15 | optimizer: adamw 16 | learning_rate: 0.00003 17 | weight_decay: 0 18 | patience: 5 19 | 20 | # eval 21 | eval_batch_size: 2 22 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 23 | val_metric: Micro-F1 24 | 25 | # model 26 | model_name: BERT 27 | init_weight: null 28 | network_config: 29 | lm_weight: bert-base-uncased 30 | 31 | # pretrained vocab / embeddings 32 | embed_file: null 33 | -------------------------------------------------------------------------------- /config/ecthr_b/bert_tune.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/ecthr_b 3 | data_name: ecthr_b 4 | min_vocab_freq: 1 5 | max_seq_length: ['grid_search', [128, 512]] 6 | include_test_labels: true 7 | remove_no_label_data: false 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: ['grid_search', [0.00002, 0.00003, 0.00005]] 15 | momentum: 0 16 | weight_decay: ['grid_search', [0, 0.001]] 17 | patience: 5 18 | shuffle: true 19 | 20 | # eval 21 | eval_batch_size: 8 22 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 23 | val_metric: Micro-F1 24 | 25 | # model 26 | model_name: BERT 27 | init_weight: null 28 | network_config: 29 | dropout: ['grid_search', [0.1, 0.2]] 30 | lm_weight: bert-base-uncased 31 | 32 | # pretrained vocab / embeddings 33 | vocab_file: null 34 | embed_file: null 35 | normalize_embed: false 36 | 37 | # hyperparamter search 38 | search_alg: basic_variant 39 | embed_cache_dir: null 40 | num_samples: 1 41 | scheduler: null 42 | # Uncomment the following lines to enable the ASHAScheduler. 43 | # See the documentation here: https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-ashascheduler 44 | #scheduler: 45 | #time_attr: training_iteration 46 | #max_t: 50 # the maximum epochs to run for each config (parameter R in the ASHA paper) 47 | #grace_period: 10 # the minimum epochs to run for each config (parameter r in the ASHA paper) 48 | #reduction_factor: 3 # reduce the number of configuration to floor(1/reduction_factor) each round of successive halving (called rung in ASHA paper) 49 | #brackets: 1 # number of brackets. A smaller bracket index (parameter s in the ASHA paper) means earlier stopping (i.e., less total resources used) 50 | 51 | # other parameters specified in main.py::get_args 52 | checkpoint_path: null 53 | cpu: false 54 | data_workers: 4 55 | eval: false 56 | label_file: null 57 | limit_train_batches: 1.0 58 | limit_val_batches: 1.0 59 | limit_test_batches: 1.0 60 | metric_threshold: 0.5 61 | result_dir: runs 62 | save_k_predictions: 0 63 | silent: true 64 | test_path: null 65 | train_path: null 66 | val_path: null 67 | val_size: 0.2 68 | 69 | # LexGLUE 70 | zero: true 71 | multi_class: false 72 | add_special_tokens: true 73 | enable_ce_loss: false 74 | hierarchical: false 75 | accumulate_grad_batches: 1 76 | enable_transformer_trainer: false 77 | -------------------------------------------------------------------------------- /config/ecthr_b/bert_tuned.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/ecthr_b 3 | data_name: ecthr_b 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: 3e-05 15 | weight_decay: 0.001 16 | patience: 5 17 | 18 | # eval 19 | eval_batch_size: 8 20 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 21 | val_metric: Micro-F1 22 | 23 | # model 24 | model_name: BERT 25 | init_weight: null 26 | network_config: 27 | lm_weight: bert-base-uncased 28 | dropout: 0.2 29 | 30 | # pretrained vocab / embeddings 31 | embed_file: null 32 | -------------------------------------------------------------------------------- /config/ecthr_b/l2svm.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_linear/ecthr_b 3 | data_name: ecthr_b 4 | 5 | # train 6 | seed: 1337 7 | linear: true 8 | liblinear_options: "-s 2 -B 1 -e 0.0001 -q" 9 | linear_technique: 1vsrest 10 | 11 | # eval 12 | eval_batch_size: 256 13 | monitor_metrics: [Micro-F1, Macro-F1] 14 | metric_threshold: 0 15 | 16 | data_format: txt 17 | -------------------------------------------------------------------------------- /config/eurlex/bert_default.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/eurlex 3 | data_name: eurlex 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: 0.00005 15 | weight_decay: 0.001 16 | patience: 5 17 | 18 | # eval 19 | eval_batch_size: 8 20 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 21 | val_metric: Micro-F1 22 | 23 | # model 24 | model_name: BERT 25 | init_weight: null 26 | network_config: 27 | lm_weight: bert-base-uncased 28 | 29 | # pretrained vocab / embeddings 30 | embed_file: null 31 | -------------------------------------------------------------------------------- /config/eurlex/bert_reproduced.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/eurlex 3 | data_name: eurlex 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | 9 | # train 10 | seed: 1337 11 | epochs: 20 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: 0.00003 15 | weight_decay: 0 16 | patience: 5 17 | 18 | # eval 19 | eval_batch_size: 8 20 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 21 | val_metric: Micro-F1 22 | 23 | # model 24 | model_name: BERT 25 | init_weight: null 26 | network_config: 27 | lm_weight: bert-base-uncased 28 | 29 | # pretrained vocab / embeddings 30 | embed_file: null 31 | -------------------------------------------------------------------------------- /config/eurlex/bert_tune.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/eurlex 3 | data_name: eurlex 4 | min_vocab_freq: 1 5 | max_seq_length: ['grid_search', [128, 512]] 6 | include_test_labels: true 7 | remove_no_label_data: false 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: ['grid_search', [0.00002, 0.00003, 0.00005]] 15 | momentum: 0 16 | weight_decay: ['grid_search', [0, 0.001]] 17 | patience: 5 18 | shuffle: true 19 | 20 | # eval 21 | eval_batch_size: 8 22 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 23 | val_metric: Micro-F1 24 | 25 | # model 26 | model_name: BERT 27 | init_weight: null 28 | network_config: 29 | dropout: ['grid_search', [0.1, 0.2]] 30 | lm_weight: bert-base-uncased 31 | 32 | # pretrained vocab / embeddings 33 | vocab_file: null 34 | embed_file: null 35 | normalize_embed: false 36 | 37 | # hyperparamter search 38 | search_alg: basic_variant 39 | embed_cache_dir: null 40 | num_samples: 1 41 | scheduler: null 42 | # Uncomment the following lines to enable the ASHAScheduler. 43 | # See the documentation here: https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-ashascheduler 44 | #scheduler: 45 | #time_attr: training_iteration 46 | #max_t: 50 # the maximum epochs to run for each config (parameter R in the ASHA paper) 47 | #grace_period: 10 # the minimum epochs to run for each config (parameter r in the ASHA paper) 48 | #reduction_factor: 3 # reduce the number of configuration to floor(1/reduction_factor) each round of successive halving (called rung in ASHA paper) 49 | #brackets: 1 # number of brackets. A smaller bracket index (parameter s in the ASHA paper) means earlier stopping (i.e., less total resources used) 50 | 51 | # other parameters specified in main.py::get_args 52 | checkpoint_path: null 53 | cpu: false 54 | data_workers: 4 55 | eval: false 56 | label_file: null 57 | limit_train_batches: 1.0 58 | limit_val_batches: 1.0 59 | limit_test_batches: 1.0 60 | metric_threshold: 0.5 61 | result_dir: runs 62 | save_k_predictions: 0 63 | silent: true 64 | test_path: null 65 | train_path: null 66 | val_path: null 67 | val_size: 0.2 68 | 69 | # LexGLUE 70 | zero: false 71 | multi_class: false 72 | add_special_tokens: true 73 | enable_ce_loss: false 74 | hierarchical: false 75 | accumulate_grad_batches: 1 76 | enable_transformer_trainer: false 77 | -------------------------------------------------------------------------------- /config/eurlex/bert_tuned.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/eurlex 3 | data_name: eurlex 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: 5e-05 15 | weight_decay: 0.001 16 | patience: 5 17 | 18 | # eval 19 | eval_batch_size: 8 20 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 21 | val_metric: Micro-F1 22 | 23 | # model 24 | model_name: BERT 25 | init_weight: null 26 | network_config: 27 | lm_weight: bert-base-uncased 28 | dropout: 0.1 29 | 30 | # pretrained vocab / embeddings 31 | embed_file: null 32 | -------------------------------------------------------------------------------- /config/eurlex/l2svm.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_linear/eurlex 3 | data_name: eurlex 4 | 5 | # train 6 | seed: 1337 7 | linear: true 8 | liblinear_options: "-s 2 -B 1 -e 0.0001 -q" 9 | linear_technique: 1vsrest 10 | 11 | # eval 12 | eval_batch_size: 256 13 | monitor_metrics: [Micro-F1, Macro-F1] 14 | metric_threshold: 0 15 | 16 | data_format: txt 17 | -------------------------------------------------------------------------------- /config/ledgar/bert_default.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/ledgar 3 | data_name: ledgar 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: 0.00005 15 | weight_decay: 0.001 16 | patience: 5 17 | 18 | # eval 19 | eval_batch_size: 8 20 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 21 | val_metric: Micro-F1 22 | 23 | # model 24 | model_name: BERT 25 | init_weight: null 26 | network_config: 27 | lm_weight: bert-base-uncased 28 | 29 | # pretrained vocab / embeddings 30 | embed_file: null 31 | -------------------------------------------------------------------------------- /config/ledgar/bert_reproduced.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/ledgar 3 | data_name: ledgar 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | 9 | # train 10 | seed: 1337 11 | epochs: 20 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: 0.00003 15 | weight_decay: 0 16 | patience: 5 17 | 18 | # eval 19 | eval_batch_size: 8 20 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 21 | val_metric: Micro-F1 22 | 23 | # model 24 | model_name: BERT 25 | init_weight: null 26 | network_config: 27 | lm_weight: bert-base-uncased 28 | 29 | # pretrained vocab / embeddings 30 | embed_file: null 31 | -------------------------------------------------------------------------------- /config/ledgar/bert_tune.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/ledgar 3 | data_name: ledgar 4 | min_vocab_freq: 1 5 | max_seq_length: ['grid_search', [128, 512]] 6 | include_test_labels: true 7 | remove_no_label_data: false 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: ['grid_search', [0.00002, 0.00003, 0.00005]] 15 | momentum: 0 16 | weight_decay: ['grid_search', [0, 0.001]] 17 | patience: 5 18 | shuffle: true 19 | 20 | # eval 21 | eval_batch_size: 8 22 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 23 | val_metric: Micro-F1 24 | 25 | # model 26 | model_name: BERT 27 | init_weight: null 28 | network_config: 29 | dropout: ['grid_search', [0.1, 0.2]] 30 | lm_weight: bert-base-uncased 31 | 32 | # pretrained vocab / embeddings 33 | vocab_file: null 34 | embed_file: null 35 | normalize_embed: false 36 | 37 | # hyperparamter search 38 | search_alg: basic_variant 39 | embed_cache_dir: null 40 | num_samples: 1 41 | scheduler: null 42 | # Uncomment the following lines to enable the ASHAScheduler. 43 | # See the documentation here: https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-ashascheduler 44 | #scheduler: 45 | #time_attr: training_iteration 46 | #max_t: 50 # the maximum epochs to run for each config (parameter R in the ASHA paper) 47 | #grace_period: 10 # the minimum epochs to run for each config (parameter r in the ASHA paper) 48 | #reduction_factor: 3 # reduce the number of configuration to floor(1/reduction_factor) each round of successive halving (called rung in ASHA paper) 49 | #brackets: 1 # number of brackets. A smaller bracket index (parameter s in the ASHA paper) means earlier stopping (i.e., less total resources used) 50 | 51 | # other parameters specified in main.py::get_args 52 | checkpoint_path: null 53 | cpu: false 54 | data_workers: 4 55 | eval: false 56 | label_file: null 57 | limit_train_batches: 1.0 58 | limit_val_batches: 1.0 59 | limit_test_batches: 1.0 60 | metric_threshold: 0.5 61 | result_dir: runs 62 | save_k_predictions: 0 63 | silent: true 64 | test_path: null 65 | train_path: null 66 | val_path: null 67 | val_size: 0.2 68 | 69 | # LexGLUE 70 | zero: false 71 | multi_class: true 72 | add_special_tokens: true 73 | enable_ce_loss: true 74 | hierarchical: false 75 | accumulate_grad_batches: 1 76 | enable_transformer_trainer: false 77 | -------------------------------------------------------------------------------- /config/ledgar/bert_tuned.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/ledgar 3 | data_name: ledgar 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: 2e-05 15 | weight_decay: 0 16 | patience: 5 17 | 18 | # eval 19 | eval_batch_size: 8 20 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 21 | val_metric: Micro-F1 22 | 23 | # model 24 | model_name: BERT 25 | init_weight: null 26 | network_config: 27 | lm_weight: bert-base-uncased 28 | dropout: 0.2 29 | 30 | # pretrained vocab / embeddings 31 | embed_file: null 32 | -------------------------------------------------------------------------------- /config/ledgar/l2svm.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_linear/ledgar 3 | data_name: ledgar 4 | 5 | # train 6 | seed: 1337 7 | linear: true 8 | liblinear_options: "-s 2 -B 1 -e 0.0001 -q" 9 | linear_technique: 1vsrest 10 | 11 | # eval 12 | eval_batch_size: 256 13 | monitor_metrics: [Micro-F1, Macro-F1] 14 | metric_threshold: 0 15 | 16 | data_format: txt 17 | -------------------------------------------------------------------------------- /config/scotus/bert_default.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/scotus 3 | data_name: scotus 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: 0.00005 15 | weight_decay: 0.001 16 | patience: 5 17 | 18 | # eval 19 | eval_batch_size: 8 20 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 21 | val_metric: Micro-F1 22 | 23 | # model 24 | model_name: BERT 25 | init_weight: null 26 | network_config: 27 | lm_weight: bert-base-uncased 28 | 29 | # pretrained vocab / embeddings 30 | embed_file: null 31 | -------------------------------------------------------------------------------- /config/scotus/bert_reproduced.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_hier/scotus 3 | data_name: scotus 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | hierarchical: true 9 | accumulate_grad_batches: 4 10 | 11 | # train 12 | seed: 1337 13 | epochs: 20 14 | batch_size: 2 15 | optimizer: adamw 16 | learning_rate: 0.00003 17 | weight_decay: 0 18 | patience: 5 19 | 20 | # eval 21 | eval_batch_size: 2 22 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 23 | val_metric: Micro-F1 24 | 25 | # model 26 | model_name: BERT 27 | init_weight: null 28 | network_config: 29 | lm_weight: bert-base-uncased 30 | 31 | # pretrained vocab / embeddings 32 | embed_file: null 33 | -------------------------------------------------------------------------------- /config/scotus/bert_tune.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/scotus 3 | data_name: scotus 4 | min_vocab_freq: 1 5 | max_seq_length: ['grid_search', [128, 512]] 6 | include_test_labels: true 7 | remove_no_label_data: false 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: ['grid_search', [0.00002, 0.00003, 0.00005]] 15 | momentum: 0 16 | weight_decay: ['grid_search', [0, 0.001]] 17 | patience: 5 18 | shuffle: true 19 | 20 | # eval 21 | eval_batch_size: 8 22 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 23 | val_metric: Micro-F1 24 | 25 | # model 26 | model_name: BERT 27 | init_weight: null 28 | network_config: 29 | dropout: ['grid_search', [0.1, 0.2]] 30 | lm_weight: bert-base-uncased 31 | 32 | # pretrained vocab / embeddings 33 | vocab_file: null 34 | embed_file: null 35 | normalize_embed: false 36 | 37 | # hyperparamter search 38 | search_alg: basic_variant 39 | embed_cache_dir: null 40 | num_samples: 1 41 | scheduler: null 42 | # Uncomment the following lines to enable the ASHAScheduler. 43 | # See the documentation here: https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-ashascheduler 44 | #scheduler: 45 | #time_attr: training_iteration 46 | #max_t: 50 # the maximum epochs to run for each config (parameter R in the ASHA paper) 47 | #grace_period: 10 # the minimum epochs to run for each config (parameter r in the ASHA paper) 48 | #reduction_factor: 3 # reduce the number of configuration to floor(1/reduction_factor) each round of successive halving (called rung in ASHA paper) 49 | #brackets: 1 # number of brackets. A smaller bracket index (parameter s in the ASHA paper) means earlier stopping (i.e., less total resources used) 50 | 51 | # other parameters specified in main.py::get_args 52 | checkpoint_path: null 53 | cpu: false 54 | data_workers: 4 55 | eval: false 56 | label_file: null 57 | limit_train_batches: 1.0 58 | limit_val_batches: 1.0 59 | limit_test_batches: 1.0 60 | metric_threshold: 0.5 61 | result_dir: runs 62 | save_k_predictions: 0 63 | silent: true 64 | test_path: null 65 | train_path: null 66 | val_path: null 67 | val_size: 0.2 68 | 69 | # LexGLUE 70 | zero: false 71 | multi_class: true 72 | add_special_tokens: true 73 | enable_ce_loss: true 74 | hierarchical: false 75 | accumulate_grad_batches: 1 76 | enable_transformer_trainer: false 77 | -------------------------------------------------------------------------------- /config/scotus/bert_tuned.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/scotus 3 | data_name: scotus 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: 2e-05 15 | weight_decay: 0 16 | patience: 5 17 | 18 | # eval 19 | eval_batch_size: 8 20 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 21 | val_metric: Micro-F1 22 | 23 | # model 24 | model_name: BERT 25 | init_weight: null 26 | network_config: 27 | lm_weight: bert-base-uncased 28 | dropout: 0.1 29 | 30 | # pretrained vocab / embeddings 31 | embed_file: null 32 | -------------------------------------------------------------------------------- /config/scotus/l2svm.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_linear/scotus 3 | data_name: scotus 4 | 5 | # train 6 | seed: 1337 7 | linear: true 8 | liblinear_options: "-s 2 -B 1 -e 0.0001 -q" 9 | linear_technique: 1vsrest 10 | 11 | # eval 12 | eval_batch_size: 256 13 | monitor_metrics: [Micro-F1, Macro-F1] 14 | metric_threshold: 0 15 | 16 | data_format: txt 17 | -------------------------------------------------------------------------------- /config/unfair_tos/bert_default.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/unfair_tos 3 | data_name: unfair_tos 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: 0.00005 15 | weight_decay: 0.001 16 | patience: 5 17 | 18 | # eval 19 | eval_batch_size: 8 20 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 21 | val_metric: Micro-F1 22 | 23 | # model 24 | model_name: BERT 25 | init_weight: null 26 | network_config: 27 | lm_weight: bert-base-uncased 28 | 29 | # pretrained vocab / embeddings 30 | embed_file: null 31 | -------------------------------------------------------------------------------- /config/unfair_tos/bert_reproduced.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/unfair_tos 3 | data_name: unfair_tos 4 | min_vocab_freq: 1 5 | max_seq_length: 128 6 | include_test_labels: true 7 | add_special_tokens: true 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: 0.00003 15 | weight_decay: 0 16 | patience: 5 17 | 18 | # eval 19 | eval_batch_size: 8 20 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 21 | val_metric: Micro-F1 22 | 23 | # model 24 | model_name: BERT 25 | init_weight: null 26 | network_config: 27 | lm_weight: bert-base-uncased 28 | 29 | # pretrained vocab / embeddings 30 | embed_file: null 31 | -------------------------------------------------------------------------------- /config/unfair_tos/bert_tune.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/unfair_tos 3 | data_name: unfair_tos 4 | min_vocab_freq: 1 5 | max_seq_length: ['grid_search', [128, 512]] 6 | include_test_labels: true 7 | remove_no_label_data: false 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: ['grid_search', [0.00002, 0.00003, 0.00005]] 15 | momentum: 0 16 | weight_decay: ['grid_search', [0, 0.001]] 17 | patience: 5 18 | shuffle: true 19 | 20 | # eval 21 | eval_batch_size: 8 22 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 23 | val_metric: Micro-F1 24 | 25 | # model 26 | model_name: BERT 27 | init_weight: null 28 | network_config: 29 | dropout: ['grid_search', [0.1, 0.2]] 30 | lm_weight: bert-base-uncased 31 | 32 | # pretrained vocab / embeddings 33 | vocab_file: null 34 | embed_file: null 35 | normalize_embed: false 36 | 37 | # hyperparamter search 38 | search_alg: basic_variant 39 | embed_cache_dir: null 40 | num_samples: 1 41 | scheduler: null 42 | # Uncomment the following lines to enable the ASHAScheduler. 43 | # See the documentation here: https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-ashascheduler 44 | #scheduler: 45 | #time_attr: training_iteration 46 | #max_t: 50 # the maximum epochs to run for each config (parameter R in the ASHA paper) 47 | #grace_period: 10 # the minimum epochs to run for each config (parameter r in the ASHA paper) 48 | #reduction_factor: 3 # reduce the number of configuration to floor(1/reduction_factor) each round of successive halving (called rung in ASHA paper) 49 | #brackets: 1 # number of brackets. A smaller bracket index (parameter s in the ASHA paper) means earlier stopping (i.e., less total resources used) 50 | 51 | # other parameters specified in main.py::get_args 52 | checkpoint_path: null 53 | cpu: false 54 | data_workers: 4 55 | eval: false 56 | label_file: null 57 | limit_train_batches: 1.0 58 | limit_val_batches: 1.0 59 | limit_test_batches: 1.0 60 | metric_threshold: 0.5 61 | result_dir: runs 62 | save_k_predictions: 0 63 | silent: true 64 | test_path: null 65 | train_path: null 66 | val_path: null 67 | val_size: 0.2 68 | 69 | # LexGLUE 70 | zero: true 71 | multi_class: false 72 | add_special_tokens: true 73 | enable_ce_loss: false 74 | hierarchical: false 75 | accumulate_grad_batches: 1 76 | enable_transformer_trainer: false 77 | -------------------------------------------------------------------------------- /config/unfair_tos/bert_tuned.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_nn/unfair_tos 3 | data_name: unfair_tos 4 | min_vocab_freq: 1 5 | max_seq_length: 512 6 | include_test_labels: true 7 | add_special_tokens: true 8 | 9 | # train 10 | seed: 1337 11 | epochs: 15 12 | batch_size: 8 13 | optimizer: adamw 14 | learning_rate: 3e-05 15 | weight_decay: 0 16 | patience: 5 17 | 18 | # eval 19 | eval_batch_size: 8 20 | monitor_metrics: ['Micro-F1', 'Macro-F1'] 21 | val_metric: Micro-F1 22 | 23 | # model 24 | model_name: BERT 25 | init_weight: null 26 | network_config: 27 | lm_weight: bert-base-uncased 28 | dropout: 0.1 29 | 30 | # pretrained vocab / embeddings 31 | embed_file: null 32 | -------------------------------------------------------------------------------- /config/unfair_tos/l2svm.yml: -------------------------------------------------------------------------------- 1 | # data 2 | data_dir: data_linear/unfair_tos 3 | data_name: unfair_tos 4 | 5 | # train 6 | seed: 1337 7 | linear: true 8 | liblinear_options: "-s 2 -B 1 -e 0.0001 -q" 9 | linear_technique: 1vsrest 10 | 11 | # eval 12 | eval_batch_size: 256 13 | monitor_metrics: [Micro-F1, Macro-F1] 14 | metric_threshold: 0 15 | 16 | data_format: txt 17 | -------------------------------------------------------------------------------- /generate_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | from datasets import load_dataset 5 | 6 | 7 | def get_args(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('-dl', '--data_list', type=list, 10 | default=['ecthr_a', 'ecthr_b', 'scotus', 'eurlex', 'ledgar', 'unfair_tos']) 11 | parser.add_argument('-sl', '--split_list', type=list, default = ['train', 'validation', 'test']) 12 | parser.add_argument('-ddp', '--data_dir_prefix', type=str, default='data') 13 | parser.add_argument('-f', '--format', type=str, choices=['linear', 'nn', 'hier'], required=True) 14 | args = parser.parse_args() 15 | return args 16 | 17 | 18 | def data2task(data): 19 | return 'multi_label' if data not in ['scotus', 'ledgar'] else 'multi_class' 20 | 21 | 22 | def data2hier(data): 23 | return True if data in ['ecthr_a', 'ecthr_b', 'scotus'] else False 24 | 25 | 26 | def split2name(split): 27 | _split2name = {'train': 'train.txt', 'validation': 'valid.txt', 'test': 'test.txt'} 28 | return _split2name[split] 29 | 30 | 31 | def get_texts(data, dataset, hier=False): 32 | if 'ecthr' in data: 33 | if not hier: 34 | texts = [' '.join(text) for text in dataset['text']] 35 | else: 36 | texts = [' [HIER] '.join(text) for text in dataset['text']] 37 | return [' '.join(text.split()) for text in texts] 38 | elif data == 'scotus' and hier: 39 | texts = [' [HIER] '.join(re.split('\n{2,}', text)) for text in dataset['text']] 40 | # Huggingface tokenizer ignores newline and tab, 41 | # so it's okay to replace them with a space here. 42 | for i in range(len(texts)): 43 | texts[i] = texts[i].replace('\n', ' ') 44 | texts[i] = texts[i].replace('\r', ' ') 45 | texts[i] = texts[i].replace('\t', ' ') 46 | return texts 47 | elif data == 'case_hold': 48 | return [contexts[0] + ' [SEP] '.join(holdings) 49 | for contexts, holdings in zip(dataset['contexts'], dataset['endings'])] 50 | else: 51 | return [' '.join(text.split()) for text in dataset['text']] 52 | 53 | 54 | def get_labels(data, dataset, task): 55 | if task == 'multi_class': 56 | return list(map(str, dataset['label'])) 57 | else: 58 | if data == 'eurlex': 59 | return [' '.join(map(str, [l for l in label if l < 100])) for label in dataset['labels']] 60 | else: 61 | return [' '.join(map(str, label)) for label in dataset['labels']] 62 | 63 | 64 | def save_data(data_path, data): 65 | with open(data_path, 'w') as f: 66 | for text, label in zip(data['text'], data['labels']): 67 | assert '\n' not in label+text 68 | assert '\r' not in label+text 69 | assert '\t' not in label+text 70 | formatted_instance = '\t'.join([label, text]) 71 | f.write(f'{formatted_instance}\n') 72 | 73 | 74 | def main(): 75 | # args 76 | args = get_args() 77 | data_dir = f'{args.data_dir_prefix}_{args.format}' 78 | os.makedirs(data_dir, exist_ok=True) 79 | 80 | # generate 81 | for data in args.data_list: 82 | if args.format == 'hier' and not data2hier(data): 83 | continue 84 | data_path = os.path.join(data_dir, data) 85 | os.makedirs(data_path, exist_ok=True) 86 | processed_data = {} 87 | for split in args.split_list: 88 | dataset = load_dataset('coastalcph/lex_glue', data, split=split, trust_remote_code=True) 89 | texts = get_texts(data, dataset, hier=args.format == 'hier') 90 | labels = get_labels(data, dataset, data2task(data)) 91 | assert len(texts) == len(labels) 92 | print(f'{data} ({split}): num_instance = {len(texts)}') 93 | processed_data[split] = {'text': texts, 'labels': labels} 94 | # format 95 | if args.format == 'linear': 96 | # train 97 | train_path = os.path.join(data_path, split2name('train')) 98 | train_data = { 99 | 'text': processed_data['train']['text'] + processed_data['validation']['text'], 100 | 'labels': processed_data['train']['labels'] + processed_data['validation']['labels'] 101 | } 102 | save_data(train_path, train_data) 103 | # test 104 | test_path = os.path.join(data_path, split2name('test')) 105 | test_data = processed_data['test'] 106 | save_data(test_path, test_data) 107 | elif args.format == 'nn' or args.format == 'hier': 108 | # train/validation/test 109 | for split in processed_data: 110 | split_path = os.path.join(data_path, split2name(split)) 111 | save_data(split_path, processed_data[split]) 112 | 113 | 114 | if __name__ == '__main__': 115 | main() 116 | -------------------------------------------------------------------------------- /generate_data.sh: -------------------------------------------------------------------------------- 1 | format_list=(linear nn hier) 2 | 3 | format=$1 4 | 5 | if [ $# -eq 0 ]; then 6 | for format in "${format_list[@]}"; do 7 | python3 generate_data.py -f ${format} 8 | done 9 | else 10 | if [[ ! " ${format_list[*]} " =~ " ${format} " ]]; then 11 | echo "Invalid argument! Format ${format} is not in (${format_list[*]})." 12 | exit 13 | else 14 | python3 generate_data.py -f ${format} 15 | fi 16 | fi 17 | -------------------------------------------------------------------------------- /libmultilabel/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JamesLYC88/text_classification_baseline_code/6436f567372f3c36547cd52da79516900e5a6148/libmultilabel/__init__.py -------------------------------------------------------------------------------- /libmultilabel/common_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import logging 4 | import os 5 | import time 6 | 7 | import numpy as np 8 | 9 | 10 | class Timer(object): 11 | """Computes elasped time.""" 12 | 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.running = True 18 | self.total = 0 19 | self.start = time.time() 20 | return self 21 | 22 | def resume(self): 23 | if not self.running: 24 | self.running = True 25 | self.start = time.time() 26 | return self 27 | 28 | def stop(self): 29 | if self.running: 30 | self.running = False 31 | self.total += time.time() - self.start 32 | return self 33 | 34 | def time(self): 35 | if self.running: 36 | return self.total + time.time() - self.start 37 | return self.total 38 | 39 | 40 | def dump_log(log_path, metrics=None, split=None, config=None): 41 | """Write log including config and the evaluation scores. 42 | 43 | Args: 44 | log_path(str): path to log path 45 | metrics (dict): metric and scores in dictionary format, defaults to None 46 | split (str): val or test, defaults to None 47 | config (dict): config to save, defaults to None 48 | """ 49 | os.makedirs(os.path.dirname(log_path), exist_ok=True) 50 | if os.path.isfile(log_path): 51 | with open(log_path) as fp: 52 | result = json.load(fp) 53 | else: 54 | result = dict() 55 | 56 | if config: 57 | config_to_save = copy.deepcopy(dict(config)) 58 | config_to_save.pop('device', None) # delete if device exists 59 | result['config'] = config_to_save 60 | if split and metrics: 61 | if split in result: 62 | result[split].append(metrics) 63 | else: 64 | result[split] = [metrics] 65 | 66 | with open(log_path, 'w') as fp: 67 | json.dump(result, fp) 68 | 69 | logging.info(f'Finish writing log to {log_path}.') 70 | 71 | 72 | def argsort_top_k(vals, k, axis=-1): 73 | unsorted_top_k_idx = np.argpartition(vals, -k, axis=axis)[:, -k:] 74 | unsorted_top_k_scores = np.take_along_axis( 75 | vals, unsorted_top_k_idx, axis=axis) 76 | sorted_order = np.argsort(-unsorted_top_k_scores, axis=axis) 77 | sorted_top_k_idx = np.take_along_axis( 78 | unsorted_top_k_idx, sorted_order, axis=axis) 79 | return sorted_top_k_idx 80 | 81 | 82 | class AttributeDict(dict): 83 | """AttributeDict is an extended dict that can access 84 | stored items as attributes. 85 | 86 | >>> ad = AttributeDict({'ans': 42}) 87 | >>> ad.ans 88 | >>> 42 89 | """ 90 | 91 | def __getattr__(self, key: str) -> any: 92 | try: 93 | return self[key] 94 | except KeyError: 95 | raise AttributeError(f'Missing attribute "{key}"') 96 | 97 | def __setattr__(self, key: str, value: any) -> None: 98 | self[key] = value 99 | -------------------------------------------------------------------------------- /libmultilabel/linear/__init__.py: -------------------------------------------------------------------------------- 1 | from .linear import * 2 | from .metrics import get_metrics, tabulate_metrics 3 | from .preprocessor import * 4 | from .utils import * 5 | -------------------------------------------------------------------------------- /libmultilabel/linear/linear.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | import scipy.sparse as sparse 6 | from liblinear.liblinearutil import train 7 | from tqdm import tqdm 8 | 9 | __all__ = ['train_1vsrest', 10 | 'train_thresholding', 11 | 'train_cost_sensitive', 12 | 'train_cost_sensitive_micro', 13 | 'predict_values'] 14 | 15 | 16 | def train_1vsrest(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str): 17 | """Trains a linear model for multiabel data using a one-vs-rest strategy. 18 | 19 | Args: 20 | y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes. 21 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features. 22 | options (str): The option string passed to liblinear. 23 | 24 | Returns: 25 | A model which can be used in predict_values. 26 | """ 27 | # Follows the MATLAB implementation at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/multilabel/ 28 | x, options, bias = prepare_options(x, options) 29 | 30 | y = y.tocsc() 31 | num_class = y.shape[1] 32 | num_feature = x.shape[1] 33 | weights = np.zeros((num_feature, num_class), order='F') 34 | 35 | logging.info(f'Training one-vs-rest model on {num_class} labels') 36 | for i in tqdm(range(num_class)): 37 | yi = y[:, i].toarray().reshape(-1) 38 | modeli = train(2*yi - 1, x, options) 39 | w = np.ctypeslib.as_array(modeli.w, (num_feature,)) 40 | # Liblinear flips +1/-1 labels so +1 is always the first label, 41 | # but not if all labels are -1. 42 | # For our usage, we need +1 to always be the first label, 43 | # so the check is necessary. 44 | if modeli.get_labels()[0] == -1: 45 | weights[:, i] = -w 46 | else: 47 | weights[:, i] = w 48 | 49 | return {'weights': np.asmatrix(weights), '-B': bias, 'threshold': 0} 50 | 51 | 52 | def prepare_options(x: sparse.csr_matrix, options: str) -> 'tuple[sparse.csr_matrix, str, float]': 53 | """Prepare options and x for multi-label training. Called in the first line of 54 | any training function. 55 | 56 | Args: 57 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features. 58 | options (str): The option string passed to liblinear. 59 | 60 | Returns: 61 | tuple[sparse.csr_matrix, str, float]: Transformed x, transformed options and 62 | bias parsed from options. 63 | """ 64 | if options is None: 65 | options = '' 66 | if any(o in options for o in ['-R', '-C', '-v']): 67 | raise ValueError('-R, -C and -v are not supported') 68 | 69 | bias = -1. 70 | if options.find('-B') != -1: 71 | options_split = options.split() 72 | i = options_split.index('-B') 73 | bias = float(options_split[i+1]) 74 | options = ' '.join(options_split[:i] + options_split[i+2:]) 75 | x = sparse.hstack([ 76 | x, 77 | np.full((x.shape[0], 1), bias), 78 | ], 'csr') 79 | 80 | if not '-q' in options: 81 | options += ' -q' 82 | 83 | if not '-m' in options: 84 | options += f' -m {int(os.cpu_count() / 2)}' 85 | 86 | return x, options, bias 87 | 88 | 89 | def train_thresholding(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str): 90 | """Trains a linear model for multilabel data using a one-vs-rest strategy 91 | and cross-validation to pick an optimal decision threshold for Macro-F1. 92 | Outperforms train_1vsrest in most aspects at the cost of higher 93 | time complexity. 94 | See user guide for more details. 95 | 96 | Args: 97 | y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes. 98 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features. 99 | options (str): The option string passed to liblinear. 100 | 101 | Returns: 102 | A model which can be used in predict_values. 103 | """ 104 | # Follows the MATLAB implementation at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/multilabel/ 105 | x, options, bias = prepare_options(x, options) 106 | 107 | y = y.tocsc() 108 | num_class = y.shape[1] 109 | num_feature = x.shape[1] 110 | weights = np.zeros((num_feature, num_class), order='F') 111 | thresholds = np.zeros(num_class) 112 | 113 | logging.info(f'Training thresholding model on {num_class} labels') 114 | for i in tqdm(range(num_class)): 115 | yi = y[:, i].toarray().reshape(-1) 116 | w, t = thresholding_one_label(2*yi - 1, x, options) 117 | weights[:, i] = w.ravel() 118 | thresholds[i] = t 119 | 120 | return {'weights': np.asmatrix(weights), '-B': bias, 'threshold': thresholds} 121 | 122 | 123 | def thresholding_one_label(y: np.ndarray, 124 | x: sparse.csr_matrix, 125 | options: str 126 | ) -> 'tuple[np.ndarray, float]': 127 | """Outer cross-validation for thresholding on a single label. 128 | 129 | Args: 130 | y (np.ndarray): A +1/-1 array with dimensions number of instances * 1. 131 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features. 132 | options (str): The option string passed to liblinear. 133 | 134 | Returns: 135 | tuple[np.ndarray, float]: tuple of the weights and threshold. 136 | """ 137 | fbr_list = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]) 138 | 139 | nr_fold = 3 140 | 141 | l = y.shape[0] 142 | 143 | perm = np.random.permutation(l) 144 | 145 | f_list = np.zeros_like(fbr_list) 146 | 147 | for fold in range(nr_fold): 148 | mask = np.zeros_like(perm, dtype='?') 149 | mask[np.arange(int(fold*l/nr_fold), int((fold+1)*l/nr_fold))] = 1 150 | val_idx = perm[mask] 151 | train_idx = perm[mask != True] 152 | 153 | scutfbr_w, scutfbr_b_list = scutfbr( 154 | y[train_idx], x[train_idx], fbr_list, options) 155 | wTx = (x[val_idx] * scutfbr_w).A1 156 | 157 | for i in range(fbr_list.size): 158 | F = fmeasure(y[val_idx], 2*(wTx > -scutfbr_b_list[i]) - 1) 159 | f_list[i] += F 160 | 161 | best_fbr = fbr_list[::-1][np.argmax(f_list[::-1])] # last largest 162 | if np.max(f_list) == 0: 163 | best_fbr = np.min(fbr_list) 164 | 165 | # final model 166 | w, b_list = scutfbr(y, x, np.array([best_fbr]), options) 167 | 168 | return w, b_list[0] 169 | 170 | 171 | def scutfbr(y: np.ndarray, 172 | x: sparse.csr_matrix, 173 | fbr_list: 'list[float]', 174 | options: str 175 | ) -> 'tuple[np.matrix, np.ndarray]': 176 | """Inner cross-validation for SCutfbr heuristic. 177 | 178 | Args: 179 | y (np.ndarray): A +1/-1 array with dimensions number of instances * 1. 180 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features. 181 | fbr_list (list[float]): list of fbr values. 182 | options (str): The option string passed to liblinear. 183 | 184 | Returns: 185 | tuple[np.matrix, np.ndarray]: tuple of weights and threshold candidates. 186 | """ 187 | 188 | b_list = np.zeros_like(fbr_list) 189 | 190 | nr_fold = 3 191 | 192 | l = y.shape[0] 193 | 194 | perm = np.random.permutation(l) 195 | 196 | for fold in range(nr_fold): 197 | mask = np.zeros_like(perm, dtype='?') 198 | mask[np.arange(int(fold*l/nr_fold), int((fold+1)*l/nr_fold))] = 1 199 | val_idx = perm[mask] 200 | train_idx = perm[mask != True] 201 | 202 | w = do_train(y[train_idx], x[train_idx], options) 203 | 204 | wTx = (x[val_idx] * w).A1 205 | scut_b = 0. 206 | start_F = fmeasure(y[val_idx], 2*(wTx > -scut_b) - 1) 207 | 208 | # stableness to match the MATLAB implementation 209 | sorted_wTx_index = np.argsort(wTx, kind='stable') 210 | sorted_wTx = wTx[sorted_wTx_index] 211 | 212 | tp = np.sum(y[val_idx] == 1) 213 | fp = val_idx.size - tp 214 | fn = 0 215 | cut = -1 216 | best_F = 2*tp / (2*tp + fp + fn) 217 | y_val = y[val_idx] 218 | 219 | # following MATLAB implementation to suppress NaNs 220 | prev_settings = np.seterr('ignore') 221 | for i in range(val_idx.size): 222 | if y_val[sorted_wTx_index[i]] == -1: 223 | fp -= 1 224 | else: 225 | tp -= 1 226 | fn += 1 227 | 228 | # There will be NaNs, but the behaviour is correct 229 | F = 2*tp / (2*tp + fp + fn) 230 | 231 | if F >= best_F: 232 | best_F = F 233 | cut = i 234 | np.seterr(**prev_settings) 235 | 236 | if best_F > start_F: 237 | if cut == -1: # i.e. all 1 in scut 238 | scut_b = np.nextafter(-sorted_wTx[0], np.inf) # predict all 1 239 | elif cut == val_idx.size - 1: 240 | scut_b = np.nextafter(-sorted_wTx[-1], np.inf) 241 | else: 242 | scut_b = -(sorted_wTx[cut] + sorted_wTx[cut + 1]) / 2 243 | 244 | F = fmeasure(y_val, 2*(wTx > -scut_b) - 1) 245 | 246 | for i in range(fbr_list.size): 247 | if F > fbr_list[i]: 248 | b_list[i] += scut_b 249 | else: 250 | b_list[i] -= np.max(wTx) 251 | 252 | b_list = b_list / nr_fold 253 | return do_train(y, x, options), b_list 254 | 255 | 256 | def do_train(y: np.ndarray, x: sparse.csr_matrix, options: str) -> np.matrix: 257 | """Wrapper around liblinear.liblinearutil.train. 258 | Forcibly suppresses all IO regardless of options. 259 | 260 | Args: 261 | y (np.ndarray): A +1/-1 array with dimensions number of instances * 1. 262 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features. 263 | options (str): The option string passed to liblinear. 264 | 265 | Returns: 266 | np.matrix: the weights. 267 | """ 268 | with silent_stderr(): 269 | model = train(y, x, options) 270 | 271 | w = np.ctypeslib.as_array(model.w, (x.shape[1], 1)) 272 | w = np.asmatrix(w) 273 | # Liblinear flips +1/-1 labels so +1 is always the first label, 274 | # but not if all labels are -1. 275 | # For our usage, we need +1 to always be the first label, 276 | # so the check is necessary. 277 | if model.get_labels()[0] == -1: 278 | return -w 279 | else: 280 | # The memory is freed on model deletion so we make a copy. 281 | return w.copy() 282 | 283 | 284 | class silent_stderr: 285 | """Context manager that suppresses stderr. 286 | Liblinear emits warnings on missing classes with 287 | specified weight, which may happen during cross-validation. 288 | Since this information is useless to the user, we suppress it. 289 | """ 290 | 291 | def __init__(self): 292 | self.stderr = os.dup(2) 293 | self.devnull = os.open(os.devnull, os.O_WRONLY) 294 | 295 | def __enter__(self): 296 | os.dup2(self.devnull, 2) 297 | 298 | def __exit__(self, type, value, traceback): 299 | os.dup2(self.stderr, 2) 300 | os.close(self.devnull) 301 | os.close(self.stderr) 302 | 303 | 304 | def fmeasure(y_true: np.ndarray, y_pred: np.ndarray) -> float: 305 | """Calculate F1 score. 306 | 307 | Args: 308 | y_true (np.ndarray): array of +1/-1. 309 | y_pred (np.ndarray): array of +1/-1. 310 | 311 | Returns: 312 | float: the F1 score. 313 | """ 314 | tp = np.sum(np.logical_and(y_true == 1, y_pred == 1)) 315 | fn = np.sum(np.logical_and(y_true == 1, y_pred == -1)) 316 | fp = np.sum(np.logical_and(y_true == -1, y_pred == 1)) 317 | 318 | F = 0 319 | if tp != 0 or fp != 0 or fn != 0: 320 | F = 2*tp / (2*tp + fp + fn) 321 | return F 322 | 323 | 324 | def train_cost_sensitive(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str): 325 | """Trains a linear model for multilabel data using a one-vs-rest strategy 326 | and cross-validation to pick an optimal asymmetric misclassification cost 327 | for Macro-F1. 328 | Outperforms train_1vsrest in most aspects at the cost of higher 329 | time complexity. 330 | See user guide for more details. 331 | 332 | Args: 333 | y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes. 334 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features. 335 | options (str): The option string passed to liblinear. 336 | 337 | Returns: 338 | A model which can be used in predict_values. 339 | """ 340 | # Follows the MATLAB implementation at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/multilabel/ 341 | x, options, bias = prepare_options(x, options) 342 | 343 | y = y.tocsc() 344 | num_class = y.shape[1] 345 | num_feature = x.shape[1] 346 | weights = np.zeros((num_feature, num_class), order='F') 347 | 348 | logging.info( 349 | f'Training cost-sensitive model for Macro-F1 on {num_class} labels') 350 | for i in tqdm(range(num_class)): 351 | yi = y[:, i].toarray().reshape(-1) 352 | w = cost_sensitive_one_label(2*yi - 1, x, options) 353 | weights[:, i] = w.ravel() 354 | 355 | return {'weights': np.asmatrix(weights), '-B': bias, 'threshold': 0} 356 | 357 | 358 | def cost_sensitive_one_label(y: np.ndarray, 359 | x: sparse.csr_matrix, 360 | options: str 361 | ) -> np.ndarray: 362 | """Loop over parameter space for cost-sensitive on a single label. 363 | 364 | Args: 365 | y (np.ndarray): A +1/-1 array with dimensions number of instances * 1. 366 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features. 367 | options (str): The option string passed to liblinear. 368 | 369 | Returns: 370 | np.ndarray: the weights. 371 | """ 372 | 373 | l = y.shape[0] 374 | perm = np.random.permutation(l) 375 | 376 | param_space = [1, 1.33, 1.8, 2.5, 3.67, 6, 13] 377 | 378 | bestScore = -np.Inf 379 | for a in param_space: 380 | cv_options = f'{options} -w1 {a}' 381 | pred = cross_validate(y, x, cv_options, perm) 382 | score = fmeasure(y, pred) 383 | if bestScore < score: 384 | bestScore = score 385 | bestA = a 386 | 387 | final_options = f'{options} -w1 {bestA}' 388 | return do_train(y, x, final_options) 389 | 390 | 391 | def cross_validate(y: np.ndarray, 392 | x: sparse.csr_matrix, 393 | options: str, 394 | perm: np.ndarray 395 | ) -> np.ndarray: 396 | """Cross-validation for cost-sensitive. 397 | 398 | Args: 399 | y (np.ndarray): A +1/-1 array with dimensions number of instances * 1. 400 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features. 401 | options (str): The option string passed to liblinear. 402 | 403 | Returns: 404 | np.ndarray: Cross-validation result as a +1/-1 array. 405 | """ 406 | l = y.shape[0] 407 | nr_fold = 3 408 | pred = np.zeros_like(y) 409 | for fold in range(nr_fold): 410 | mask = np.zeros_like(perm, dtype='?') 411 | mask[np.arange(int(fold*l/nr_fold), int((fold+1)*l/nr_fold))] = 1 412 | val_idx = perm[mask] 413 | train_idx = perm[mask != True] 414 | 415 | w = do_train(y[train_idx], x[train_idx], options) 416 | pred[val_idx] = (x[val_idx] * w).A1 > 0 417 | 418 | return 2*pred - 1 419 | 420 | 421 | def train_cost_sensitive_micro(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str): 422 | """Trains a linear model for multilabel data using a one-vs-rest strategy 423 | and cross-validation to pick an optimal asymmetric misclassification cost 424 | for Micro-F1. 425 | Outperforms train_1vsrest in most aspects at the cost of higher 426 | time complexity. 427 | See user guide for more details. 428 | 429 | Args: 430 | y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes. 431 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features. 432 | options (str): The option string passed to liblinear. 433 | 434 | Returns: 435 | A model which can be used in predict_values. 436 | """ 437 | # Follows the MATLAB implementation at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/multilabel/ 438 | x, options, bias = prepare_options(x, options) 439 | 440 | y = y.tocsc() 441 | num_class = y.shape[1] 442 | num_feature = x.shape[1] 443 | weights = np.zeros((num_feature, num_class), order='F') 444 | 445 | l = y.shape[0] 446 | perm = np.random.permutation(l) 447 | param_space = [1, 1.33, 1.8, 2.5, 3.67, 6, 13] 448 | bestScore = -np.Inf 449 | 450 | logging.info( 451 | f'Training cost-sensitive model for Micro-F1 on {num_class} labels') 452 | for a in param_space: 453 | tp = fn = fp = 0 454 | for i in tqdm(range(num_class)): 455 | yi = y[:, i].toarray().reshape(-1) 456 | yi = 2*yi - 1 457 | 458 | cv_options = f'{options} -w1 {a}' 459 | pred = cross_validate(yi, x, cv_options, perm) 460 | tp = tp + np.sum(np.logical_and(yi == 1, pred == 1)) 461 | fn = fn + np.sum(np.logical_and(yi == 1, pred == -1)) 462 | fp = fp + np.sum(np.logical_and(yi == -1, pred == 1)) 463 | 464 | score = 2*tp / (2*tp + fn + fp) 465 | if bestScore < score: 466 | bestScore = score 467 | bestA = a 468 | 469 | final_options = f'{options} -w1 {bestA}' 470 | for i in range(num_class): 471 | yi = y[:, i].toarray().reshape(-1) 472 | w = do_train(2*yi - 1, x, final_options) 473 | weights[:, i] = w.ravel() 474 | 475 | return {'weights': np.asmatrix(weights), '-B': bias, 'threshold': 0} 476 | 477 | 478 | def predict_values(model, x: sparse.csr_matrix) -> np.ndarray: 479 | """Calculates the decision values associated with x. 480 | 481 | Args: 482 | model: A model returned from a training function. 483 | x (sparse.csr_matrix): A matrix with dimension number of instances * number of features. 484 | 485 | Returns: 486 | np.ndarray: A matrix with dimension number of instances * number of classes. 487 | """ 488 | bias = model['-B'] 489 | bias_col = np.full((x.shape[0], 1 if bias > 0 else 0), bias) 490 | num_feature = model['weights'].shape[0] 491 | num_feature -= 1 if bias > 0 else 0 492 | if x.shape[1] < num_feature: 493 | x = sparse.hstack([ 494 | x, 495 | np.zeros((x.shape[0], num_feature - x.shape[1])), 496 | bias_col, 497 | ], 'csr') 498 | else: 499 | x = sparse.hstack([ 500 | x[:, :num_feature], 501 | bias_col, 502 | ], 'csr') 503 | 504 | return (x * model['weights']).A + model['threshold'] 505 | -------------------------------------------------------------------------------- /libmultilabel/linear/metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | 5 | __all__ = ['RPrecision', 6 | 'Precision', 7 | 'F1', 8 | 'MetricCollection', 9 | 'get_metrics', 10 | 'tabulate_metrics'] 11 | 12 | 13 | class RPrecision: 14 | def __init__(self, top_k: int) -> None: 15 | self.top_k = top_k 16 | self.score = 0 17 | self.num_sample = 0 18 | 19 | def update(self, preds: np.ndarray, target: np.ndarray) -> None: 20 | assert preds.shape == target.shape # (batch_size, num_classes) 21 | top_k_ind = np.argpartition(preds, -self.top_k)[:, -self.top_k:] 22 | num_relevant = np.take_along_axis( 23 | target, top_k_ind, axis=-1).sum(axis=-1) # (batch_size, top_k) 24 | self.score += np.nan_to_num( 25 | num_relevant / np.minimum(self.top_k, target.sum(axis=-1)), 26 | posinf=0. 27 | ).sum() 28 | self.num_sample += preds.shape[0] 29 | 30 | def compute(self) -> float: 31 | return self.score / self.num_sample 32 | 33 | 34 | class Precision: 35 | def __init__(self, num_classes: int, average: str, top_k: int) -> None: 36 | self.top_k = top_k 37 | self.score = 0 38 | self.num_sample = 0 39 | 40 | def update(self, preds: np.ndarray, target: np.ndarray) -> None: 41 | assert preds.shape == target.shape # (batch_size, num_classes) 42 | top_k_ind = np.argpartition(preds, -self.top_k)[:, -self.top_k:] 43 | num_relevant = np.take_along_axis(target, top_k_ind, -1).sum() 44 | self.score += num_relevant / self.top_k 45 | self.num_sample += preds.shape[0] 46 | 47 | def compute(self) -> float: 48 | return self.score / self.num_sample 49 | 50 | 51 | def add_zero_class(labels): 52 | augmented_labels = np.zeros((len(labels), len(labels[0]) + 1), dtype=np.int32) 53 | augmented_labels[:, :-1] = labels 54 | augmented_labels[:, -1] = (np.sum(labels, axis=1) == 0).astype('int32') 55 | return augmented_labels 56 | 57 | 58 | class F1: 59 | def __init__(self, num_classes: int, metric_threshold: float, average: str, 60 | zero: bool, multi_class: bool) -> None: 61 | self.num_classes = num_classes 62 | self.metric_threshold = metric_threshold 63 | if average not in {'macro', 'micro', 'another-macro'}: 64 | raise ValueError('unsupported average') 65 | self.average = average 66 | self.tp = self.fp = self.fn = 0 67 | self.zero = zero 68 | if self.zero: 69 | self.num_classes += 1 70 | self.multi_class = multi_class 71 | 72 | def update(self, preds: np.ndarray, target: np.ndarray) -> None: 73 | if self.multi_class: 74 | preds = np.eye(preds.shape[1])[preds.argmax(1)] 75 | else: 76 | preds = preds > self.metric_threshold 77 | if self.zero: 78 | preds = add_zero_class(preds) 79 | target = add_zero_class(target) 80 | assert preds.shape == target.shape # (batch_size, num_classes) 81 | self.tp += np.logical_and(target == 1, preds == 1).sum(axis=0) 82 | self.fn += np.logical_and(target == 1, preds == 0).sum(axis=0) 83 | self.fp += np.logical_and(target == 0, preds == 1).sum(axis=0) 84 | 85 | def compute(self) -> float: 86 | prev_settings = np.seterr('ignore') 87 | 88 | if self.average == 'macro': 89 | score = np.nansum( 90 | 2*self.tp / (2*self.tp + self.fp + self.fn)) / self.num_classes 91 | elif self.average == 'micro': 92 | score = np.nan_to_num(2*np.sum(self.tp) / 93 | np.sum(2*self.tp + self.fp + self.fn)) 94 | elif self.average == 'another-macro': 95 | macro_prec = np.nansum( 96 | self.tp / (self.tp + self.fp)) / self.num_classes 97 | macro_recall = np.nansum( 98 | self.tp / (self.tp + self.fn)) / self.num_classes 99 | score = np.nan_to_num( 100 | 2 * macro_prec * macro_recall / (macro_prec + macro_recall)) 101 | 102 | np.seterr(**prev_settings) 103 | return score 104 | 105 | 106 | class MetricCollection(dict): 107 | def __init__(self, metrics) -> None: 108 | self.metrics = metrics 109 | 110 | def update(self, preds: np.ndarray, target: np.ndarray) -> None: 111 | assert preds.shape == target.shape # (batch_size, num_classes) 112 | for metric in self.metrics.values(): 113 | metric.update(preds, target) 114 | 115 | def compute(self) -> "dict[str, float]": 116 | ret = {} 117 | for name, metric in self.metrics.items(): 118 | ret[name] = metric.compute() 119 | return ret 120 | 121 | 122 | def get_metrics(metric_threshold: float, monitor_metrics: list, num_classes: int, 123 | zero: bool, multi_class: bool): 124 | """Get a collection of metrics by their names. 125 | Args: 126 | metric_threshold (float): The decision value threshold over which a label 127 | is predicted as positive. 128 | monitor_metrics (list): A list of strings naming the metrics. 129 | num_classes (int): The number of classes. 130 | zero (bool) 131 | multi_class (bool) 132 | Returns: 133 | MetricCollection: A metric collection of the list of metrics. 134 | """ 135 | if monitor_metrics is None: 136 | monitor_metrics = [] 137 | metrics = {} 138 | for metric in monitor_metrics: 139 | if re.match('P@\d+', metric): 140 | metrics[metric] = Precision( 141 | num_classes, average='samples', top_k=int(metric[2:])) 142 | elif re.match('RP@\d+', metric): 143 | metrics[metric] = RPrecision(top_k=int(metric[3:])) 144 | elif metric in {'Another-Macro-F1', 'Macro-F1', 'Micro-F1'}: 145 | metrics[metric] = F1( 146 | num_classes, metric_threshold, average=metric[:-3].lower(), 147 | zero=zero, multi_class=multi_class) 148 | else: 149 | raise ValueError(f'Invalid metric: {metric}') 150 | 151 | return MetricCollection(metrics) 152 | 153 | 154 | def tabulate_metrics(metric_dict, split): 155 | msg = f'====== {split} dataset evaluation result =======\n' 156 | header = '|'.join([f'{k:^18}' for k in metric_dict.keys()]) 157 | values = '|'.join([f'{x * 100:^18.4f}' if isinstance(x, (np.floating, 158 | float)) else f'{x:^18}' for x in metric_dict.values()]) 159 | msg += f"|{header}|\n|{'-----------------:|' * len(metric_dict)}\n|{values}|\n" 160 | return msg 161 | -------------------------------------------------------------------------------- /libmultilabel/linear/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import os 5 | import re 6 | from array import array 7 | from collections import defaultdict 8 | 9 | import numpy as np 10 | import pandas as pd 11 | import scipy 12 | import scipy.sparse as sparse 13 | from sklearn.feature_extraction.text import TfidfVectorizer 14 | from sklearn.preprocessing import MultiLabelBinarizer 15 | 16 | __all__ = ['Preprocessor'] 17 | 18 | 19 | class Preprocessor: 20 | """Preprocessor is used to load and preprocess input data in LibSVM and LibMultiLabel formats. 21 | The same Preprocessor has to be used for both training and testing data; 22 | see save_pipeline and load_pipeline. 23 | """ 24 | 25 | def __init__(self, data_format: str) -> None: 26 | """Initializes the preprocessor. 27 | 28 | Args: 29 | data_format (str): The data format used. 'svm' for LibSVM format and 'txt' for LibMultiLabel format. 30 | """ 31 | if not data_format in {'txt', 'svm'}: 32 | raise ValueError(f'unsupported data format {data_format}') 33 | 34 | self.data_format = data_format 35 | 36 | def load_data(self, train_path: str = '', 37 | test_path: str = '', 38 | eval: bool = False, 39 | label_file: str = None, 40 | include_test_labels: bool = False, 41 | remove_no_label_data: bool = False) -> 'dict[str, dict]': 42 | """Loads and preprocesses data. 43 | 44 | Args: 45 | train_path (str): Training data path. Ignored if eval is True. Defaults to ''. 46 | test_path (str): Test data path. Ignored if test_path doesn't exist. Defaults to ''. 47 | eval (bool): If True, ignores training data and uses previously loaded state to preprocess test data. 48 | label_file (str, optional): Path to a file holding all labels. 49 | include_test_labels (bool, optional): Whether to include labels in the test dataset. Defaults to False. 50 | remove_no_label_data (bool, optional): Whether to remove training instances that have no labels. 51 | 52 | Returns: 53 | dict[str, dict]: The training and test data, with keys 'train' and 'test' respectively. The data 54 | has keys 'x' for input features and 'y' for labels. 55 | """ 56 | if label_file is not None: 57 | logging.info(f'Load labels from {label_file}.') 58 | with open(label_file, 'r') as fp: 59 | self.classes = sorted([s.strip() for s in fp.readlines()]) 60 | else: 61 | if not os.path.exists(test_path) and include_test_labels: 62 | raise ValueError( 63 | f'Specified the inclusion of test labels but test file does not exist') 64 | self.classes = None 65 | self.include_test_labels = include_test_labels 66 | 67 | if self.data_format == 'txt': 68 | data = self._load_txt(train_path, test_path, eval) 69 | elif self.data_format == 'svm': 70 | data = self._load_svm(train_path, test_path, eval) 71 | 72 | if 'train' in data: 73 | num_labels = data['train']['y'].getnnz(axis=1) 74 | num_no_label_data = np.count_nonzero(num_labels == 0) 75 | if num_no_label_data > 0: 76 | if remove_no_label_data: 77 | logging.info(f'Remove {num_no_label_data} instances that have no labels from {train_path}.') 78 | data['train']['x'] = data['train']['x'][num_labels > 0] 79 | data['train']['y'] = data['train']['y'][num_labels > 0] 80 | else: 81 | logging.info(f'Keep {num_no_label_data} instances that have no labels from {train_path}.') 82 | 83 | return data 84 | 85 | def _load_txt(self, train_path, test_path, eval) -> 'dict[str, dict]': 86 | datasets = defaultdict(dict) 87 | if os.path.exists(test_path): 88 | test = read_libmultilabel_format(test_path) 89 | 90 | if not eval: 91 | train = read_libmultilabel_format(train_path) 92 | self._generate_tfidf(train['text']) 93 | 94 | if self.classes or not self.include_test_labels: 95 | self._generate_label_mapping(train['label'], self.classes) 96 | else: 97 | self._generate_label_mapping(train['label'] + test['label']) 98 | datasets['train']['x'] = self.vectorizer.transform(train['text']) 99 | datasets['train']['y'] = self.binarizer.transform( 100 | train['label']).astype('d') 101 | 102 | if os.path.exists(test_path): 103 | datasets['test']['x'] = self.vectorizer.transform(test['text']) 104 | datasets['test']['y'] = self.binarizer.transform( 105 | test['label']).astype('d') 106 | 107 | return dict(datasets) 108 | 109 | def _load_svm(self, train_path, test_path, eval) -> 'dict[str, dict]': 110 | datasets = defaultdict(dict) 111 | if os.path.exists(test_path): 112 | ty, tx = read_libsvm_format(test_path) 113 | 114 | if not eval: 115 | y, x = read_libsvm_format(train_path) 116 | if self.classes or not self.include_test_labels: 117 | self._generate_label_mapping(y, self.classes) 118 | else: 119 | self._generate_label_mapping(y + ty) 120 | datasets['train']['x'] = x 121 | datasets['train']['y'] = self.binarizer.transform(y).astype('d') 122 | 123 | if os.path.exists(test_path): 124 | datasets['test']['x'] = tx 125 | datasets['test']['y'] = self.binarizer.transform(ty).astype('d') 126 | return dict(datasets) 127 | 128 | def _generate_tfidf(self, texts): 129 | self.vectorizer = TfidfVectorizer() 130 | self.vectorizer.fit(texts) 131 | 132 | def _generate_label_mapping(self, labels, classes=None): 133 | self.binarizer = MultiLabelBinarizer( 134 | sparse_output=True, classes=classes) 135 | self.binarizer.fit(labels) 136 | 137 | 138 | def read_libmultilabel_format(path: str) -> 'dict[str,list[str]]': 139 | data = pd.read_csv(path, sep='\t', header=None, 140 | dtype=str, 141 | on_bad_lines='skip').fillna('') 142 | if data.shape[1] == 2: 143 | data.columns = ['label', 'text'] 144 | data = data.reset_index() 145 | elif data.shape[1] == 3: 146 | data.columns = ['index', 'label', 'text'] 147 | else: 148 | raise ValueError(f'Expected 2 or 3 columns, got {data.shape[1]}.') 149 | data['label'] = data['label'].map(lambda s: s.split()) 150 | return data.to_dict('list') 151 | 152 | 153 | def read_libsvm_format(file_path: str) -> 'tuple[list[list[int]], sparse.csr_matrix]': 154 | """Read multi-label LIBSVM-format data. 155 | 156 | Args: 157 | file_path (str): Path to file. 158 | 159 | Returns: 160 | tuple[list[list[int]], sparse.csr_matrix]: A tuple of labels and features. 161 | """ 162 | def as_ints(str): 163 | return [int(s) for s in str.split(',')] 164 | 165 | prob_y = [] 166 | prob_x = array('d') 167 | row_ptr = array('l', [0]) 168 | col_idx = array('l') 169 | 170 | pattern = re.compile(r'(?!^$)([+\-0-9,]+\s+)?(.*\n?)') 171 | for i, line in enumerate(open(file_path)): 172 | m = pattern.fullmatch(line) 173 | try: 174 | labels = m[1] 175 | prob_y.append(as_ints(labels) if labels else []) 176 | features = m[2] or '' 177 | nz = 0 178 | for e in features.split(): 179 | ind, val = e.split(':') 180 | ind, val = int(ind), float(val) 181 | if ind < 1: 182 | raise IndexError(f'invalid svm format at line {i+1} of the file \'{file_path}\' --> Indices should start from one.') 183 | if val != 0: 184 | col_idx.append(ind - 1) 185 | prob_x.append(val) 186 | nz += 1 187 | row_ptr.append(row_ptr[-1]+nz) 188 | except IndexError: 189 | raise 190 | except: 191 | raise ValueError(f'invalid svm format at line {i+1} of the file \'{file_path}\'') 192 | 193 | prob_x = scipy.frombuffer(prob_x, dtype='d') 194 | col_idx = scipy.frombuffer(col_idx, dtype='l') 195 | row_ptr = scipy.frombuffer(row_ptr, dtype='l') 196 | prob_x = sparse.csr_matrix((prob_x, col_idx, row_ptr)) 197 | 198 | return (prob_y, prob_x) 199 | -------------------------------------------------------------------------------- /libmultilabel/linear/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | from pathlib import Path 4 | 5 | from .preprocessor import Preprocessor 6 | 7 | __all__ = ['save_pipeline', 'load_pipeline'] 8 | 9 | 10 | def save_pipeline(checkpoint_dir: str, preprocessor: Preprocessor, model): 11 | """Saves preprocessor and model to checkpoint_dir/linear_pipline.pickle. 12 | 13 | Args: 14 | checkpoint_dir (str): The directory to save to. 15 | preprocessor (Preprocessor): A Preprocessor. 16 | model: A model returned from one of the training functions. 17 | """ 18 | Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) 19 | checkpoint_path = os.path.join(checkpoint_dir, 'linear_pipeline.pickle') 20 | 21 | with open(checkpoint_path, 'wb') as f: 22 | pickle.dump({ 23 | 'preprocessor': preprocessor, 24 | 'model': model, 25 | }, f) 26 | 27 | 28 | def load_pipeline(checkpoint_path: str) -> tuple: 29 | """Loads preprocessor and model from checkpoint_path. 30 | 31 | Args: 32 | checkpoint_path (str): The path to a previously saved pipeline. 33 | 34 | Returns: 35 | tuple: A tuple of the preprocessor and model. 36 | """ 37 | with open(checkpoint_path, 'rb') as f: 38 | pipeline = pickle.load(f) 39 | return (pipeline['preprocessor'], pipeline['model']) 40 | -------------------------------------------------------------------------------- /libmultilabel/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JamesLYC88/text_classification_baseline_code/6436f567372f3c36547cd52da79516900e5a6148/libmultilabel/nn/__init__.py -------------------------------------------------------------------------------- /libmultilabel/nn/metrics.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | import torch 5 | import torchmetrics.classification 6 | from torchmetrics import Metric, MetricCollection, Precision, Recall, RetrievalNormalizedDCG 7 | from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg 8 | from torchmetrics.utilities.data import select_topk 9 | 10 | 11 | class NDCG(Metric): 12 | """NDCG (Normalized Discounted Cumulative Gain) sums the true scores 13 | ranked in the order induced by the predicted scores after applying a logarithmic discount, 14 | and then divides by the best possible score (Ideal DCG, obtained for a perfect ranking) 15 | to obtain a score between 0 and 1. 16 | The definition is quoted from: 17 | https://scikit-learn.org/stable/modules/generated/sklearn.metrics.ndcg_score.html 18 | Please find the formal definition here: 19 | https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-ranked-retrieval-results-1.html 20 | 21 | Args: 22 | top_k (int): the top k relevant labels to evaluate. 23 | """ 24 | # If the metric state of one batch is independent of the state of other batches, 25 | # full_state_update can be set to False, 26 | # which leads to more efficient computation with calling update() only once. 27 | # Please find the detailed explanation here: 28 | # https://torchmetrics.readthedocs.io/en/stable/pages/implement.html 29 | full_state_update = False 30 | 31 | def __init__( 32 | self, 33 | top_k 34 | ): 35 | super().__init__() 36 | self.top_k = top_k 37 | self.add_state("ndcg", default=[], dist_reduce_fx="cat") 38 | 39 | def update(self, preds, target): 40 | assert preds.shape == target.shape 41 | # implement batch-wise calculations instead of storing results of all batches 42 | self.ndcg += [self._metric(p, t) for p, t in zip(preds, target)] 43 | 44 | def compute(self): 45 | return torch.stack(self.ndcg).mean() 46 | 47 | def _metric(self, preds, target): 48 | return retrieval_normalized_dcg(preds, target, k=self.top_k) 49 | 50 | 51 | class RPrecision(Metric): 52 | """R-precision calculates precision at k by adjusting k to the minimum value of the number of 53 | relevant labels and k. The definition is given at Appendix C equation (3) of 54 | https://aclanthology.org/P19-1636.pdf 55 | 56 | Args: 57 | top_k (int): the top k relevant labels to evaluate. 58 | """ 59 | # If the metric state of one batch is independent of the state of other batches, 60 | # full_state_update can be set to False, 61 | # which leads to more efficient computation with calling update() only once. 62 | # Please find the detailed explanation here: 63 | # https://torchmetrics.readthedocs.io/en/stable/pages/implement.html 64 | full_state_update = False 65 | 66 | def __init__( 67 | self, 68 | top_k 69 | ): 70 | super().__init__() 71 | self.top_k = top_k 72 | self.add_state("score", default=torch.tensor(0., dtype=torch.double), dist_reduce_fx="sum") 73 | self.add_state("num_sample", default=torch.tensor(0), dist_reduce_fx="sum") 74 | 75 | def update(self, preds, target): 76 | assert preds.shape == target.shape 77 | binary_topk_preds = select_topk(preds, self.top_k) 78 | target = target.to(dtype=torch.int) 79 | num_relevant = torch.sum(binary_topk_preds & target, dim=-1) 80 | top_ks = torch.tensor([self.top_k]*preds.shape[0]).to(preds.device) 81 | self.score += torch.nan_to_num( 82 | num_relevant / torch.min(top_ks, target.sum(dim=-1)), 83 | posinf=0. 84 | ).sum() 85 | self.num_sample += len(preds) 86 | 87 | def compute(self): 88 | return self.score / self.num_sample 89 | 90 | 91 | class MacroF1(Metric): 92 | """The macro-f1 score computes the average f1 scores of all labels in the dataset. 93 | 94 | Args: 95 | num_classes (int): Total number of classes. 96 | metric_threshold (float): Threshold to monitor for metrics. 97 | another_macro_f1 (bool, optional): Whether to compute the 'Another-Macro-F1' score. 98 | The 'Another-Macro-F1' is the f1 value of macro-precision and macro-recall. 99 | This variant of macro-f1 is less preferred but is used in some works. 100 | Please refer to Opitz et al. 2019 [https://arxiv.org/pdf/1911.03347.pdf]. 101 | Defaults to False. 102 | """ 103 | # If the metric state of one batch is independent of the state of other batches, 104 | # full_state_update can be set to False, 105 | # which leads to more efficient computation with calling update() only once. 106 | # Please find the detailed explanation here: 107 | # https://torchmetrics.readthedocs.io/en/stable/pages/implement.html 108 | full_state_update = False 109 | 110 | def __init__( 111 | self, 112 | num_classes, 113 | metric_threshold, 114 | another_macro_f1=False 115 | ): 116 | super().__init__() 117 | self.metric_threshold = metric_threshold 118 | self.another_macro_f1 = another_macro_f1 119 | self.add_state("preds_sum", default=torch.zeros(num_classes, dtype=torch.double)) 120 | self.add_state("target_sum", default=torch.zeros(num_classes, dtype=torch.double)) 121 | self.add_state("tp_sum", default=torch.zeros(num_classes, dtype=torch.double)) 122 | 123 | def update(self, preds, target): 124 | assert preds.shape == target.shape 125 | preds = torch.where(preds > self.metric_threshold, 1, 0) 126 | self.preds_sum = torch.add(self.preds_sum, preds.sum(dim=0)) 127 | self.target_sum = torch.add(self.target_sum, target.sum(dim=0)) 128 | self.tp_sum = torch.add(self.tp_sum, (preds & target).sum(dim=0)) 129 | 130 | def compute(self): 131 | if self.another_macro_f1: 132 | macro_prec = torch.mean(torch.nan_to_num(self.tp_sum / self.preds_sum, posinf=0.)) 133 | macro_recall = torch.mean(torch.nan_to_num(self.tp_sum / self.target_sum, posinf=0.)) 134 | return 2 * (macro_prec * macro_recall) / (macro_prec + macro_recall + 1e-10) 135 | else: 136 | label_f1 = 2 * self.tp_sum / (self.preds_sum + self.target_sum + 1e-10) 137 | return torch.mean(label_f1) 138 | 139 | 140 | def add_zero_class(labels): 141 | augmented_labels = torch.zeros((len(labels), len(labels[0]) + 1), dtype=torch.int32).to(labels.device) 142 | augmented_labels[:, :-1] = labels 143 | augmented_labels[:, -1] = (torch.sum(labels, axis=1) == 0).type(torch.int32) 144 | return augmented_labels 145 | 146 | 147 | class F1(Metric): 148 | full_state_update = False 149 | 150 | def __init__( 151 | self, 152 | num_classes, 153 | metric_threshold, 154 | average, 155 | zero, 156 | multi_class 157 | ): 158 | super().__init__() 159 | self.metric_threshold = metric_threshold 160 | if average not in {'macro', 'micro', 'another-macro'}: 161 | raise ValueError('unsupported average') 162 | self.average = average 163 | self.zero = zero 164 | if self.zero: 165 | num_classes += 1 166 | self.multi_class = multi_class 167 | self.add_state("preds_sum", default=torch.zeros(num_classes, dtype=torch.double)) 168 | self.add_state("target_sum", default=torch.zeros(num_classes, dtype=torch.double)) 169 | self.add_state("tp_sum", default=torch.zeros(num_classes, dtype=torch.double)) 170 | 171 | def update(self, preds, target): 172 | if self.multi_class: 173 | preds = torch.eye(preds.shape[1])[preds.argmax(1)].type(torch.int32).to(preds.device) 174 | else: 175 | preds = torch.where(preds > self.metric_threshold, 1, 0) 176 | if self.zero: 177 | preds = add_zero_class(preds) 178 | target = add_zero_class(target) 179 | assert preds.shape == target.shape 180 | self.preds_sum = torch.add(self.preds_sum, preds.sum(dim=0)) 181 | self.target_sum = torch.add(self.target_sum, target.sum(dim=0)) 182 | self.tp_sum = torch.add(self.tp_sum, (preds & target).sum(dim=0)) 183 | 184 | def compute(self): 185 | if self.average == 'another-macro': 186 | macro_prec = torch.mean(torch.nan_to_num(self.tp_sum / self.preds_sum, posinf=0.)) 187 | macro_recall = torch.mean(torch.nan_to_num(self.tp_sum / self.target_sum, posinf=0.)) 188 | return 2 * (macro_prec * macro_recall) / (macro_prec + macro_recall + 1e-10) 189 | elif self.average == 'macro': 190 | label_f1 = 2 * self.tp_sum / (self.preds_sum + self.target_sum + 1e-10) 191 | return torch.mean(label_f1) 192 | elif self.average == 'micro': 193 | return 2 * torch.sum(self.tp_sum) / (torch.sum(self.preds_sum + self.target_sum) + 1e-10) 194 | 195 | 196 | def get_metrics(metric_threshold, monitor_metrics, num_classes, zero, multi_class): 197 | """Map monitor metrics to the corresponding classes defined in `torchmetrics.Metric` 198 | (https://torchmetrics.readthedocs.io/en/latest/references/modules.html). 199 | 200 | Args: 201 | metric_threshold (float): Threshold to monitor for metrics. 202 | monitor_metrics (list): Metrics to monitor while validating. 203 | num_classes (int): Total number of classes. 204 | zero (bool) 205 | multi_class (bool) 206 | 207 | Raises: 208 | ValueError: The metric is invalid if: 209 | (1) It is not one of 'P@k', 'R@k', 'RP@k', 'nDCG@k', 'Micro-Precision', 210 | 'Micro-Recall', 'Micro-F1', 'Macro-F1', 'Another-Macro-F1', or a 211 | `torchmetrics.Metric`. 212 | (2) Metric@k: k is greater than `num_classes`. 213 | 214 | Returns: 215 | torchmetrics.MetricCollection: A collections of `torchmetrics.Metric` for evaluation. 216 | """ 217 | if monitor_metrics is None: 218 | monitor_metrics = [] 219 | 220 | metrics = dict() 221 | for metric in monitor_metrics: 222 | if isinstance(metric, Metric): # customized metric 223 | metrics[type(metric).__name__] = metric 224 | continue 225 | 226 | match_top_k = re.match(r'\b(P|R|RP|nDCG)\b@(\d+)', metric) 227 | match_metric = re.match(r'\b(Micro|Macro)\b-\b(Precision|Recall|F1)\b', metric) 228 | 229 | if match_top_k: 230 | metric_abbr = match_top_k.group(1) # P, R, PR, or nDCG 231 | top_k = int(match_top_k.group(2)) 232 | if top_k >= num_classes: 233 | raise ValueError( 234 | f'Invalid metric: {metric}. {top_k} is greater than {num_classes}.') 235 | if metric_abbr == 'P': 236 | metrics[metric] = Precision(num_classes, average='samples', top_k=top_k) 237 | elif metric_abbr == 'R': 238 | metrics[metric] = Recall(num_classes, average='samples', top_k=top_k) 239 | elif metric_abbr == 'RP': 240 | metrics[metric] = RPrecision(top_k=top_k) 241 | elif metric_abbr == 'nDCG': 242 | metrics[metric] = NDCG(top_k=top_k) 243 | # The implementation in torchmetrics stores the prediction/target of all batches, 244 | # which can lead to CUDA out of memory. 245 | # metrics[metric] = RetrievalNormalizedDCG(k=top_k) 246 | elif metric == 'Another-Macro-F1': 247 | metrics[metric] = F1(num_classes, metric_threshold, average='another-macro', 248 | zero=zero, multi_class=multi_class) 249 | elif metric == 'Macro-F1': 250 | metrics[metric] = F1(num_classes, metric_threshold, average='macro', 251 | zero=zero, multi_class=multi_class) 252 | elif metric == 'Micro-F1': 253 | metrics[metric] = F1(num_classes, metric_threshold, average='micro', 254 | zero=zero, multi_class=multi_class) 255 | elif match_metric: 256 | average_type = match_metric.group(1).lower() # Micro 257 | metric_type = match_metric.group(2) # Precision, Recall, or F1 258 | metric_type = metric_type.replace('F1', 'F1Score') # to be determined 259 | metrics[metric] = getattr(torchmetrics.classification, metric_type)( 260 | num_classes, metric_threshold, average=average_type) 261 | else: 262 | raise ValueError( 263 | f'Invalid metric: {metric}. Make sure the metric is in the right format: Macro/Micro-Precision/Recall/F1 (ex. Micro-F1)') 264 | 265 | # If compute_groups is set to True (default), incorrect results may be calculated. 266 | # Please refer to https://github.com/Lightning-AI/metrics/issues/746 for more details. 267 | return MetricCollection(metrics, compute_groups=False) 268 | 269 | 270 | def tabulate_metrics(metric_dict, split): 271 | msg = f'====== {split} dataset evaluation result =======\n' 272 | header = '|'.join([f'{k:^18}' for k in metric_dict.keys()]) 273 | values = '|'.join([f'{x * 100:^18.4f}' if isinstance(x, (np.floating, float)) else f'{x:^18}' for x in metric_dict.values()]) 274 | msg += f"|{header}|\n|{'-----------------:|' * len(metric_dict)}\n|{values}|\n" 275 | return msg 276 | -------------------------------------------------------------------------------- /libmultilabel/nn/model.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | import numpy as np 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | 9 | from ..common_utils import dump_log, argsort_top_k 10 | from ..nn.metrics import get_metrics, tabulate_metrics 11 | 12 | 13 | class MultiLabelModel(pl.LightningModule): 14 | """Abstract class handling Pytorch Lightning training flow 15 | 16 | Args: 17 | num_classes (int): Total number of classes. 18 | learning_rate (float, optional): Learning rate for optimizer. Defaults to 0.0001. 19 | optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'. 20 | momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9. 21 | weight_decay (int, optional): Weight decay factor. Defaults to 0. 22 | metric_threshold (float, optional): Threshold to monitor for metrics. Defaults to 0.5. 23 | monitor_metrics (list, optional): Metrics to monitor while validating. Defaults to None. 24 | log_path (str): Path to a directory holding the log files and models. 25 | silent (bool, optional): Enable silent mode. Defaults to False. 26 | save_k_predictions (int, optional): Save top k predictions on test set. Defaults to 0. 27 | zero (bool, optional) 28 | multi_class (bool, optional) 29 | """ 30 | 31 | def __init__( 32 | self, 33 | num_classes, 34 | learning_rate=0.0001, 35 | optimizer='adam', 36 | momentum=0.9, 37 | weight_decay=0, 38 | metric_threshold=0.5, 39 | monitor_metrics=None, 40 | log_path=None, 41 | silent=False, 42 | save_k_predictions=0, 43 | zero=False, 44 | multi_class=False, 45 | **kwargs 46 | ): 47 | super().__init__() 48 | 49 | # optimizer 50 | self.learning_rate = learning_rate 51 | self.optimizer = optimizer 52 | self.momentum = momentum 53 | self.weight_decay = weight_decay 54 | 55 | # dump log 56 | self.log_path = log_path 57 | self.silent = silent 58 | self.save_k_predictions = save_k_predictions 59 | 60 | # metrics for evaluation 61 | self.eval_metric = get_metrics(metric_threshold, monitor_metrics, num_classes, 62 | zero, multi_class) 63 | 64 | @abstractmethod 65 | def shared_step(self, batch): 66 | """Return loss and predicted logits""" 67 | return NotImplemented 68 | 69 | def configure_optimizers(self): 70 | """Initialize an optimizer for the free parameters of the network. 71 | """ 72 | parameters = [p for p in self.parameters() if p.requires_grad] 73 | optimizer_name = self.optimizer 74 | if optimizer_name == 'sgd': 75 | optimizer = optim.SGD(parameters, self.learning_rate, 76 | momentum=self.momentum, 77 | weight_decay=self.weight_decay) 78 | elif optimizer_name == 'adam': 79 | optimizer = optim.Adam(parameters, 80 | weight_decay=self.weight_decay, 81 | lr=self.learning_rate) 82 | elif optimizer_name == 'adamw': 83 | optimizer = optim.AdamW(parameters, 84 | weight_decay=self.weight_decay, 85 | lr=self.learning_rate) 86 | elif optimizer_name == 'adamax': 87 | optimizer = optim.Adamax(parameters, 88 | weight_decay=self.weight_decay, 89 | lr=self.learning_rate) 90 | else: 91 | raise RuntimeError( 92 | 'Unsupported optimizer: {self.optimizer}') 93 | 94 | torch.nn.utils.clip_grad_value_(parameters, 0.5) 95 | 96 | return optimizer 97 | 98 | def training_step(self, batch, batch_idx): 99 | loss, _ = self.shared_step(batch) 100 | return loss 101 | 102 | def validation_step(self, batch, batch_idx): 103 | return self._shared_eval_step(batch, batch_idx) 104 | 105 | def validation_step_end(self, batch_parts): 106 | return self._shared_eval_step_end(batch_parts) 107 | 108 | def validation_epoch_end(self, step_outputs): 109 | return self._shared_eval_epoch_end(step_outputs, 'val') 110 | 111 | def test_step(self, batch, batch_idx): 112 | return self._shared_eval_step(batch, batch_idx) 113 | 114 | def test_step_end(self, batch_parts): 115 | return self._shared_eval_step_end(batch_parts) 116 | 117 | def test_epoch_end(self, step_outputs): 118 | return self._shared_eval_epoch_end(step_outputs, 'test') 119 | 120 | def _shared_eval_step(self, batch, batch_idx): 121 | loss, pred_logits = self.shared_step(batch) 122 | return {'batch_idx': batch_idx, 123 | 'loss': loss, 124 | 'pred_scores': torch.sigmoid(pred_logits), 125 | 'target': batch['labels']} 126 | 127 | def _shared_eval_step_end(self, batch_parts): 128 | batch_size, num_classes = batch_parts['target'].shape 129 | # `indexes` indicates which index a prediction belongs. `RetrievalNormalizedDCG` 130 | # will compute the mean of nDCG scores over each prediction. 131 | indexes = torch.arange( 132 | batch_size*batch_parts['batch_idx'], batch_size*(batch_parts['batch_idx']+1)) 133 | indexes = indexes.unsqueeze(1).repeat(1, num_classes) 134 | return self.eval_metric.update( 135 | preds=batch_parts['pred_scores'], 136 | target=batch_parts['target'], 137 | indexes=indexes 138 | ) 139 | 140 | def _shared_eval_epoch_end(self, step_outputs, split): 141 | """Get scores such as `Micro-F1`, `Macro-F1`, and monitor metrics defined 142 | in the configuration file in the end of an epoch. 143 | 144 | Args: 145 | step_outputs (list): List of the return values from the val or test step end. 146 | split (str): One of the `val` or `test`. 147 | 148 | Returns: 149 | metric_dict (dict): Scores for all metrics in the dictionary format. 150 | """ 151 | metric_dict = self.eval_metric.compute() 152 | self.log_dict(metric_dict) 153 | for k, v in metric_dict.items(): 154 | metric_dict[k] = v.item() 155 | if self.log_path: 156 | dump_log(metrics=metric_dict, split=split, log_path=self.log_path) 157 | self.print(tabulate_metrics(metric_dict, split)) 158 | self.eval_metric.reset() 159 | return metric_dict 160 | 161 | def predict_step(self, batch, batch_idx, dataloader_idx): 162 | """`predict_step` is triggered when calling `trainer.predict()`. 163 | This function is used to get the top-k labels and their prediction scores. 164 | 165 | Args: 166 | batch (dict): A batch of text and label. 167 | batch_idx (int): Index of current batch. 168 | dataloader_idx (int): Index of current dataloader. 169 | 170 | Returns: 171 | dict: Top k label indexes and the prediction scores. 172 | """ 173 | _, pred_logits = self.shared_step(batch) 174 | pred_scores = pred_logits.detach().cpu().numpy() 175 | k = self.save_k_predictions 176 | top_k_idx = argsort_top_k(pred_scores, k, axis=1) 177 | top_k_scores = np.take_along_axis(pred_scores, top_k_idx, axis=1) 178 | 179 | return {'top_k_pred': top_k_idx, 180 | 'top_k_pred_scores': top_k_scores} 181 | 182 | def print(self, *args, **kwargs): 183 | """Prints only from process 0 and not in silent mode. Use this in any 184 | distributed mode to log only once.""" 185 | 186 | if not self.silent: 187 | # print() in LightningModule to print only from process 0 188 | super().print(*args, **kwargs) 189 | 190 | 191 | class Model(MultiLabelModel): 192 | """A class that implements `MultiLabelModel` for initializing and training a neural network. 193 | 194 | Args: 195 | classes (list): List of class names. 196 | word_dict (torchtext.vocab.Vocab): A vocab object which maps tokens to indices. 197 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). 198 | network (nn.Module): Network (i.e., CAML, KimCNN, or XMLCNN). 199 | log_path (str): Path to a directory holding the log files and models. 200 | enable_ce_loss (bool) 201 | """ 202 | def __init__( 203 | self, 204 | classes, 205 | word_dict, 206 | embed_vecs, 207 | network, 208 | log_path=None, 209 | enable_ce_loss=False, 210 | **kwargs 211 | ): 212 | super().__init__(num_classes=len(classes), log_path=log_path, **kwargs) 213 | self.save_hyperparameters() 214 | self.word_dict = word_dict 215 | self.embed_vecs = embed_vecs 216 | self.classes = classes 217 | self.network = network 218 | if not enable_ce_loss: 219 | self.loss_fn = F.binary_cross_entropy_with_logits 220 | else: 221 | self.loss_fn = F.cross_entropy 222 | 223 | def shared_step(self, batch): 224 | """Return loss and predicted logits of the network. 225 | 226 | Args: 227 | batch (dict): A batch of text and label. 228 | 229 | Returns: 230 | loss (torch.Tensor): Binary cross-entropy between target and predict logits. 231 | pred_logits (torch.Tensor): The predict logits (batch_size, num_classes). 232 | """ 233 | target_labels = batch['labels'] 234 | outputs = self.network(batch) 235 | pred_logits = outputs['logits'] 236 | loss = self.loss_fn(pred_logits, target_labels.float()) 237 | return loss, pred_logits 238 | -------------------------------------------------------------------------------- /libmultilabel/nn/networks/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .bert import BERT 4 | from .bert_attention import BERTAttention 5 | from .caml import CAML 6 | from .kim_cnn import KimCNN 7 | from .xml_cnn import XMLCNN 8 | from .labelwise_attention_networks import BiGRULWAN 9 | from .labelwise_attention_networks import BiLSTMLWAN 10 | from .labelwise_attention_networks import BiLSTMLWMHAN 11 | from .labelwise_attention_networks import CNNLWAN 12 | 13 | 14 | def get_init_weight_func(init_weight): 15 | def init_weight_func(m): 16 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d): 17 | getattr(nn.init, init_weight+ '_')(m.weight) 18 | return init_weight_func 19 | -------------------------------------------------------------------------------- /libmultilabel/nn/networks/bert.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from transformers import AutoModelForSequenceClassification 3 | 4 | from .hierbert import HierarchicalBert 5 | 6 | 7 | class BERT(nn.Module): 8 | """BERT 9 | 10 | Args: 11 | num_classes (int): Total number of classes. 12 | dropout (float): The dropout rate of the word embedding. Defaults to 0.1. 13 | lm_weight (str): Pretrained model name or path. Defaults to 'bert-base-cased'. 14 | """ 15 | def __init__( 16 | self, 17 | num_classes, 18 | dropout=0.1, 19 | lm_weight='bert-base-cased', 20 | hierarchical=False, 21 | max_segments=64, 22 | max_seg_length=128, 23 | **kwargs 24 | ): 25 | super().__init__() 26 | self.lm = AutoModelForSequenceClassification.from_pretrained(lm_weight, 27 | num_labels=num_classes, 28 | hidden_dropout_prob=dropout, 29 | torchscript=True) 30 | self.hierarchical = hierarchical 31 | if self.hierarchical: 32 | segment_encoder = self.lm.bert 33 | model_encoder = HierarchicalBert(encoder=segment_encoder, 34 | max_segments=max_segments, 35 | max_segment_length=max_seg_length) 36 | self.lm.bert = model_encoder 37 | 38 | def forward(self, input): 39 | input_ids = input['input_ids'] 40 | attention_mask = input['attention_mask'] 41 | x = self.lm(input_ids, attention_mask=attention_mask)[0] # (batch_size, num_classes) 42 | return {'logits': x} 43 | -------------------------------------------------------------------------------- /libmultilabel/nn/networks/bert_attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.utils.rnn import pad_sequence 3 | from transformers import AutoModel 4 | 5 | from .modules import LabelwiseAttention, LabelwiseLinearOutput, LabelwiseMultiHeadAttention 6 | 7 | 8 | class BERTAttention(nn.Module): 9 | """BERT + Label-wise Document Attention or Multi-Head Attention 10 | 11 | Args: 12 | num_classes (int): Total number of classes. 13 | dropout (float): The dropout rate of the word embedding. Defaults to 0.2. 14 | lm_weight (str): Pretrained model name or path. Defaults to 'bert-base-cased'. 15 | lm_window (int): Length of the subsequences to be split before feeding them to 16 | the language model. Defaults to 512. 17 | num_heads (int): The number of parallel attention heads. Defaults to 8. 18 | attention_type (str): Type of attention to use (caml or multihead). Defaults to 'multihead'. 19 | attention_dropout (float): The dropout rate for the attention. Defaults to 0.0. 20 | """ 21 | def __init__( 22 | self, 23 | num_classes, 24 | dropout=0.2, 25 | lm_weight='bert-base-cased', 26 | lm_window=512, 27 | num_heads=8, 28 | attention_type='multihead', 29 | attention_dropout=0.0, 30 | **kwargs 31 | ): 32 | super().__init__() 33 | self.lm_window = lm_window 34 | self.attention_type = attention_type 35 | 36 | self.lm = AutoModel.from_pretrained(lm_weight, torchscript=True) 37 | self.embed_drop = nn.Dropout(p=dropout) 38 | 39 | self.attention_type = attention_type 40 | assert attention_type in ['singlehead', 'multihead'], "attention_type must be 'singlehead' or 'multihead'" 41 | if attention_type == 'singlehead': 42 | self.attention = LabelwiseAttention(self.lm.config.hidden_size, num_classes) 43 | else: 44 | self.attention = LabelwiseMultiHeadAttention( 45 | self.lm.config.hidden_size, num_classes, num_heads, attention_dropout) 46 | 47 | # Final layer: create a matrix to use for the #labels binary classifiers 48 | self.output = LabelwiseLinearOutput(self.lm.config.hidden_size, num_classes) 49 | 50 | def lm_feature(self, input_ids): 51 | """BERT takes an input of a sequence of no more than 512 tokens. 52 | Therefore, long sequence are split into subsequences of size `lm_window`, which is a number no greater than 512. 53 | If the split subsequence is shorter than `lm_window`, pad it with the pad token. 54 | 55 | Args: 56 | input_ids (torch.Tensor): Input ids of the sequence with shape (batch_size, sequence_length). 57 | 58 | Returns: 59 | torch.Tensor: The representation of the sequence. 60 | """ 61 | if input_ids.size(-1) <= self.lm_window: 62 | return self.lm(input_ids, attention_mask=input_ids != self.lm.config.pad_token_id)[0] 63 | else: 64 | inputs = [] 65 | batch_indexes = [] 66 | seq_lengths = [] 67 | for token_id in input_ids: 68 | indexes = [] 69 | seq_length = (token_id != self.lm.config.pad_token_id).sum() 70 | seq_lengths.append(seq_length) 71 | for i in range(0, seq_length, self.lm_window): 72 | indexes.append(len(inputs)) 73 | inputs.append(token_id[i: i + self.lm_window]) 74 | batch_indexes.append(indexes) 75 | 76 | padded_inputs = pad_sequence(inputs, batch_first=True) 77 | last_hidden_states = self.lm( 78 | padded_inputs, attention_mask=padded_inputs != self.lm.config.pad_token_id)[0] 79 | 80 | x = [] 81 | for seq_l, mapping in zip(seq_lengths, batch_indexes): 82 | last_hidden_state = last_hidden_states[mapping].view( 83 | -1, last_hidden_states.size(-1))[:seq_l, :] 84 | x.append(last_hidden_state) 85 | return pad_sequence(x, batch_first=True) 86 | 87 | def forward(self, input): 88 | input_ids = input['text'] # (batch_size, sequence_length) 89 | attention_mask = input_ids == self.lm.config.pad_token_id 90 | x = self.lm_feature(input_ids) # (batch_size, sequence_length, lm_hidden_size) 91 | x = self.embed_drop(x) 92 | 93 | # Apply per-label attention. 94 | if self.attention_type == 'singlehead': 95 | logits, attention = self.attention(x) 96 | else: 97 | logits, attention = self.attention(x, attention_mask) 98 | 99 | # Compute a probability for each label 100 | x = self.output(logits) 101 | return {'logits': x, 'attention': attention} 102 | -------------------------------------------------------------------------------- /libmultilabel/nn/networks/caml.py: -------------------------------------------------------------------------------- 1 | from math import floor 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.init import xavier_uniform_ 7 | 8 | 9 | class CAML(nn.Module): 10 | """CAML (Convolutional Attention for Multi-Label classification) 11 | Follows the work of Mullenbach et al. [https://aclanthology.org/N18-1100.pdf] 12 | This class is for reproducing the results in the paper. 13 | Use CNNLWAN instead for better modularization. 14 | 15 | Args: 16 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). 17 | num_classes (int): Total number of classes. 18 | filter_sizes (list): Size of convolutional filters. 19 | num_filter_per_size (int): The number of filters in convolutional layers in each size. Defaults to 50. 20 | dropout (float): The dropout rate of the word embedding. Defaults to 0.2. 21 | """ 22 | def __init__( 23 | self, 24 | embed_vecs, 25 | num_classes, 26 | filter_sizes=None, 27 | num_filter_per_size=50, 28 | dropout=0.2, 29 | ): 30 | super(CAML, self).__init__() 31 | if not filter_sizes and len(filter_sizes) != 1: 32 | raise ValueError(f'CAML expect 1 filter size. Got filter_sizes={filter_sizes}') 33 | filter_size = filter_sizes[0] 34 | 35 | self.embedding = nn.Embedding(len(embed_vecs), embed_vecs.shape[1], padding_idx=0) 36 | self.embedding.weight.data = embed_vecs.clone() 37 | self.embed_drop = nn.Dropout(p=dropout) 38 | 39 | # Initialize conv layer 40 | self.conv = nn.Conv1d(embed_vecs.shape[1], num_filter_per_size, kernel_size=filter_size, padding=int(floor(filter_size/2))) 41 | xavier_uniform_(self.conv.weight) 42 | 43 | """Context vectors for computing attention with 44 | (in_features, out_features) = (num_filter_per_size, num_classes) 45 | """ 46 | self.Q = nn.Linear(num_filter_per_size, num_classes) 47 | xavier_uniform_(self.Q.weight) 48 | 49 | # Final layer: create a matrix to use for the #labels binary classifiers 50 | self.output = nn.Linear(num_filter_per_size, num_classes) 51 | xavier_uniform_(self.output.weight) 52 | 53 | def forward(self, input): 54 | # Get embeddings and apply dropout 55 | x = self.embedding(input['text']) # (batch_size, length, embed_dim) 56 | x = self.embed_drop(x) 57 | x = x.transpose(1, 2) # (batch_size, embed_dim, length) 58 | 59 | """ Apply convolution and nonlinearity (tanh). The shapes are: 60 | - self.conv(x): (batch_size, num_filte_per_size, length) 61 | - x after transposing the first and the second dimension and applying 62 | the activation function: (batch_size, length, num_filte_per_size) 63 | """ 64 | Z = torch.tanh(self.conv(x).transpose(1, 2)) 65 | 66 | """Apply per-label attention. The shapes are: 67 | - Q.weight: (num_classes, num_filte_per_size) 68 | - matrix product of U.weight and x: (batch_size, num_classes, length) 69 | - alpha: (batch_size, num_classes, length) 70 | """ 71 | alpha = torch.softmax(self.Q.weight.matmul(Z.transpose(1, 2)), dim=2) 72 | 73 | # Document representations are weighted sums using the attention 74 | E = alpha.matmul(Z) # (batch_size, num_classes, num_filter_per_size) 75 | 76 | # Compute a probability for each label 77 | logits = self.output.weight.mul(E).sum(dim=2).add(self.output.bias) # (batch_size, num_classes) 78 | 79 | return {'logits': logits, 'attention': alpha} 80 | -------------------------------------------------------------------------------- /libmultilabel/nn/networks/hierbert.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | import numpy as np 6 | from torch import nn 7 | from transformers.file_utils import ModelOutput 8 | 9 | 10 | @dataclass 11 | class SimpleOutput(ModelOutput): 12 | last_hidden_state: torch.FloatTensor = None 13 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 14 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 15 | attentions: Optional[Tuple[torch.FloatTensor]] = None 16 | cross_attentions: Optional[Tuple[torch.FloatTensor]] = None 17 | 18 | 19 | def sinusoidal_init(num_embeddings: int, embedding_dim: int): 20 | # keep dim 0 for padding token position encoding zero vector 21 | position_enc = np.array([ 22 | [pos / np.power(10000, 2 * i / embedding_dim) for i in range(embedding_dim)] 23 | if pos != 0 else np.zeros(embedding_dim) for pos in range(num_embeddings)]) 24 | 25 | position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i 26 | position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1 27 | return torch.from_numpy(position_enc).type(torch.FloatTensor) 28 | 29 | 30 | class HierarchicalBert(nn.Module): 31 | 32 | def __init__(self, encoder, max_segments=64, max_segment_length=128): 33 | super(HierarchicalBert, self).__init__() 34 | supported_models = ['bert', 'roberta', 'deberta'] 35 | assert encoder.config.model_type in supported_models # other model types are not supported so far 36 | # Pre-trained segment (token-wise) encoder, e.g., BERT 37 | self.encoder = encoder 38 | # Specs for the segment-wise encoder 39 | self.hidden_size = encoder.config.hidden_size 40 | self.max_segments = max_segments 41 | self.max_segment_length = max_segment_length 42 | # Init sinusoidal positional embeddings 43 | self.seg_pos_embeddings = nn.Embedding(max_segments + 1, encoder.config.hidden_size, 44 | padding_idx=0, 45 | _weight=sinusoidal_init(max_segments + 1, encoder.config.hidden_size)) 46 | # Init segment-wise transformer-based encoder 47 | self.seg_encoder = nn.Transformer(d_model=encoder.config.hidden_size, 48 | nhead=encoder.config.num_attention_heads, 49 | batch_first=True, dim_feedforward=encoder.config.intermediate_size, 50 | activation=encoder.config.hidden_act, 51 | dropout=encoder.config.hidden_dropout_prob, 52 | layer_norm_eps=encoder.config.layer_norm_eps, 53 | num_encoder_layers=2, num_decoder_layers=0).encoder 54 | 55 | def forward(self, 56 | input_ids=None, 57 | attention_mask=None, 58 | token_type_ids=None, 59 | position_ids=None, 60 | head_mask=None, 61 | inputs_embeds=None, 62 | labels=None, 63 | output_attentions=None, 64 | output_hidden_states=None, 65 | return_dict=None, 66 | ): 67 | # Hypothetical Example 68 | # Batch of 4 documents: (batch_size, n_segments, max_segment_length) --> (4, 64, 128) 69 | # BERT-BASE encoder: 768 hidden units 70 | 71 | # Squash samples and segments into a single axis (batch_size * n_segments, max_segment_length) --> (256, 128) 72 | input_ids_reshape = input_ids.contiguous().view(-1, input_ids.size(-1)) 73 | attention_mask_reshape = attention_mask.contiguous().view(-1, attention_mask.size(-1)) 74 | if token_type_ids is not None: 75 | token_type_ids_reshape = token_type_ids.contiguous().view(-1, token_type_ids.size(-1)) 76 | else: 77 | token_type_ids_reshape = None 78 | 79 | # Encode segments with BERT --> (256, 128, 768) 80 | encoder_outputs = self.encoder(input_ids=input_ids_reshape, 81 | attention_mask=attention_mask_reshape, 82 | token_type_ids=token_type_ids_reshape)[0] 83 | 84 | # Reshape back to (batch_size, n_segments, max_segment_length, output_size) --> (4, 64, 128, 768) 85 | encoder_outputs = encoder_outputs.contiguous().view(input_ids.size(0), self.max_segments, 86 | self.max_segment_length, 87 | self.hidden_size) 88 | 89 | # Gather CLS outputs per segment --> (4, 64, 768) 90 | encoder_outputs = encoder_outputs[:, :, 0] 91 | 92 | # Infer real segments, i.e., mask paddings 93 | seg_mask = (torch.sum(input_ids, 2) != 0).to(input_ids.dtype) 94 | # Infer and collect segment positional embeddings 95 | seg_positions = torch.arange(1, self.max_segments + 1).to(input_ids.device) * seg_mask 96 | # Add segment positional embeddings to segment inputs 97 | encoder_outputs += self.seg_pos_embeddings(seg_positions) 98 | 99 | # Encode segments with segment-wise transformer 100 | seg_encoder_outputs = self.seg_encoder(encoder_outputs) 101 | 102 | # Collect document representation 103 | outputs, _ = torch.max(seg_encoder_outputs, 1) 104 | 105 | return SimpleOutput(last_hidden_state=outputs, hidden_states=outputs) 106 | -------------------------------------------------------------------------------- /libmultilabel/nn/networks/kim_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .modules import Embedding, CNNEncoder 5 | 6 | 7 | class KimCNN(nn.Module): 8 | """KimCNN 9 | 10 | Args: 11 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). 12 | num_classes (int): Total number of classes. 13 | filter_sizes (list): The size of convolutional filters. 14 | num_filter_per_size (int): The number of filters in convolutional layers in each size. Defaults to 128. 15 | embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2. 16 | encoder_dropout (float): The dropout rate of the encoder output. Defaults to 0. 17 | activation (str): Activation function to be used. Defaults to 'relu'. 18 | """ 19 | def __init__( 20 | self, 21 | embed_vecs, 22 | num_classes, 23 | filter_sizes=None, 24 | num_filter_per_size=128, 25 | embed_dropout=0.2, 26 | encoder_dropout=0, 27 | activation='relu' 28 | ): 29 | super(KimCNN, self).__init__() 30 | self.embedding = Embedding(embed_vecs, embed_dropout) 31 | self.encoder = CNNEncoder(embed_vecs.shape[1], filter_sizes, 32 | num_filter_per_size, activation, 33 | encoder_dropout, num_pool=1) 34 | conv_output_size = num_filter_per_size * len(filter_sizes) 35 | self.linear = nn.Linear(conv_output_size, num_classes) 36 | 37 | def forward(self, input): 38 | x = self.embedding(input['text']) # (batch_size, length, embed_dim) 39 | x = self.encoder(x) # (batch_size, num_filter, 1) 40 | x = torch.squeeze(x, 2) # (batch_size, num_filter) 41 | x = self.linear(x) # (batch_size, num_classes) 42 | return {'logits': x} 43 | -------------------------------------------------------------------------------- /libmultilabel/nn/networks/labelwise_attention_networks.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch.nn as nn 4 | 5 | from .modules import Embedding, GRUEncoder, LSTMEncoder, CNNEncoder, LabelwiseAttention, LabelwiseMultiHeadAttention, LabelwiseLinearOutput 6 | 7 | 8 | class LabelwiseAttentionNetwork(ABC, nn.Module): 9 | """Base class for Labelwise Attention Network 10 | 11 | Args: 12 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). 13 | num_classes (int): Total number of classes. 14 | embed_dropout (float): The dropout rate of the word embedding. 15 | encoder_dropout (float): The dropout rate of the encoder output. 16 | hidden_dim (int): The output dimension of the encoder. 17 | """ 18 | 19 | def __init__(self, embed_vecs, num_classes, embed_dropout, encoder_dropout, hidden_dim): 20 | super(LabelwiseAttentionNetwork, self).__init__() 21 | self.embedding = Embedding(embed_vecs, embed_dropout) 22 | self.encoder = self._get_encoder(embed_vecs.shape[1], encoder_dropout) 23 | self.attention = self._get_attention() 24 | self.output = LabelwiseLinearOutput(hidden_dim, num_classes) 25 | 26 | @abstractmethod 27 | def forward(self, input): 28 | raise NotImplementedError 29 | 30 | @abstractmethod 31 | def _get_encoder(self, input_size, dropout): 32 | raise NotImplementedError 33 | 34 | @abstractmethod 35 | def _get_attention(self): 36 | raise NotImplementedError 37 | 38 | 39 | class RNNLWAN(LabelwiseAttentionNetwork): 40 | """Base class for RNN Labelwise Attention Network 41 | """ 42 | 43 | def forward(self, input): 44 | x = self.embedding(input['text']) # (batch_size, sequence_length, embed_dim) 45 | x = self.encoder(x, input['length']) # (batch_size, sequence_length, hidden_dim) 46 | x, _ = self.attention(x) # (batch_size, num_classes, hidden_dim) 47 | x = self.output(x) # (batch_size, num_classes) 48 | return {'logits': x} 49 | 50 | 51 | class BiGRULWAN(RNNLWAN): 52 | """BiGRU Labelwise Attention Network 53 | 54 | Args: 55 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). 56 | num_classes (int): Total number of classes. 57 | rnn_dim (int): The size of bidirectional hidden layers. The hidden size of the GRU network 58 | is set to rnn_dim//2. Defaults to 512. 59 | rnn_layers (int): The number of recurrent layers. Defaults to 1. 60 | embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2. 61 | encoder_dropout (float): The dropout rate of the encoder output. Defaults to 0. 62 | """ 63 | 64 | def __init__( 65 | self, 66 | embed_vecs, 67 | num_classes, 68 | rnn_dim=512, 69 | rnn_layers=1, 70 | embed_dropout=0.2, 71 | encoder_dropout=0 72 | ): 73 | self.num_classes = num_classes 74 | self.rnn_dim = rnn_dim 75 | self.rnn_layers = rnn_layers 76 | super(BiGRULWAN, self).__init__(embed_vecs, num_classes, embed_dropout, 77 | encoder_dropout, rnn_dim) 78 | 79 | def _get_encoder(self, input_size, dropout): 80 | assert self.rnn_dim % 2 == 0, """`rnn_dim` should be even.""" 81 | return GRUEncoder(input_size, self.rnn_dim // 2, self.rnn_layers, dropout) 82 | 83 | def _get_attention(self): 84 | return LabelwiseAttention(self.rnn_dim, self.num_classes) 85 | 86 | 87 | class BiLSTMLWAN(RNNLWAN): 88 | """BiLSTM Labelwise Attention Network 89 | 90 | Args: 91 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). 92 | num_classes (int): Total number of classes. 93 | rnn_dim (int): The size of bidirectional hidden layers. The hidden size of the LSTM network 94 | is set to rnn_dim//2. Defaults to 512. 95 | rnn_layers (int): The number of recurrent layers. Defaults to 1. 96 | embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2. 97 | encoder_dropout (float): The dropout rate of the encoder output. Defaults to 0. 98 | """ 99 | 100 | def __init__( 101 | self, 102 | embed_vecs, 103 | num_classes, 104 | rnn_dim=512, 105 | rnn_layers=1, 106 | embed_dropout=0.2, 107 | encoder_dropout=0 108 | ): 109 | self.num_classes = num_classes 110 | self.rnn_dim = rnn_dim 111 | self.rnn_layers = rnn_layers 112 | super(BiLSTMLWAN, self).__init__(embed_vecs, num_classes, embed_dropout, 113 | encoder_dropout, rnn_dim) 114 | 115 | def _get_encoder(self, input_size, dropout): 116 | assert self.rnn_dim % 2 == 0, """`rnn_dim` should be even.""" 117 | return LSTMEncoder(input_size, self.rnn_dim // 2, self.rnn_layers, dropout) 118 | 119 | def _get_attention(self): 120 | return LabelwiseAttention(self.rnn_dim, self.num_classes) 121 | 122 | 123 | class BiLSTMLWMHAN(LabelwiseAttentionNetwork): 124 | """BiLSTM Labelwise Multihead Attention Network 125 | 126 | Args: 127 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). 128 | num_classes (int): Total number of classes. 129 | rnn_dim (int): The size of bidirectional hidden layers. The hidden size of the LSTM network 130 | is set to rnn_dim//2. Defaults to 512. 131 | rnn_layers (int): The number of recurrent layers. Defaults to 1. 132 | embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2. 133 | encoder_dropout (float): The dropout rate of the encoder output. Defaults to 0. 134 | num_heads (int): The number of parallel attention heads. Defaults to 8. 135 | attention_dropout (float): The dropout rate for the attention. Defaults to 0.0. 136 | """ 137 | 138 | def __init__( 139 | self, 140 | embed_vecs, 141 | num_classes, 142 | rnn_dim=512, 143 | rnn_layers=1, 144 | embed_dropout=0.2, 145 | encoder_dropout=0, 146 | num_heads=8, 147 | attention_dropout=0.0 148 | ): 149 | self.num_classes = num_classes 150 | self.rnn_dim = rnn_dim 151 | self.rnn_layers = rnn_layers 152 | self.num_heads = num_heads 153 | self.attention_dropout = attention_dropout 154 | super(BiLSTMLWMHAN, self).__init__(embed_vecs, num_classes, embed_dropout, 155 | encoder_dropout, rnn_dim) 156 | 157 | def _get_encoder(self, input_size, dropout): 158 | assert self.rnn_dim % 2 == 0, """`rnn_dim` should be even.""" 159 | return LSTMEncoder(input_size, self.rnn_dim // 2, 160 | self.rnn_layers, dropout) 161 | 162 | def _get_attention(self): 163 | return LabelwiseMultiHeadAttention(self.rnn_dim, self.num_classes, self.num_heads, self.attention_dropout) 164 | 165 | def forward(self, input): 166 | x = self.embedding(input['text']) # (batch_size, sequence_length, embed_dim) 167 | x = self.encoder(x, input['length']) # (batch_size, sequence_length, hidden_dim) 168 | x, _ = self.attention(x, attention_mask=input['text'] == 0) # (batch_size, num_classes, hidden_dim) 169 | x = self.output(x) # (batch_size, num_classes) 170 | return {'logits': x} 171 | 172 | 173 | class CNNLWAN(LabelwiseAttentionNetwork): 174 | """CNN Labelwise Attention Network 175 | 176 | Args: 177 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). 178 | num_classes (int): Total number of classes. 179 | filter_sizes (list): Size of convolutional filters. 180 | num_filter_per_size (int): The number of filters in convolutional layers in each size. Defaults to 50. 181 | embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2. 182 | encoder_dropout (float): The dropout rate of the encoder output. Defaults to 0. 183 | activation (str): Activation function to be used. Defaults to 'tanh'. 184 | """ 185 | 186 | def __init__( 187 | self, 188 | embed_vecs, 189 | num_classes, 190 | filter_sizes=None, 191 | num_filter_per_size=50, 192 | embed_dropout=0.2, 193 | encoder_dropout=0, 194 | activation='tanh' 195 | ): 196 | self.num_classes = num_classes 197 | self.filter_sizes = filter_sizes 198 | self.num_filter_per_size = num_filter_per_size 199 | self.activation = activation 200 | self.hidden_dim = num_filter_per_size * len(filter_sizes) 201 | super(CNNLWAN, self).__init__(embed_vecs, num_classes, embed_dropout, 202 | encoder_dropout, self.hidden_dim) 203 | 204 | def _get_encoder(self, input_size, dropout): 205 | return CNNEncoder(input_size, self.filter_sizes, 206 | self.num_filter_per_size, self.activation, dropout, 207 | channel_last=True) 208 | 209 | def _get_attention(self): 210 | return LabelwiseAttention(self.hidden_dim, self.num_classes) 211 | 212 | def forward(self, input): 213 | x = self.embedding(input['text']) # (batch_size, sequence_length, embed_dim) 214 | x = self.encoder(x) # (batch_size, sequence_length, hidden_dim) 215 | x, _ = self.attention(x) # (batch_size, num_classes, hidden_dim) 216 | x = self.output(x) # (batch_size, num_classes) 217 | return {'logits': x} 218 | -------------------------------------------------------------------------------- /libmultilabel/nn/networks/modules.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 7 | 8 | 9 | class Embedding(nn.Module): 10 | """Embedding layer with dropout 11 | 12 | Args: 13 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). 14 | dropout (float): The dropout rate of the word embedding. Defaults to 0.2. 15 | """ 16 | 17 | def __init__(self, embed_vecs, dropout=0.2): 18 | super(Embedding, self).__init__() 19 | self.embedding = nn.Embedding.from_pretrained( 20 | embed_vecs, freeze=False, padding_idx=0) 21 | self.dropout = nn.Dropout(dropout) 22 | 23 | def forward(self, input): 24 | return self.dropout(self.embedding(input)) 25 | 26 | 27 | class RNNEncoder(ABC, nn.Module): 28 | """Base class of RNN encoder with dropout 29 | 30 | Args: 31 | input_size (int): The number of expected features in the input. 32 | hidden_size (int): The number of features in the hidden state. 33 | num_layers (int): The number of recurrent layers. 34 | dropout (float): The dropout rate of the encoder. Defaults to 0. 35 | """ 36 | 37 | def __init__(self, input_size, hidden_size, num_layers, dropout=0): 38 | super(RNNEncoder, self).__init__() 39 | self.rnn = self._get_rnn(input_size, hidden_size, num_layers) 40 | self.dropout = nn.Dropout(dropout) 41 | 42 | def forward(self, input, length, **kwargs): 43 | self.rnn.flatten_parameters() 44 | idx = torch.argsort(length, descending=True) 45 | length_clamped = length[idx].cpu().clamp(min=1) # avoid the empty text with length 0 46 | packed_input = pack_padded_sequence( 47 | input[idx], length_clamped, batch_first=True) 48 | outputs, _ = pad_packed_sequence( 49 | self.rnn(packed_input)[0], batch_first=True) 50 | return self.dropout(outputs[torch.argsort(idx)]) 51 | 52 | @abstractmethod 53 | def _get_rnn(self, input_size, hidden_size, num_layers): 54 | raise NotImplementedError 55 | 56 | 57 | class GRUEncoder(RNNEncoder): 58 | """Bi-directional GRU encoder with dropout 59 | 60 | Args: 61 | input_size (int): The number of expected features in the input. 62 | hidden_size (int): The number of features in the hidden state. 63 | num_layers (int): The number of recurrent layers. 64 | dropout (float): The dropout rate of the encoder. Defaults to 0. 65 | """ 66 | 67 | def __init__(self, input_size, hidden_size, num_layers, dropout=0): 68 | super(GRUEncoder, self).__init__(input_size, hidden_size, num_layers, 69 | dropout) 70 | 71 | def _get_rnn(self, input_size, hidden_size, num_layers): 72 | return nn.GRU(input_size, hidden_size, num_layers, 73 | batch_first=True, bidirectional=True) 74 | 75 | 76 | class LSTMEncoder(RNNEncoder): 77 | """Bi-directional LSTM encoder with dropout 78 | 79 | Args: 80 | input_size (int): The number of expected features in the input. 81 | hidden_size (int): The number of features in the hidden state. 82 | num_layers (int): The number of recurrent layers. 83 | dropout (float): The dropout rate of the encoder. Defaults to 0. 84 | """ 85 | 86 | def __init__(self, input_size, hidden_size, num_layers, dropout=0): 87 | super(LSTMEncoder, self).__init__(input_size, hidden_size, num_layers, 88 | dropout) 89 | 90 | def _get_rnn(self, input_size, hidden_size, num_layers): 91 | return nn.LSTM(input_size, hidden_size, num_layers, 92 | batch_first=True, bidirectional=True) 93 | 94 | 95 | class CNNEncoder(nn.Module): 96 | """Multi-filter-size CNN encoder for text classification with max-pooling 97 | 98 | Args: 99 | input_size (int): The number of expected features in the input. 100 | filter_sizes (list): Size of convolutional filters. 101 | num_filter_per_size (int): The number of filters in convolutional layers in each size. Defaults to 128. 102 | activation (str): Activation function to be used. Defaults to 'relu'. 103 | dropout (float): The dropout rate of the encoder. Defaults to 0. 104 | num_pool (int): The number of pools for max-pooling. 105 | If num_pool = 0, do nothing. 106 | If num_pool = 1, do typical max-pooling. 107 | If num_pool > 1, do adaptive max-pooling. 108 | channel_last (bool): Whether to transpose the dimension from (batch_size, num_channel, length) to (batch_size, length, num_channel) 109 | """ 110 | 111 | def __init__(self, input_size, filter_sizes, num_filter_per_size, 112 | activation, dropout=0, num_pool=0, channel_last=False): 113 | super(CNNEncoder, self).__init__() 114 | if not filter_sizes: 115 | raise ValueError(f'CNNEncoder expect non-empty filter_sizes. ' 116 | f'Got: {filter_sizes}') 117 | self.channel_last = channel_last 118 | self.convs = nn.ModuleList() 119 | for filter_size in filter_sizes: 120 | conv = nn.Conv1d( 121 | in_channels=input_size, 122 | out_channels=num_filter_per_size, 123 | kernel_size=filter_size) 124 | self.convs.append(conv) 125 | self.num_pool = num_pool 126 | if num_pool > 1: 127 | self.pool = nn.AdaptiveMaxPool1d(num_pool) 128 | self.activation = getattr(torch, activation, getattr(F, activation)) 129 | self.dropout = nn.Dropout(dropout) 130 | 131 | def forward(self, input): 132 | h = input.transpose(1, 2) # (batch_size, input_size, length) 133 | h_list = [] 134 | for conv in self.convs: 135 | h_sub = conv(h) # (batch_size, num_filter, length) 136 | if self.num_pool == 1: 137 | h_sub = F.max_pool1d(h_sub, h_sub.shape[2]) # (batch_size, num_filter, 1) 138 | elif self.num_pool > 1: 139 | h_sub = self.pool(h_sub) # (batch_size, num_filter, num_pool) 140 | h_list.append(h_sub) 141 | h = torch.cat(h_list, 1) # (batch_size, total_num_filter, *) 142 | if self.channel_last: 143 | h = h.transpose(1, 2) # (batch_size, *, total_num_filter) 144 | h = self.activation(h) 145 | return self.dropout(h) 146 | 147 | 148 | class LabelwiseAttention(nn.Module): 149 | """Applies attention technique to summarize the sequence for each label 150 | See `Explainable Prediction of Medical Codes from Clinical Text `_ 151 | 152 | Args: 153 | input_size (int): The number of expected features in the input. 154 | num_classes (int): Total number of classes. 155 | """ 156 | def __init__(self, input_size, num_classes): 157 | super(LabelwiseAttention, self).__init__() 158 | self.attention = nn.Linear(input_size, num_classes, bias=False) 159 | 160 | def forward(self, input): 161 | attention = self.attention(input).transpose(1, 2) # (batch_size, num_classes, sequence_length) 162 | attention = F.softmax(attention, -1) 163 | logits = torch.bmm(attention, input) # (batch_size, num_classes, hidden_dim) 164 | return logits, attention 165 | 166 | 167 | class LabelwiseMultiHeadAttention(nn.Module): 168 | """Labelwise multi-head attention 169 | 170 | Args: 171 | input_size (int): The number of expected features in the input. 172 | num_classes (int): Total number of classes. 173 | num_heads (int): The number of parallel attention heads. 174 | attention_dropout (float): The dropout rate for the attention. Defaults to 0.0. 175 | """ 176 | def __init__(self, input_size, num_classes, num_heads, attention_dropout=0.0): 177 | super(LabelwiseMultiHeadAttention, self).__init__() 178 | self.attention = nn.MultiheadAttention(embed_dim=input_size, num_heads=num_heads, dropout=attention_dropout) 179 | self.Q = nn.Linear(input_size, num_classes) 180 | 181 | def forward(self, input, attention_mask=None): 182 | key = value = input.permute(1, 0, 2) # (sequence_length, batch_size, hidden_dim) 183 | query = self.Q.weight.repeat(input.size(0), 1, 1).transpose( 184 | 0, 1) # (num_classes, batch_size, hidden_dim) 185 | 186 | logits, attention = self.attention(query, key, value, key_padding_mask=attention_mask) 187 | logits = logits.permute(1, 0, 2) # (batch_size, num_classes, hidden_dim) 188 | return logits, attention 189 | 190 | 191 | class LabelwiseLinearOutput(nn.Module): 192 | """Applies a linear transformation to the incoming data for each label 193 | 194 | Args: 195 | input_size (int): The number of expected features in the input. 196 | num_classes (int): Total number of classes. 197 | """ 198 | 199 | def __init__(self, input_size, num_classes): 200 | super(LabelwiseLinearOutput, self).__init__() 201 | self.output = nn.Linear(input_size, num_classes) 202 | 203 | def forward(self, input): 204 | return (self.output.weight * input).sum(dim=-1) + self.output.bias 205 | -------------------------------------------------------------------------------- /libmultilabel/nn/networks/xml_cnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .modules import Embedding, CNNEncoder 6 | 7 | 8 | class XMLCNN(nn.Module): 9 | """XML-CNN 10 | 11 | Args: 12 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). 13 | num_classes (int): Total number of classes. 14 | embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2. 15 | hidden_dropout (float): The dropout rate of the hidden layer output. Defaults to 0. 16 | filter_sizes (list): Size of convolutional filters. 17 | hidden_dim (int): Dimension of the hidden layer. Defaults to 512. 18 | num_filter_per_size (int): The number of filters in convolutional layers in each size. Defaults to 256. 19 | num_pool (int): The number of pool for dynamic max-pooling. Defaults to 2. 20 | activation (str): Activation function to be used. Defaults to 'relu'. 21 | """ 22 | def __init__( 23 | self, 24 | embed_vecs, 25 | num_classes, 26 | embed_dropout=0.2, 27 | hidden_dropout=0, 28 | filter_sizes=None, 29 | hidden_dim=512, 30 | num_filter_per_size=256, 31 | num_pool=2, 32 | activation='relu' 33 | ): 34 | super(XMLCNN, self).__init__() 35 | self.embedding = Embedding(embed_vecs, embed_dropout) 36 | self.encoder = CNNEncoder(embed_vecs.shape[1], filter_sizes, 37 | num_filter_per_size, activation, 38 | num_pool=num_pool) 39 | total_output_size = len(filter_sizes) * num_filter_per_size * num_pool 40 | self.dropout = nn.Dropout(hidden_dropout) 41 | self.linear1 = nn.Linear(total_output_size, hidden_dim) 42 | self.linear2 = nn.Linear(hidden_dim, num_classes) 43 | self.activation = getattr(torch, activation, getattr(F, activation)) 44 | 45 | def forward(self, input): 46 | x = self.embedding(input['text']) # (batch_size, length, embed_dim) 47 | x = self.encoder(x) # (batch_size, num_filter, num_pool) 48 | x = x.view(x.shape[0], -1) # (batch_size, num_filter * num_pool) 49 | x = self.activation(self.linear1(x)) 50 | x = self.dropout(x) 51 | x = self.linear2(x) 52 | return {'logits': x} 53 | -------------------------------------------------------------------------------- /libmultilabel/nn/nn_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 7 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 8 | from pytorch_lightning.utilities.seed import seed_everything 9 | 10 | from ..nn import networks 11 | from ..nn.model import Model 12 | 13 | 14 | def init_device(use_cpu=False): 15 | """Initialize device to CPU if `use_cpu` is set to True otherwise GPU. 16 | 17 | Args: 18 | use_cpu (bool, optional): Whether to use CPU or not. Defaults to False. 19 | 20 | Returns: 21 | torch.device: One of cuda or cpu. 22 | """ 23 | 24 | if not use_cpu and torch.cuda.is_available(): 25 | # Set a debug environment variable CUBLAS_WORKSPACE_CONFIG to ":16:8" (may limit overall performance) or ":4096:8" (will increase library footprint in GPU memory by approximately 24MiB). 26 | # https://docs.nvidia.com/cuda/cublas/index.html 27 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" 28 | device = torch.device('cuda') 29 | else: 30 | device = torch.device('cpu') 31 | # https://github.com/pytorch/pytorch/issues/11201 32 | torch.multiprocessing.set_sharing_strategy('file_system') 33 | logging.info(f'Using device: {device}') 34 | return device 35 | 36 | 37 | def init_model(model_name, 38 | network_config, 39 | classes, 40 | word_dict, 41 | embed_vecs, 42 | init_weight=None, 43 | log_path=None, 44 | learning_rate=0.0001, 45 | optimizer='adam', 46 | momentum=0.9, 47 | weight_decay=0, 48 | metric_threshold=0.5, 49 | monitor_metrics=None, 50 | silent=False, 51 | save_k_predictions=0, 52 | zero=False, 53 | multi_class=False, 54 | enable_ce_loss=False, 55 | hierarchical=False): 56 | """Initialize a `Model` class for initializing and training a neural network. 57 | 58 | Args: 59 | model_name (str): Model to be used such as KimCNN. 60 | network_config (dict): Configuration for defining the network. 61 | classes (list): List of class names. 62 | word_dict (torchtext.vocab.Vocab): A vocab object which maps tokens to indices. 63 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). 64 | init_weight (str): Weight initialization method from `torch.nn.init`. 65 | For example, the `init_weight` of `torch.nn.init.kaiming_uniform_` 66 | is `kaiming_uniform`. Defaults to None. 67 | log_path (str): Path to a directory holding the log files and models. 68 | learning_rate (float, optional): Learning rate for optimizer. Defaults to 0.0001. 69 | optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'. 70 | momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9. 71 | weight_decay (int, optional): Weight decay factor. Defaults to 0. 72 | metric_threshold (float, optional): Threshold to monitor for metrics. Defaults to 0.5. 73 | monitor_metrics (list, optional): Metrics to monitor while validating. Defaults to None. 74 | silent (bool, optional): Enable silent mode. Defaults to False. 75 | save_k_predictions (int, optional): Save top k predictions on test set. Defaults to 0. 76 | zero (bool, optional) 77 | multi_class (bool, optional) 78 | enable_ce_loss (bool, optional) 79 | hierarchical (bool, optional) 80 | 81 | Returns: 82 | Model: A class that implements `MultiLabelModel` for initializing and training a neural network. 83 | """ 84 | 85 | network = getattr(networks, model_name)( 86 | embed_vecs=embed_vecs, 87 | num_classes=len(classes), 88 | hierarchical=hierarchical, 89 | **dict(network_config) 90 | ) 91 | 92 | if init_weight is not None: 93 | init_weight = networks.get_init_weight_func( 94 | init_weight=init_weight) 95 | network.apply(init_weight) 96 | 97 | model = Model( 98 | classes=classes, 99 | word_dict=word_dict, 100 | embed_vecs=embed_vecs, 101 | network=network, 102 | log_path=log_path, 103 | learning_rate=learning_rate, 104 | optimizer=optimizer, 105 | momentum=momentum, 106 | weight_decay=weight_decay, 107 | metric_threshold=metric_threshold, 108 | monitor_metrics=monitor_metrics, 109 | silent=silent, 110 | save_k_predictions=save_k_predictions, 111 | zero=zero, 112 | multi_class=multi_class, 113 | enable_ce_loss=enable_ce_loss 114 | ) 115 | return model 116 | 117 | 118 | def init_trainer(checkpoint_dir, 119 | epochs=10000, 120 | patience=5, 121 | mode='max', 122 | val_metric='P@1', 123 | silent=False, 124 | use_cpu=False, 125 | limit_train_batches=1.0, 126 | limit_val_batches=1.0, 127 | limit_test_batches=1.0, 128 | search_params=False, 129 | save_checkpoints=True, 130 | accumulate_grad_batches=1): 131 | """Initialize a torch lightning trainer. 132 | 133 | Args: 134 | checkpoint_dir (str): Directory for saving models and log. 135 | epochs (int): Number of epochs to train. Defaults to 10000. 136 | patience (int): Number of epochs to wait for improvement before early stopping. Defaults to 5. 137 | mode (str): One of [min, max]. Decides whether the val_metric is minimizing or maximizing. 138 | val_metric (str): The metric to monitor for early stopping. Defaults to 'P@1'. 139 | silent (bool): Enable silent mode. Defaults to False. 140 | use_cpu (bool): Disable CUDA. Defaults to False. 141 | limit_train_batches (Union[int, float]): Percentage of training dataset to use. Defaults to 1.0. 142 | limit_val_batches (Union[int, float]): Percentage of validation dataset to use. Defaults to 1.0. 143 | limit_test_batches (Union[int, float]): Percentage of test dataset to use. Defaults to 1.0. 144 | search_params (bool): Enable pytorch-lightning trainer to report the results to ray tune 145 | on validation end during hyperparameter search. Defaults to False. 146 | save_checkpoints (bool): Whether to save the last and the best checkpoint or not. Defaults to True. 147 | accumulate_grad_batches (int) 148 | 149 | Returns: 150 | pl.Trainer: A torch lightning trainer. 151 | """ 152 | 153 | callbacks = [] 154 | if save_checkpoints: 155 | callbacks += [ModelCheckpoint(dirpath=checkpoint_dir, filename='best_model', 156 | save_last=True, save_top_k=1, 157 | monitor=val_metric, mode=mode)] 158 | if search_params: 159 | from ray.tune.integration.pytorch_lightning import TuneReportCallback 160 | callbacks += [TuneReportCallback({f'val_{val_metric}': val_metric}, on="validation_end")] 161 | 162 | trainer = pl.Trainer(logger=False, num_sanity_val_steps=0, 163 | gpus=0 if use_cpu else 1, 164 | enable_progress_bar=False if silent else True, 165 | max_epochs=epochs, 166 | callbacks=callbacks, 167 | limit_train_batches=limit_train_batches, 168 | limit_val_batches=limit_val_batches, 169 | limit_test_batches=limit_test_batches, 170 | accumulate_grad_batches=accumulate_grad_batches, 171 | deterministic=True) 172 | return trainer 173 | 174 | 175 | def set_seed(seed): 176 | """Set seeds for numpy and pytorch. 177 | 178 | Args: 179 | seed (int): Random seed. 180 | """ 181 | 182 | if seed is not None: 183 | if seed >= 0: 184 | seed_everything(seed=seed, workers=True) 185 | else: 186 | logging.warning('the random seed should be a non-negative integer') 187 | 188 | 189 | from transformers import TrainingArguments 190 | def init_training_args(config): 191 | """Initialize huggingface training arguments. 192 | 193 | Args: 194 | config (dict) 195 | 196 | Returns: 197 | dict: huggingface training arguments. 198 | """ 199 | 200 | training_args = TrainingArguments( 201 | output_dir = \ 202 | f"{config.result_dir}/{config.data_name}/{config.network_config['lm_weight']}/seed_{config.seed}", 203 | ) 204 | training_args.model_name_or_path = config.network_config['lm_weight'] 205 | training_args.do_lower_case = True 206 | training_args.do_train = True 207 | training_args.do_eval = True 208 | training_args.do_pred = True 209 | training_args.overwrite_output_dir = True 210 | training_args.load_best_model_at_end = True 211 | training_args.metric_for_best_model = config.val_metric 212 | training_args.greater_is_better = True 213 | training_args.evaluation_strategy = 'epoch' 214 | training_args.save_strategy = 'epoch' 215 | training_args.save_total_limit = 5 216 | training_args.num_train_epochs = config.epochs 217 | training_args.learning_rate = config.learning_rate 218 | training_args.per_device_train_batch_size = config.batch_size 219 | training_args.per_device_eval_batch_size = config.batch_size 220 | training_args.seed = config.seed 221 | if not config.hierarchical: # hierarchical methods will lead to errors 222 | training_args.fp16 = True 223 | training_args.fp16_full_eval = True 224 | training_args.gradient_accumulation_steps = config.accumulate_grad_batches 225 | training_args.eval_accumulation_steps = config.accumulate_grad_batches 226 | return training_args 227 | 228 | 229 | import numpy as np 230 | from sklearn.metrics import f1_score 231 | from transformers import EvalPrediction 232 | def compute_metrics(p: EvalPrediction): 233 | logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 234 | preds = np.argmax(logits, axis=1) 235 | macro_f1 = f1_score(y_true=p.label_ids, y_pred=preds, average='macro', zero_division=0) 236 | micro_f1 = f1_score(y_true=p.label_ids, y_pred=preds, average='micro', zero_division=0) 237 | return {'Macro-F1': macro_f1, 'Micro-F1': micro_f1} 238 | -------------------------------------------------------------------------------- /linear_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from math import ceil 4 | 5 | import numpy as np 6 | 7 | import libmultilabel.linear as linear 8 | from libmultilabel.common_utils import dump_log, argsort_top_k 9 | 10 | 11 | def linear_test(config, model, datasets): 12 | metrics = linear.get_metrics( 13 | config.metric_threshold, 14 | config.monitor_metrics, 15 | datasets['test']['y'].shape[1], 16 | config.zero, 17 | config.multi_class 18 | ) 19 | num_instance = datasets['test']['x'].shape[0] 20 | 21 | k = config.save_k_predictions 22 | top_k_idx = np.zeros((num_instance, k), dtype='i') 23 | top_k_scores = np.zeros((num_instance, k), dtype='d') 24 | 25 | for i in range(ceil(num_instance / config.eval_batch_size)): 26 | slice = np.s_[i*config.eval_batch_size:(i+1)*config.eval_batch_size] 27 | preds = linear.predict_values(model, datasets['test']['x'][slice]) 28 | target = datasets['test']['y'][slice].toarray() 29 | metrics.update(preds, target) 30 | 31 | if k > 0: 32 | top_k_idx[slice] = argsort_top_k(preds, k, axis=1) 33 | top_k_scores[slice] = np.take_along_axis( 34 | preds, top_k_idx[slice], axis=1) 35 | 36 | metric_dict = metrics.compute() 37 | return (metric_dict, top_k_idx, top_k_scores) 38 | 39 | 40 | def linear_train(datasets, config): 41 | techniques = {'1vsrest': linear.train_1vsrest, 42 | 'thresholding': linear.train_thresholding, 43 | 'cost_sensitive': linear.train_cost_sensitive, 44 | 'cost_sensitive_micro': linear.train_cost_sensitive_micro} 45 | model = techniques[config.linear_technique]( 46 | datasets['train']['y'], 47 | datasets['train']['x'], 48 | config.liblinear_options, 49 | ) 50 | return model 51 | 52 | 53 | def linear_run(config): 54 | if config.seed is not None: 55 | np.random.seed(config.seed) 56 | 57 | if config.eval: 58 | preprocessor, model = linear.load_pipeline(config.checkpoint_path) 59 | datasets = preprocessor.load_data( 60 | config.train_path, config.test_path, config.eval) 61 | else: 62 | preprocessor = linear.Preprocessor(data_format=config.data_format) 63 | datasets = preprocessor.load_data( 64 | config.train_path, 65 | config.test_path, 66 | config.eval, 67 | config.label_file, 68 | config.include_test_labels, 69 | config.remove_no_label_data) 70 | model = linear_train(datasets, config) 71 | linear.save_pipeline(config.checkpoint_dir, preprocessor, model) 72 | 73 | if os.path.exists(config.test_path): 74 | metric_dict, top_k_idx, top_k_scores = linear_test( 75 | config, model, datasets) 76 | 77 | dump_log(config=config, metrics=metric_dict, 78 | split='test', log_path=config.log_path) 79 | print(linear.tabulate_metrics(metric_dict, 'test')) 80 | 81 | if config.save_k_predictions > 0: 82 | classes = preprocessor.binarizer.classes_ 83 | with open(config.predict_out_path, 'w') as fp: 84 | for idx, score in zip(top_k_idx, top_k_scores): 85 | out_str = ' '.join([f'{classes[i]}:{s:.4}' for i, s in zip( 86 | idx, score)]) 87 | fp.write(out_str+'\n') 88 | logging.info(f'Saved predictions to: {config.predict_out_path}') 89 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from datetime import datetime 5 | from pathlib import Path 6 | 7 | import yaml 8 | 9 | from libmultilabel.common_utils import Timer, AttributeDict 10 | 11 | 12 | def add_all_arguments(parser): 13 | # path / directory 14 | parser.add_argument('--data_dir', default='./data/rcv1', 15 | help='The directory to load data (default: %(default)s)') 16 | parser.add_argument('--result_dir', default='./runs', 17 | help='The directory to save checkpoints and logs (default: %(default)s)') 18 | 19 | # data 20 | parser.add_argument('--data_name', default='rcv1', 21 | help='Dataset name (default: %(default)s)') 22 | parser.add_argument('--train_path', 23 | help='Path to training data (default: [data_dir]/train.txt)') 24 | parser.add_argument('--val_path', 25 | help='Path to validation data (default: [data_dir]/valid.txt)') 26 | parser.add_argument('--test_path', 27 | help='Path to test data (default: [data_dir]/test.txt)') 28 | parser.add_argument('--val_size', type=float, default=0.2, 29 | help='Training-validation split: a ratio in [0, 1] or an integer for the size of the validation set (default: %(default)s).') 30 | parser.add_argument('--min_vocab_freq', type=int, default=1, 31 | help='The minimum frequency needed to include a token in the vocabulary (default: %(default)s)') 32 | parser.add_argument('--max_seq_length', type=int, default=500, 33 | help='The maximum number of tokens of a sample (default: %(default)s)') 34 | parser.add_argument('--lm_weight', type=str, 35 | help='Pretrained model name or path (default: %(default)s)') 36 | parser.add_argument('--shuffle', type=bool, default=True, 37 | help='Whether to shuffle training data before each epoch (default: %(default)s)') 38 | parser.add_argument('--merge_train_val', action='store_true', 39 | help='Whether to merge the training and validation data. (default: %(default)s)') 40 | parser.add_argument('--include_test_labels', action='store_true', 41 | help='Whether to include labels in the test dataset. (default: %(default)s)') 42 | parser.add_argument('--remove_no_label_data', action='store_true', 43 | help='Whether to remove training and validation instances that have no labels.') 44 | 45 | # train 46 | parser.add_argument('--seed', type=int, 47 | help='Random seed (default: %(default)s)') 48 | parser.add_argument('--epochs', type=int, default=10000, 49 | help='The number of epochs to train (default: %(default)s)') 50 | parser.add_argument('--batch_size', type=int, default=16, 51 | help='Size of training batches (default: %(default)s)') 52 | parser.add_argument('--optimizer', default='adam', choices=['adam', 'adamw', 'adamax', 'sgd'], 53 | help='Optimizer (default: %(default)s)') 54 | parser.add_argument('--learning_rate', type=float, default=0.0001, 55 | help='Learning rate for optimizer (default: %(default)s)') 56 | parser.add_argument('--weight_decay', type=float, default=0, 57 | help='Weight decay factor (default: %(default)s)') 58 | parser.add_argument('--momentum', type=float, default=0.9, 59 | help='Momentum factor for SGD only (default: %(default)s)') 60 | parser.add_argument('--patience', type=int, default=5, 61 | help='The number of epochs to wait for improvement before early stopping (default: %(default)s)') 62 | parser.add_argument('--normalize_embed', action='store_true', 63 | help='Whether the embeddings of each word is normalized to a unit vector (default: %(default)s)') 64 | 65 | # model 66 | parser.add_argument('--model_name', default='KimCNN', 67 | help='Model to be used (default: %(default)s)') 68 | parser.add_argument('--init_weight', default='kaiming_uniform', 69 | help='Weight initialization to be used (default: %(default)s)') 70 | 71 | # eval 72 | parser.add_argument('--eval_batch_size', type=int, default=256, 73 | help='Size of evaluating batches (default: %(default)s)') 74 | parser.add_argument('--metric_threshold', type=float, default=0.5, 75 | help='Thresholds to monitor for metrics (default: %(default)s)') 76 | parser.add_argument('--monitor_metrics', nargs='+', default=['P@1', 'P@3', 'P@5'], 77 | help='Metrics to monitor while validating (default: %(default)s)') 78 | parser.add_argument('--val_metric', default='P@1', 79 | help='The metric to monitor for early stopping (default: %(default)s)') 80 | 81 | # pretrained vocab / embeddings 82 | parser.add_argument('--vocab_file', type=str, 83 | help='Path to a file holding vocabuaries (default: %(default)s)') 84 | parser.add_argument('--embed_file', type=str, 85 | help='Path to a file holding pre-trained embeddings (default: %(default)s)') 86 | parser.add_argument('--label_file', type=str, 87 | help='Path to a file holding all labels (default: %(default)s)') 88 | 89 | # log 90 | parser.add_argument('--save_k_predictions', type=int, nargs='?', const=100, default=0, 91 | help='Save top k predictions on test set. k=%(const)s if not specified. (default: %(default)s)') 92 | parser.add_argument('--predict_out_path', 93 | help='Path to the an output file holding top k label results (default: %(default)s)') 94 | 95 | # auto-test 96 | parser.add_argument('--limit_train_batches', type=float, default=1.0, 97 | help='Percentage of train dataset to use for auto-testing (default: %(default)s)') 98 | parser.add_argument('--limit_val_batches', type=float, default=1.0, 99 | help='Percentage of validation dataset to use for auto-testing (default: %(default)s)') 100 | parser.add_argument('--limit_test_batches', type=float, default=1.0, 101 | help='Percentage of test dataset to use for auto-testing (default: %(default)s)') 102 | 103 | # others 104 | parser.add_argument('--cpu', action='store_true', 105 | help='Disable CUDA') 106 | parser.add_argument('--silent', action='store_true', 107 | help='Enable silent mode') 108 | parser.add_argument('--data_workers', type=int, default=4, 109 | help='Use multi-cpu core for data pre-processing (default: %(default)s)') 110 | parser.add_argument('--embed_cache_dir', type=str, 111 | help='For parameter search only: path to a directory for storing embeddings for multiple runs. (default: %(default)s)') 112 | parser.add_argument('--eval', action='store_true', 113 | help='Only run evaluation on the test set (default: %(default)s)') 114 | parser.add_argument('--checkpoint_path', 115 | help='The checkpoint to warm-up with (default: %(default)s)') 116 | 117 | # linear options 118 | parser.add_argument('--linear', action='store_true', 119 | help='Train linear model') 120 | parser.add_argument('--data_format', type=str, default='txt', 121 | help='\'svm\' for SVM format or \'txt\' for LibMultiLabel format (default: %(default)s)') 122 | parser.add_argument('--liblinear_options', type=str, 123 | help='Options passed to liblinear (default: %(default)s)') 124 | parser.add_argument('--linear_technique', type=str, default='1vsrest', 125 | choices=['1vsrest', 'thresholding', 'cost_sensitive', 'cost_sensitive_micro'], 126 | help='Technique for linear classification (default: %(default)s)') 127 | 128 | parser.add_argument('-h', '--help', action='help', 129 | help="If you are trying to specify network config such as dropout or activation, use a yaml file instead. " 130 | "See example configs in example_config") 131 | 132 | # LexGLUE 133 | parser.add_argument('--zero', action='store_true') 134 | parser.add_argument('--multi_class', action='store_true') 135 | parser.add_argument('--add_special_tokens', action='store_true') 136 | parser.add_argument('--enable_ce_loss', action='store_true') 137 | parser.add_argument('--hierarchical', action='store_true') 138 | parser.add_argument('--accumulate_grad_batches', type=int, default=1) 139 | parser.add_argument('--enable_transformer_trainer', action='store_true') 140 | 141 | 142 | def get_config(): 143 | parser = argparse.ArgumentParser( 144 | add_help=False, 145 | description='multi-label learning for text classification') 146 | 147 | # load params from config file 148 | parser.add_argument('-c', '--config', help='Path to configuration file') 149 | args, _ = parser.parse_known_args() 150 | config = {} 151 | if args.config: 152 | with open(args.config) as fp: 153 | config = yaml.load(fp, Loader=yaml.SafeLoader) 154 | 155 | add_all_arguments(parser) 156 | 157 | parser.set_defaults(**config) 158 | args = parser.parse_args() 159 | config = AttributeDict(vars(args)) 160 | 161 | config.run_name = '{}_{}_{}'.format( 162 | config.data_name, 163 | Path(config.config).stem if config.config else config.model_name, 164 | datetime.now().strftime('%Y%m%d%H%M%S'), 165 | ) 166 | config.checkpoint_dir = os.path.join(config.result_dir, config.run_name) 167 | config.log_path = os.path.join(config.checkpoint_dir, 'logs.json') 168 | config.predict_out_path = config.predict_out_path or os.path.join( 169 | config.checkpoint_dir, 'predictions.txt') 170 | 171 | config.train_path = config.train_path or os.path.join( 172 | config.data_dir, 'train.txt') 173 | config.val_path = config.val_path or os.path.join( 174 | config.data_dir, 'valid.txt') 175 | config.test_path = config.test_path or os.path.join( 176 | config.data_dir, 'test.txt') 177 | 178 | return config 179 | 180 | 181 | def check_config(config): 182 | """Check if the configuration has invalid arguments. 183 | 184 | Args: 185 | config (AttributeDict): Config of the experiment from `get_args`. 186 | """ 187 | if config.model_name == 'XMLCNN' and config.seed is not None: 188 | raise ValueError("nn.AdaptiveMaxPool1d doesn't have a deterministic implementation but seed is" 189 | "specified. Please do not specify seed.") 190 | 191 | if config.eval and not os.path.exists(config.test_path): 192 | raise ValueError('--eval is specified but there is no test data set') 193 | 194 | 195 | def main(): 196 | # Get config 197 | config = get_config() 198 | check_config(config) 199 | 200 | # Set up logger 201 | log_level = logging.WARNING if config.silent else logging.INFO 202 | logging.basicConfig( 203 | level=log_level, format='%(asctime)s %(levelname)s:%(message)s') 204 | 205 | logging.info(f'Run name: {config.run_name}') 206 | 207 | if config.linear: 208 | from linear_trainer import linear_run 209 | linear_run(config) 210 | else: 211 | from torch_trainer import TorchTrainer 212 | trainer = TorchTrainer(config) # initialize trainer 213 | # train 214 | if not config.eval: 215 | trainer.train() 216 | # test 217 | if 'test' in trainer.datasets: 218 | trainer.test() 219 | 220 | return config 221 | 222 | 223 | def dump_time_to_log(log_path, time): 224 | """Write time to log. 225 | Args: 226 | log_path(str): path to log path 227 | time (str): time in seconds 228 | """ 229 | assert os.path.isfile(log_path) 230 | with open(log_path) as fp: 231 | import json 232 | result = json.load(fp) 233 | 234 | if time >= 60 * 60: 235 | h, s = divmod(time, 60 * 60) 236 | formatted_time = f'{h}h {s/60:.0f}m' 237 | elif time >= 60: 238 | m, s = divmod(time, 60) 239 | formatted_time = f'{m}m {s}s' 240 | else: 241 | formatted_time = f'{time}s' 242 | result['time'] = formatted_time 243 | 244 | with open(log_path, 'w') as fp: 245 | json.dump(result, fp) 246 | 247 | print(f'Wall time: {formatted_time}') 248 | 249 | if __name__ == '__main__': 250 | wall_time = Timer() 251 | config = main() 252 | dump_time_to_log(config.log_path, round(wall_time.time())) 253 | # print(f'Wall time: {wall_time.time():.2f} (s)') 254 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.18.0 # or any version > 2.15.0 2 | nltk==3.7 3 | pandas==1.5.0 4 | PyYAML==6.0 5 | scikit-learn==1.1.2 6 | torch==1.12.0 7 | torchmetrics==0.9.2 8 | torchtext==0.13.0 9 | pytorch-lightning==1.6.5 10 | tqdm==4.64.1 11 | liblinear-multicore==2.45.1 12 | numba==0.56.3 13 | scipy==1.9.2 14 | transformers==4.23.0 15 | -------------------------------------------------------------------------------- /requirements_parameter_search.txt: -------------------------------------------------------------------------------- 1 | bayesian-optimization==1.2.0 # ray[tune] 2 | optuna==2.10.1 # ray[tune] 3 | ray==1.13.0 4 | ray[tune]==1.13.0 5 | grpcio==1.43.0 # Fix issue: https://github.com/ray-project/ray/issues/22518 6 | -------------------------------------------------------------------------------- /run_experiments.sh: -------------------------------------------------------------------------------- 1 | data_list=(ecthr_a ecthr_b scotus eurlex ledgar unfair_tos) 2 | algo_list=(1vsrest thresholding cost_sensitive bert_default bert_tuned bert_reproduced) 3 | 4 | data=$1 5 | algo=$2 6 | 7 | if [[ ! " ${data_list[*]} " =~ " ${data} " ]]; then 8 | echo "Invalid argument! Data ${data} is not in (${data_list[*]})." 9 | exit 10 | fi 11 | 12 | if [[ ! " ${algo_list[*]} " =~ " ${algo} " ]]; then 13 | echo "Invalid argument! Algorithm ${algo} is not in (${algo_list[*]})." 14 | exit 15 | fi 16 | 17 | linear_algo_list=(1vsrest thresholding cost_sensitive) 18 | bert_algo_list=(bert_default bert_tuned bert_reproduced) 19 | 20 | multilabel_unlabeled_data_list=(ecthr_a ecthr_b unfair_tos) 21 | multilabel_labeled_data_list=(eurlex) 22 | multiclass_labeled_data_list=(scotus ledgar) 23 | 24 | if [[ " ${linear_algo_list[*]} " =~ " ${algo} " ]]; then 25 | if [[ " ${multilabel_unlabeled_data_list[*]} " =~ " ${data} " ]]; then 26 | python3 main.py --config config/${data}/l2svm.yml --linear_technique ${algo} --zero 27 | elif [[ " ${multilabel_labeled_data_list[*]} " =~ " ${data} " ]]; then 28 | python3 main.py --config config/${data}/l2svm.yml --linear_technique ${algo} 29 | elif [[ " ${multiclass_labeled_data_list[*]} " =~ " ${data} " ]]; then 30 | python3 main.py --config config/${data}/l2svm.yml --linear_technique ${algo} --multi_class 31 | else 32 | echo "Should never reach here..." 33 | exit 34 | fi 35 | elif [[ " ${bert_algo_list[*]} " =~ " ${algo} " ]]; then 36 | if [[ " ${multilabel_unlabeled_data_list[*]} " =~ " ${data} " ]]; then 37 | python3 main.py --config config/${data}/${algo}.yml --zero --seed 1 38 | elif [[ " ${multilabel_labeled_data_list[*]} " =~ " ${data} " ]]; then 39 | python3 main.py --config config/${data}/${algo}.yml --seed 1 40 | elif [[ " ${multiclass_labeled_data_list[*]} " =~ " ${data} " ]]; then 41 | huggingface_trainer_algo_list=(bert_reproduced) 42 | if [[ ! " ${huggingface_trainer_algo_list[*]} " =~ " ${algo} " ]]; then 43 | python3 main.py --config config/${data}/${algo}.yml --multi_class --enable_ce_loss --seed 1 44 | else 45 | python3 main.py --config config/${data}/${algo}.yml --multi_class --enable_ce_loss --seed 1 --enable_transformer_trainer 46 | fi 47 | else 48 | echo "Should never reach here..." 49 | exit 50 | fi 51 | fi 52 | -------------------------------------------------------------------------------- /search_params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import logging 4 | import os 5 | import time 6 | from datetime import datetime 7 | from pathlib import Path 8 | 9 | import yaml 10 | from ray import tune 11 | from ray.tune.schedulers import ASHAScheduler 12 | 13 | import numpy as np 14 | 15 | from libmultilabel.nn import data_utils 16 | from libmultilabel.nn.nn_utils import set_seed 17 | from libmultilabel.common_utils import AttributeDict, Timer 18 | from torch_trainer import TorchTrainer 19 | 20 | logging.basicConfig(level=logging.INFO, 21 | format='%(asctime)s %(levelname)s:%(message)s') 22 | 23 | 24 | def train_libmultilabel_tune(config, datasets, classes, word_dict): 25 | """The training function for ray tune. 26 | 27 | Args: 28 | config (AttributeDict): Config of the experiment. 29 | datasets (dict): A dictionary of datasets. 30 | classes(list): List of class names. 31 | word_dict(torchtext.vocab.Vocab): A vocab object which maps tokens to indices. 32 | """ 33 | set_seed(seed=config.seed) 34 | config.run_name = tune.get_trial_dir() 35 | logging.info(f'Run name: {config.run_name}') 36 | config.checkpoint_dir = os.path.join(config.result_dir, config.run_name) 37 | config.log_path = os.path.join(config.checkpoint_dir, 'logs.json') 38 | 39 | trainer = TorchTrainer(config=config, 40 | datasets=datasets, 41 | classes=classes, 42 | word_dict=word_dict, 43 | search_params=True, 44 | save_checkpoints=False) 45 | trainer.train() 46 | 47 | 48 | def load_config_from_file(config_path): 49 | """Initialize the model config. 50 | 51 | Args: 52 | config_path (str): Path to the config file. 53 | 54 | Returns: 55 | AttributeDict: Config of the experiment. 56 | """ 57 | with open(config_path) as fp: 58 | config = yaml.safe_load(fp) 59 | 60 | # create directories that hold the shared data 61 | os.makedirs(config['result_dir'], exist_ok=True) 62 | if config['embed_cache_dir']: 63 | os.makedirs(config['embed_cache_dir'], exist_ok=True) 64 | 65 | # set relative path to absolute path (_path, _file, _dir) 66 | for k, v in config.items(): 67 | if isinstance(v, str) and os.path.exists(v): 68 | config[k] = os.path.abspath(v) 69 | 70 | # find `train.txt`, `val.txt`, and `test.txt` from the data directory if not specified. 71 | config['train_path'] = config['train_path'] or os.path.join(config['data_dir'], 'train.txt') 72 | config['val_path'] = config['val_path'] or os.path.join(config['data_dir'], 'valid.txt') 73 | config['test_path'] = config['test_path'] or os.path.join(config['data_dir'], 'test.txt') 74 | 75 | return config 76 | 77 | 78 | def init_search_params_spaces(config, parameter_columns, prefix): 79 | """Initialize the sample space defined in ray tune. 80 | See the random distributions API listed here: https://docs.ray.io/en/master/tune/api_docs/search_space.html#random-distributions-api 81 | 82 | Args: 83 | config (dict): Config of the experiment. 84 | parameter_columns (dict): Names of parameters to include in the CLIReporter. 85 | The keys are parameter names and the values are displayed names. 86 | prefix(str): The prefix of a nested parameter such as network_config/dropout. 87 | 88 | Returns: 89 | dict: Config with parsed sample spaces. 90 | """ 91 | search_spaces = ['choice', 'grid_search', 'uniform', 'quniform', 'loguniform', 92 | 'qloguniform', 'randn', 'qrandn', 'randint', 'qrandint'] 93 | for key, value in config.items(): 94 | if isinstance(value, list) and len(value) >= 2 and value[0] in search_spaces: 95 | search_space, search_args = value[0], value[1:] 96 | if isinstance(search_args[0], list) and any(isinstance(x, list) for x in search_args[0]) and search_space != 'grid_search': 97 | raise ValueError( 98 | """If the search values are lists, the search space must be `grid_search`. 99 | Take `filter_sizes: ['grid_search', [[2,4,8], [4,6]]]` for example, the program will grid search over 100 | [2,4,8] and [4,6]. This is the same as assigning `filter_sizes` to either [2,4,8] or [4,6] in two runs. 101 | """) 102 | else: 103 | config[key] = getattr(tune, search_space)(*search_args) 104 | parameter_columns[prefix+key] = key 105 | elif isinstance(value, dict): 106 | config[key] = init_search_params_spaces(value, parameter_columns, f'{prefix}{key}/') 107 | 108 | return config 109 | 110 | 111 | def init_search_algorithm(search_alg, metric=None, mode=None): 112 | """Specify a search algorithm and you must pip install it first. 113 | If no search algorithm is specified, the default search algorithm is BasicVariantGenerator. 114 | See more details here: https://docs.ray.io/en/master/tune/api_docs/suggestion.html 115 | 116 | Args: 117 | search_alg (str): One of 'basic_variant', 'bayesopt', or 'optuna'. 118 | metric (str): The metric to monitor for early stopping. 119 | mode (str): One of 'min' or 'max' to determine whether to minimize or maximize the metric. 120 | """ 121 | if search_alg == 'optuna': 122 | assert metric and mode, "Metric and mode cannot be None for optuna." 123 | from ray.tune.suggest.optuna import OptunaSearch 124 | return OptunaSearch(metric=metric, mode=mode) 125 | elif search_alg == 'bayesopt': 126 | assert metric and mode, "Metric and mode cannot be None for bayesian optimization." 127 | from ray.tune.suggest.bayesopt import BayesOptSearch 128 | return BayesOptSearch(metric=metric, mode=mode) 129 | logging.info(f'{search_alg} search is found, run BasicVariantGenerator().') 130 | 131 | 132 | def prepare_retrain_config(best_config, best_log_dir, merge_train_val): 133 | """Prepare the configuration for re-training. 134 | 135 | Args: 136 | best_config (AttributeDict): The best hyper-parameter configuration. 137 | best_log_dir (str): The directory of the best trial of the experiment. 138 | merge_train_val (bool): Whether to merge the training and validation data. 139 | """ 140 | if merge_train_val: 141 | best_config.merge_train_val = True 142 | 143 | log_path = os.path.join(best_log_dir, 'logs.json') 144 | if os.path.isfile(log_path): 145 | with open(log_path) as fp: 146 | log = json.load(fp) 147 | else: 148 | raise FileNotFoundError("The log directory does not contain a log.") 149 | 150 | # For re-training with validation data, 151 | # we use the number of epochs at the point of optimal validation performance. 152 | log_metric = np.array([l[best_config.val_metric] for l in log['val']]) 153 | optimal_idx = log_metric.argmax() if best_config.mode == 'max' else log_metric.argmin() 154 | best_config.epochs = optimal_idx.item() + 1 # plus 1 for epochs 155 | else: 156 | best_config.merge_train_val = False 157 | 158 | 159 | def load_static_data(config, merge_train_val=False): 160 | """Preload static data once for multiple trials. 161 | 162 | Args: 163 | config (AttributeDict): Config of the experiment. 164 | merge_train_val (bool, optional): Whether to merge the training and validation data. 165 | Defaults to False. 166 | 167 | Returns: 168 | dict: A dict of static data containing datasets, classes, and word_dict. 169 | """ 170 | datasets = data_utils.load_datasets(train_path=config.train_path, 171 | test_path=config.test_path, 172 | val_path=config.val_path, 173 | val_size=config.val_size, 174 | merge_train_val=merge_train_val, 175 | tokenize_text='lm_weight' not in config['network_config'], 176 | remove_no_label_data=config.remove_no_label_data 177 | ) 178 | return { 179 | "datasets": datasets, 180 | "word_dict": None if config.embed_file is None else data_utils.load_or_build_text_dict( 181 | dataset=datasets['train'], 182 | vocab_file=config.vocab_file, 183 | min_vocab_freq=config.min_vocab_freq, 184 | embed_file=config.embed_file, 185 | embed_cache_dir=config.embed_cache_dir, 186 | silent=config.silent, 187 | normalize_embed=config.normalize_embed 188 | ), 189 | "classes": data_utils.load_or_build_label(datasets, config.label_file, config.include_test_labels) 190 | } 191 | 192 | 193 | def retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val): 194 | """Re-train the model with the best hyper-parameters. 195 | A new model is trained on the combined training and validation data if `merge_train_val` is True. 196 | If a test set is provided, it will be evaluated by the obtained model. 197 | 198 | Args: 199 | exp_name (str): The directory to save trials generated by ray tune. 200 | best_config (AttributeDict): The best hyper-parameter configuration. 201 | best_log_dir (str): The directory of the best trial of the experiment. 202 | merge_train_val (bool): Whether to merge the training and validation data. 203 | """ 204 | best_config.silent = False 205 | checkpoint_dir = os.path.join(best_config.result_dir, exp_name, 'trial_best_params') 206 | os.makedirs(checkpoint_dir, exist_ok=True) 207 | with open(os.path.join(checkpoint_dir, 'params.yml'), 'w') as fp: 208 | yaml.dump(dict(best_config), fp) 209 | quit() # do not need re-training 210 | best_config.run_name = '_'.join(exp_name.split('_')[:-1]) + '_best' 211 | best_config.checkpoint_dir = checkpoint_dir 212 | best_config.log_path = os.path.join(best_config.checkpoint_dir, 'logs.json') 213 | prepare_retrain_config(best_config, best_log_dir, merge_train_val) 214 | set_seed(seed=best_config.seed) 215 | 216 | data = load_static_data(best_config, merge_train_val=best_config.merge_train_val) 217 | logging.info(f'Re-training with best config: \n{best_config}') 218 | trainer = TorchTrainer(config=best_config, **data) 219 | trainer.train() 220 | 221 | if 'test' in data['datasets']: 222 | test_results = trainer.test() 223 | logging.info(f'Test results after re-training: {test_results}') 224 | logging.info(f'Best model saved to {trainer.checkpoint_callback.best_model_path or trainer.checkpoint_callback.last_model_path}.') 225 | 226 | 227 | def main(): 228 | parser = argparse.ArgumentParser() 229 | parser.add_argument( 230 | '--config', help='Path to configuration file (default: %(default)s). Please specify a config with all arguments in LibMultiLabel/main.py::get_config.') 231 | parser.add_argument('--cpu_count', type=int, default=4, 232 | help='Number of CPU per trial (default: %(default)s)') 233 | parser.add_argument('--gpu_count', type=int, default=1, 234 | help='Number of GPU per trial (default: %(default)s)') 235 | parser.add_argument('--num_samples', type=int, default=50, 236 | help='Number of running trials. If the search space is `grid_search`, the same grid will be repeated `num_samples` times. (default: %(default)s)') 237 | parser.add_argument('--mode', default='max', choices=['min', 'max'], 238 | help='Determines whether objective is minimizing or maximizing the metric attribute. (default: %(default)s)') 239 | parser.add_argument('--search_alg', default=None, choices=['basic_variant', 'bayesopt', 'optuna'], 240 | help='Search algorithms (default: %(default)s)') 241 | parser.add_argument('--no_merge_train_val', action='store_true', 242 | help='Do not add the validation set in re-training the final model after hyper-parameter search.') 243 | args, _ = parser.parse_known_args() 244 | 245 | # Load config from the config file and overwrite values specified in CLI. 246 | parameter_columns = dict() # parameters to include in progress table of CLIReporter 247 | config = load_config_from_file(args.config) 248 | config = init_search_params_spaces(config, parameter_columns, prefix='') 249 | parser.set_defaults(**config) 250 | config = AttributeDict(vars(parser.parse_args())) 251 | config.merge_train_val = False # no need to include validation during parameter search 252 | 253 | # Check if the validation set is provided. 254 | val_path = config.val_path or os.path.join(config.data_dir, 'valid.txt') 255 | assert config.val_size > 0 or os.path.exists(val_path), \ 256 | "You should specify either a positive `val_size` or a `val_path` defaults to `data_dir/valid.txt` for parameter search." 257 | 258 | """Run tune analysis. 259 | - If no search algorithm is specified, the default search algorighm is BasicVariantGenerator. 260 | https://docs.ray.io/en/master/tune/api_docs/suggestion.html#tune-basicvariant 261 | - Arguments without search spaces will be ignored by `tune.run` 262 | (https://github.com/ray-project/ray/blob/34d3d9294c50aea4005b7367404f6a5d9e0c2698/python/ray/tune/suggest/variant_generator.py#L333), 263 | so we parse the whole config to `tune.run` here for simplicity. 264 | """ 265 | data = load_static_data(config) 266 | reporter = tune.CLIReporter(metric_columns=[f'val_{metric}' for metric in config.monitor_metrics], 267 | parameter_columns=parameter_columns, 268 | metric=f'val_{config.val_metric}', 269 | mode=args.mode, 270 | sort_by_metric=True) 271 | if config.scheduler is not None: 272 | scheduler = ASHAScheduler(metric=f'val_{config.val_metric}', 273 | mode=args.mode, 274 | **config.scheduler) 275 | else: 276 | scheduler = None 277 | 278 | exp_name = '{}_{}_{}'.format( 279 | config.data_name, 280 | Path(config.config).stem if config.config else config.model_name, 281 | datetime.now().strftime('%Y%m%d%H%M%S'), 282 | ) 283 | analysis = tune.run( 284 | tune.with_parameters( 285 | train_libmultilabel_tune, 286 | **data), 287 | search_alg=init_search_algorithm( 288 | config.search_alg, metric=config.val_metric, mode=args.mode), 289 | scheduler=scheduler, 290 | local_dir=config.result_dir, 291 | num_samples=config.num_samples, 292 | resources_per_trial={ 293 | 'cpu': args.cpu_count, 'gpu': args.gpu_count}, 294 | progress_reporter=reporter, 295 | config=config, 296 | name=exp_name, 297 | ) 298 | 299 | # Save best model after parameter search. 300 | best_config = analysis.get_best_config(f'val_{config.val_metric}', args.mode, scope='all') 301 | best_log_dir = analysis.get_best_logdir(f'val_{config.val_metric}', args.mode, scope='all') 302 | retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val=not args.no_merge_train_val) 303 | 304 | 305 | if __name__ == '__main__': 306 | # calculate wall time. 307 | wall_time = Timer() 308 | main() 309 | print(f'Wall time: {wall_time.time():.2f} (s)') 310 | -------------------------------------------------------------------------------- /search_params.sh: -------------------------------------------------------------------------------- 1 | data_list=(ecthr_a ecthr_b scotus eurlex ledgar unfair_tos) 2 | 3 | data=$1 4 | 5 | if [[ ! " ${data_list[*]} " =~ " ${data} " ]]; then 6 | echo "Invalid argument! Data ${data} is not in (${data_list[*]})." 7 | exit 8 | fi 9 | 10 | python3 search_params.py --config config/${data}/bert_tune.yml 11 | -------------------------------------------------------------------------------- /torch_trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 6 | from transformers import AutoTokenizer 7 | 8 | from libmultilabel.nn import data_utils 9 | from libmultilabel.nn.model import Model 10 | from libmultilabel.nn.nn_utils import init_device, init_model, init_trainer 11 | from libmultilabel.common_utils import dump_log 12 | 13 | 14 | class TorchTrainer: 15 | """A wrapper for training neural network models with pytorch lightning trainer. 16 | 17 | Args: 18 | config (AttributeDict): Config of the experiment. 19 | datasets (dict, optional): Datasets for training, validation, and test. Defaults to None. 20 | classes(list, optional): List of class names. 21 | word_dict(torchtext.vocab.Vocab, optional): A vocab object which maps tokens to indices. 22 | embed_vecs (torch.Tensor, optional): The pre-trained word vectors of shape (vocab_size, embed_dim). 23 | search_params (bool, optional): Enable pytorch-lightning trainer to report the results to ray tune 24 | on validation end during hyperparameter search. Defaults to False. 25 | save_checkpoints (bool, optional): Whether to save the last and the best checkpoint or not. 26 | Defaults to True. 27 | """ 28 | def __init__( 29 | self, 30 | config: dict, 31 | datasets: dict = None, 32 | classes: list = None, 33 | word_dict: dict = None, 34 | embed_vecs = None, 35 | search_params: bool = False, 36 | save_checkpoints: bool = True 37 | ): 38 | self.run_name = config.run_name 39 | self.checkpoint_dir = config.checkpoint_dir 40 | self.log_path = config.log_path 41 | os.makedirs(self.checkpoint_dir, exist_ok=True) 42 | 43 | # Set up seed & device 44 | if not config.enable_transformer_trainer: 45 | from libmultilabel.nn.nn_utils import set_seed 46 | set_seed(seed=config.seed) 47 | self.device = init_device(use_cpu=config.cpu) 48 | self.config = config 49 | 50 | # Load pretrained tokenizer for dataset loader 51 | self.tokenizer = None 52 | tokenize_text = 'lm_weight' not in config.network_config 53 | if not tokenize_text: 54 | self.tokenizer = AutoTokenizer.from_pretrained(config.network_config['lm_weight'], use_fast=True) 55 | # Load dataset 56 | if datasets is None: 57 | self.datasets = data_utils.load_datasets( 58 | train_path=config.train_path, 59 | test_path=config.test_path, 60 | val_path=config.val_path, 61 | val_size=config.val_size, 62 | merge_train_val=config.merge_train_val, 63 | tokenize_text=tokenize_text, 64 | remove_no_label_data=config.remove_no_label_data 65 | ) 66 | else: 67 | self.datasets = datasets 68 | 69 | self._setup_model(classes=classes, 70 | word_dict=word_dict, 71 | embed_vecs=embed_vecs, 72 | log_path=self.log_path, 73 | checkpoint_path=config.checkpoint_path) 74 | if config.enable_transformer_trainer: 75 | from transformers import EarlyStoppingCallback, set_seed, Trainer 76 | 77 | from libmultilabel.nn.data_utils import generate_transformer_batch 78 | from libmultilabel.nn.nn_utils import init_training_args, compute_metrics 79 | 80 | training_args = init_training_args(config) 81 | 82 | set_seed(training_args.seed) 83 | 84 | self.train_dataset = self._get_dataset_loader(split='train', shuffle=config.shuffle).dataset 85 | self.val_dataset = self._get_dataset_loader(split='val').dataset 86 | self.test_dataset = self._get_dataset_loader(split='test').dataset 87 | self.trainer = Trainer( 88 | model=self.model.network.lm, 89 | args=training_args, 90 | train_dataset=self.train_dataset, 91 | eval_dataset=self.val_dataset, 92 | compute_metrics=compute_metrics, 93 | tokenizer=self.tokenizer, 94 | data_collator=generate_transformer_batch, 95 | callbacks=[EarlyStoppingCallback(early_stopping_patience=config.patience)] 96 | ) 97 | else: 98 | self.trainer = init_trainer(checkpoint_dir=self.checkpoint_dir, 99 | epochs=config.epochs, 100 | patience=config.patience, 101 | val_metric=config.val_metric, 102 | silent=config.silent, 103 | use_cpu=config.cpu, 104 | limit_train_batches=config.limit_train_batches, 105 | limit_val_batches=config.limit_val_batches, 106 | limit_test_batches=config.limit_test_batches, 107 | search_params=search_params, 108 | save_checkpoints=save_checkpoints, 109 | accumulate_grad_batches=config.accumulate_grad_batches) 110 | callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, ModelCheckpoint)] 111 | self.checkpoint_callback = callbacks[0] if callbacks else None 112 | 113 | # Dump config to log 114 | dump_log(self.log_path, config=config) 115 | 116 | def _setup_model( 117 | self, 118 | classes: list = None, 119 | word_dict: dict = None, 120 | embed_vecs = None, 121 | log_path: str = None, 122 | checkpoint_path: str = None 123 | ): 124 | """Setup model from checkpoint if a checkpoint path is passed in or specified in the config. 125 | Otherwise, initialize model from scratch. 126 | 127 | Args: 128 | classes(list): List of class names. 129 | word_dict(torchtext.vocab.Vocab): A vocab object which maps tokens to indices. 130 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim). 131 | log_path (str): Path to the log file. The log file contains the validation 132 | results for each epoch and the test results. If the `log_path` is None, no performance 133 | results will be logged. 134 | checkpoint_path (str): The checkpoint to warm-up with. 135 | """ 136 | if 'checkpoint_path' in self.config and self.config.checkpoint_path is not None: 137 | checkpoint_path = self.config.checkpoint_path 138 | 139 | if checkpoint_path is not None: 140 | logging.info(f'Loading model from `{checkpoint_path}`...') 141 | self.model = Model.load_from_checkpoint(checkpoint_path) 142 | else: 143 | logging.info('Initialize model from scratch.') 144 | if self.config.embed_file is not None: 145 | logging.info('Load word dictionary ') 146 | word_dict, embed_vecs = data_utils.load_or_build_text_dict( 147 | dataset=self.datasets['train'], 148 | vocab_file=self.config.vocab_file, 149 | min_vocab_freq=self.config.min_vocab_freq, 150 | embed_file=self.config.embed_file, 151 | silent=self.config.silent, 152 | normalize_embed=self.config.normalize_embed, 153 | embed_cache_dir=self.config.embed_cache_dir 154 | ) 155 | if not classes: 156 | classes = data_utils.load_or_build_label( 157 | self.datasets, self.config.label_file, self.config.include_test_labels) 158 | 159 | if self.config.val_metric not in self.config.monitor_metrics: 160 | logging.warn( 161 | f'{self.config.val_metric} is not in `monitor_metrics`. Add {self.config.val_metric} to `monitor_metrics`.') 162 | self.config.monitor_metrics += [self.config.val_metric] 163 | 164 | self.model = init_model(model_name=self.config.model_name, 165 | network_config=dict(self.config.network_config), 166 | classes=classes, 167 | word_dict=word_dict, 168 | embed_vecs=embed_vecs, 169 | init_weight=self.config.init_weight, 170 | log_path=log_path, 171 | learning_rate=self.config.learning_rate, 172 | optimizer=self.config.optimizer, 173 | momentum=self.config.momentum, 174 | weight_decay=self.config.weight_decay, 175 | metric_threshold=self.config.metric_threshold, 176 | monitor_metrics=self.config.monitor_metrics, 177 | silent=self.config.silent, 178 | save_k_predictions=self.config.save_k_predictions, 179 | zero=self.config.zero, 180 | multi_class=self.config.multi_class, 181 | enable_ce_loss=self.config.enable_ce_loss, 182 | hierarchical=self.config.hierarchical 183 | ) 184 | 185 | def _get_dataset_loader(self, split, shuffle=False): 186 | """Get dataset loader. 187 | 188 | Args: 189 | split (str): One of 'train', 'test', or 'val'. 190 | shuffle (bool): Whether to shuffle training data before each epoch. Defaults to False. 191 | 192 | Returns: 193 | torch.utils.data.DataLoader: Dataloader for the train, test, or valid dataset. 194 | """ 195 | return data_utils.get_dataset_loader( 196 | data=self.datasets[split], 197 | word_dict=self.model.word_dict, 198 | classes=self.model.classes, 199 | device=self.device, 200 | max_seq_length=self.config.max_seq_length, 201 | batch_size=self.config.batch_size if split == 'train' else self.config.eval_batch_size, 202 | shuffle=shuffle, 203 | data_workers=self.config.data_workers, 204 | tokenizer=self.tokenizer, 205 | add_special_tokens=self.config.add_special_tokens, 206 | hierarchical=self.config.hierarchical, 207 | enable_transformer_trainer=self.config.enable_transformer_trainer, 208 | multi_class=self.config.multi_class 209 | ) 210 | 211 | def train(self): 212 | """Train model with pytorch lightning trainer. Set model to the best model after the training 213 | process is finished. 214 | """ 215 | assert self.trainer is not None, "Please make sure the trainer is successfully initialized by `self._setup_trainer()`." 216 | if self.config.enable_transformer_trainer: 217 | # Training 218 | train_result = self.trainer.train() 219 | metrics = train_result.metrics 220 | metrics['train_samples'] = len(self.train_dataset) 221 | self.trainer.save_model() 222 | self.trainer.log_metrics('train', metrics) 223 | self.trainer.save_metrics('train', metrics) 224 | self.trainer.save_state() 225 | return 226 | 227 | train_loader = self._get_dataset_loader(split='train', shuffle=self.config.shuffle) 228 | 229 | if 'val' not in self.datasets: 230 | logging.info('No validation dataset is provided. Train without vaildation.') 231 | self.trainer.fit(self.model, train_loader) 232 | else: 233 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 234 | self.trainer.callbacks += [EarlyStopping(patience=self.config.patience, 235 | monitor=self.config.val_metric, 236 | mode='max')] # tentative hard code 237 | val_loader = self._get_dataset_loader(split='val') 238 | self.trainer.fit(self.model, train_loader, val_loader) 239 | 240 | # Set model to the best model. If the validation process is skipped during 241 | # training (i.e., val_size=0), the model is set to the last model. 242 | model_path = self.checkpoint_callback.best_model_path or self.checkpoint_callback.last_model_path 243 | if model_path: 244 | logging.info(f'Finished training. Load best model from {model_path}.') 245 | self._setup_model(checkpoint_path=model_path) 246 | else: 247 | logging.info('No model is saved during training. \ 248 | If you want to save the best and the last model, please set `save_checkpoints` to True.') 249 | 250 | def test(self, split='test'): 251 | """Test model with pytorch lightning trainer. Top-k predictions are saved 252 | if `save_k_predictions` > 0. 253 | 254 | Args: 255 | split (str, optional): One of 'train', 'test', or 'val'. Defaults to 'test'. 256 | 257 | Returns: 258 | dict: Scores for all metrics in the dictionary format. 259 | """ 260 | assert 'test' in self.datasets and self.trainer is not None 261 | 262 | if self.config.enable_transformer_trainer: 263 | # Validation 264 | metrics = self.trainer.evaluate(eval_dataset=self.val_dataset) 265 | metrics['val_samples'] = len(self.val_dataset) 266 | self.trainer.log_metrics('val', metrics) 267 | self.trainer.save_metrics('val', metrics) 268 | # Testing 269 | predictions, labels, metrics = self.trainer.predict(self.test_dataset, metric_key_prefix='test') 270 | metrics['test_samples'] = len(self.test_dataset) 271 | self.trainer.log_metrics('test', metrics) 272 | self.trainer.save_metrics('test', metrics) 273 | return 274 | 275 | logging.info(f'Testing on {split} set.') 276 | test_loader = self._get_dataset_loader(split=split) 277 | metric_dict = self.trainer.test(self.model, dataloaders=test_loader)[0] 278 | 279 | if self.config.save_k_predictions > 0: 280 | self._save_predictions(test_loader, self.config.predict_out_path) 281 | 282 | return metric_dict 283 | 284 | def _save_predictions(self, dataloader, predict_out_path): 285 | """Save top k label results. 286 | 287 | Args: 288 | dataloader (torch.utils.data.DataLoader): Dataloader for the test or valid dataset. 289 | predict_out_path (str): Path to the an output file holding top k label results. 290 | """ 291 | batch_predictions = self.trainer.predict(self.model, dataloaders=dataloader) 292 | pred_labels = np.vstack([batch['top_k_pred'] 293 | for batch in batch_predictions]) 294 | pred_scores = np.vstack([batch['top_k_pred_scores'] 295 | for batch in batch_predictions]) 296 | with open(predict_out_path, 'w') as fp: 297 | for pred_label, pred_score in zip(pred_labels, pred_scores): 298 | out_str = ' '.join([f'{self.model.classes[label]}:{score:.4}' for label, score in zip( 299 | pred_label, pred_score)]) 300 | fp.write(out_str+'\n') 301 | logging.info(f'Saved predictions to: {predict_out_path}') 302 | --------------------------------------------------------------------------------