├── .gitignore
├── README.md
├── __init__.py
├── config
├── ecthr_a
│ ├── bert_default.yml
│ ├── bert_reproduced.yml
│ ├── bert_tune.yml
│ ├── bert_tuned.yml
│ └── l2svm.yml
├── ecthr_b
│ ├── bert_default.yml
│ ├── bert_reproduced.yml
│ ├── bert_tune.yml
│ ├── bert_tuned.yml
│ └── l2svm.yml
├── eurlex
│ ├── bert_default.yml
│ ├── bert_reproduced.yml
│ ├── bert_tune.yml
│ ├── bert_tuned.yml
│ └── l2svm.yml
├── ledgar
│ ├── bert_default.yml
│ ├── bert_reproduced.yml
│ ├── bert_tune.yml
│ ├── bert_tuned.yml
│ └── l2svm.yml
├── scotus
│ ├── bert_default.yml
│ ├── bert_reproduced.yml
│ ├── bert_tune.yml
│ ├── bert_tuned.yml
│ └── l2svm.yml
└── unfair_tos
│ ├── bert_default.yml
│ ├── bert_reproduced.yml
│ ├── bert_tune.yml
│ ├── bert_tuned.yml
│ └── l2svm.yml
├── generate_data.py
├── generate_data.sh
├── libmultilabel
├── __init__.py
├── common_utils.py
├── linear
│ ├── __init__.py
│ ├── linear.py
│ ├── metrics.py
│ ├── preprocessor.py
│ └── utils.py
└── nn
│ ├── __init__.py
│ ├── data_utils.py
│ ├── metrics.py
│ ├── model.py
│ ├── networks
│ ├── __init__.py
│ ├── bert.py
│ ├── bert_attention.py
│ ├── caml.py
│ ├── hierbert.py
│ ├── kim_cnn.py
│ ├── labelwise_attention_networks.py
│ ├── modules.py
│ └── xml_cnn.py
│ └── nn_utils.py
├── linear_trainer.py
├── main.py
├── requirements.txt
├── requirements_parameter_search.txt
├── run_experiments.sh
├── search_params.py
├── search_params.sh
└── torch_trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | runs/
23 | wheels/
24 | prof/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # celery beat schedule file
97 | celerybeat-schedule
98 |
99 | # SageMath parsed files
100 | *.sage.py
101 |
102 | # Environments
103 | .env
104 | .venv
105 | env/
106 | venv*/
107 | ENV/
108 | env.bak/
109 | venv.bak/
110 |
111 | # Spyder project settings
112 | .spyderproject
113 | .spyproject
114 |
115 | # Rope project settings
116 | .ropeproject
117 |
118 | # mkdocs documentation
119 | /site
120 |
121 | # mypy
122 | .mypy_cache/
123 | .dmypy.json
124 | dmypy.json
125 |
126 | # Pyre type checker
127 | .pyre/
128 |
129 | # jupyter
130 | *.ipynb
131 |
132 | *.swp
133 | .editorconfig
134 | .vscode/
135 |
136 | dataset/*
137 | data/*
138 | output/*
139 | model_output/*
140 | *.xlsx
141 | tmp*
142 | *.csv
143 | *report*.txt
144 | *.pkl
145 | *.txt
146 | *.json
147 | *.bin
148 | *.lock
149 | *.pt
150 | *.zip
151 | *.tar.gz
152 | *.ckpt
153 | *.tfevents.*
154 |
155 | # Sphinx build
156 | docs/_build/*
157 | docs/cli/*.include
158 | !docs/requirements.txt
159 | !requirements.txt
160 | !requirements_parameter_search.txt
161 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Linear Classifier: An Often-Forgotten Baseline for Text Classification
2 |
3 | This is the code for the ACL 2023 paper "[Linear Classifier: An Often-Forgotten Baseline for Text Classification](https://www.csie.ntu.edu.tw/~cjlin/papers/text_classification_baseline/text_classification_baseline.pdf)". The repository is used to reproduce the experimental results in our paper. If you find our work useful, please consider citing the following paper:
4 | ```bib
5 | @InProceedings{YCL22a,
6 | author = {Yu-Chen Lin and Si-An Chen and Jie-Jyun Liu and Chih-Jen Lin},
7 | title = {Linear Classifier: An Often-Forgotten Baseline for Text Classification},
8 | booktitle = {Proceedings of the 61st Annual Meeting of the Association for Computational Linguistics (ACL)},
9 | year = {2023},
10 | url = {https://www.csie.ntu.edu.tw/~cjlin/papers/text_classification_baseline/text_classification_baseline.pdf},
11 | note = {Short paper}
12 | }
13 | ```
14 | Please feel free to contact [Yu-Chen Lin](mailto:b06504025@csie.ntu.edu.tw) if you have any questions about the code/paper.
15 |
16 | ## Setup Environment
17 |
18 | It is optional but highly recommended to create a virtual environment. For example, you can first refer to the [link](https://docs.conda.io/en/latest/miniconda.html) for the installation guidances of Miniconda and then create a virtual environment as follows.
19 | ```bash
20 | conda create -n baseline python=3.8
21 | conda activate baseline
22 | ```
23 |
24 | We recommend using **Python 3.8** because that is the version we used in our experiments. Please also install the following packages.
25 | ```bash
26 | pip install -r requirements.txt
27 | pip install -r requirements_parameter_search.txt
28 | ```
29 |
30 | If you have a different version of CUDA, follow the installation instructions for PyTorch LTS on their [website](https://pytorch.org/get-started/previous-versions/).
31 |
32 | ## Generate Data
33 |
34 | You can simply run the following script to generate all the needed data.
35 | ```bash
36 | bash generate_data.sh
37 | ```
38 | There are three formats of data. For all the formats, we use the same sets as the ones used in [LexGLUE](https://github.com/coastalcph/lex-glue). We process the sets to [LibMultiLabel](https://github.com/ASUS-AICS/LibMultiLabel) format and then modify them to be runnable on LibMultiLabel. We briefly explain what we do as follows.
39 | * linear: We combine training and validation subsets as the new training set, so only training and test sets are available.
40 | * nn: Training, validation, and test sets are available.
41 | * hier (hierarchical): We need to employ the hierarchical setting of BERT to reproduce the results in Chalkidis et al. (2022). Specifically, we add a special symbol to the data and modify the code in LibMultiLabel for conducting the experiments of the hierarchical BERT. Training, validation, and test sets are available.
42 |
43 | ## Experimental Results
44 |
45 | In Table 2, we present our investigation on two types of approaches: Linear SVM and BERT.
46 |
47 | | | ECtHR (A) | ECtHR (B) | SCOTUS | EUR-LEX | LEDGAR | UNFAIR-ToS |
48 | |:------------------------|:-----------:|:-----------:|:-----------:|:-----------:|:-----------:|:----------:|
49 | | **Method** | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 |
50 | | **Linear** |
51 | | one-vs-rest | 64.0 / 53.1 | 72.8 / 63.9 | 78.1 / 68.9 | 72.0 / 55.4 | 86.4 / 80.0 | 94.9 / 75.1 |
52 | | thresholding | 68.6 / 64.9 | 76.1 / 68.7 | 78.9 / 71.5 | 74.7 / 62.7 | 86.2 / 79.9 | 95.1 / 79.9 |
53 | | cost-sensitive | 67.4 / 60.5 | 75.5 / 67.3 | 78.3 / 71.5 | 73.4 / 60.5 | 86.2 / 80.1 | 95.3 / 77.9 |
54 | | Chalkidis et al. (2022) | 64.5 / 51.7 | 74.6 / 65.1 | 78.2 / 69.5 | 71.3 / 51.4 | 87.2 / 82.4 | 95.4 / 78.8 |
55 | | **BERT** |
56 | | Ours | 61.9 / 55.6 | 69.8 / 60.5 | 67.1 / 55.9 | 70.8 / 55.3 | 87.0 / 80.7 | 95.4 / 80.3 |
57 | | Chalkidis et al. (2022) | 71.2 / 63.6 | 79.7 / 73.4 | 68.3 / 58.3 | 71.4 / 57.2 | 87.6 / 81.8 | 95.6 / 81.3 |
58 |
59 | In Table 9, we present additional results from BERT. Note that the **Ours** setting (BERT) in Table 2 is the same as the **tuned** setting (BERT in LibMultiLabel) in Table 9.
60 |
61 | | | ECtHR (A) | ECtHR (B) | SCOTUS | EUR-LEX | LEDGAR | UNFAIR-ToS |
62 | |:------------|:-----------:|:-----------:|:-----------:|:-----------:|:-----------:|:----------:|
63 | | **Method** | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 | μ-F1 / m-F1 |
64 | | **BERT in LibMultiLabel** |
65 | | default | 60.5 / 53.4 | 68.9 / 60.8 | 66.3 / 54.8 | 70.8 / 55.3 | 85.2 / 77.9 | 95.2 / 78.2 |
66 | | tuned | 61.9 / 55.6 | 69.8 / 60.5 | 67.1 / 55.9 | 70.8 / 55.3 | 87.0 / 80.7 | 95.4 / 80.3 |
67 | | reproduced | 70.2 / 63.7 | 78.8 / 73.1 | 70.8 / 62.6 | 71.6 / 56.1 | 88.1 / 82.6 | 95.3 / 80.6 |
68 | | **BERT in Chalkidis et al. (2022)** |
69 | | paper | 71.2 / 63.6 | 79.7 / 73.4 | 68.3 / 58.3 | 71.4 / 57.2 | 87.6 / 81.8 | 95.6 / 81.3 |
70 | | reproduced | 70.8 / 64.8 | 78.7 / 72.5 | 70.9 / 61.9 | 71.7 / 57.9 | 87.7 / 82.1 | 95.6 / 80.3 |
71 |
72 | ## Run Experiments
73 |
74 | We show how to conduct the experiments of each method as follows.
75 | ```bash
76 | bash run_experiments.sh [DATA] [METHOD]
77 | ```
78 |
79 | First, you need to determine which data and method you want to try. Then, you should refer to the following lookup table for the value of **\[DATA\]** and **\[METHOD\]** arguments.
80 |
81 |
82 |
83 |
84 |
85 | | Dataset | \[Data\] |
86 | |:-----------|:----------:|
87 | | ECtHR (A) | ecthr_a |
88 | | ECtHR (B) | ecthr_b |
89 | | SCOTUS | scotus |
90 | | EUR-LEX | eurlex |
91 | | LEDGAR | ledgar |
92 | | UNFAIR-ToS | unfair_tos |
93 |
94 | |
95 |
96 |
97 | | Method | \[METHOD\] |
98 | |:--------------------------------|:----------------:|
99 | | Linear_one-vs-rest (Table 2) | 1vsrest |
100 | | Linear_thresholding (Table 2) | thresholding |
101 | | Linear_cost-sensitive (Table 2) | cost_sensitive |
102 | | BERT_Ours (Table 2) | bert_tuned |
103 | | BERT_default (Table 9) | bert_default |
104 | | BERT_tuned (Table 9) | bert_tuned |
105 | | BERT_reproduced (Table 9) | bert_reproduced |
106 |
107 | |
108 |
109 |
110 |
111 | For example, if you want to use **thresholding** techniques on the set **UNFAIR-ToS**, you should run the following command.
112 | ```bash
113 | bash run_experiments.sh unfair_tos thresholding
114 | ```
115 | If you aim to deal with the data set **ECtHR (B)** with the **BERT_default** setting, you should place the arguments like the following command.
116 | ```bash
117 | bash run_experiments.sh ecthr_b bert_default
118 | ```
119 |
120 | Additional information is shown as follows.
121 |
122 | * For the **BERT_tuned** setting, we have already tuned the parameters for you. The script only runs the experiments using the tuned parameters. The running time will be different from Table 10 because, in Table 10, the time for the parameter search is also included. However, if you want to tune the parameters by yourself, we also provide a script and a search space configuration to do that. Please check the following command.
123 | ```bash
124 | # Conduct the hyper-parameter search
125 | bash search_params.sh [DATA]
126 |
127 | # Replace the given tuned configuration with the searched parameters
128 | mv runs/[DATA]_bert_tune_XXX/trial_best_params/params.yml config/[DATA]/bert_tuned.yml
129 |
130 | # Run the BERT_tuned setting
131 | bash run_experiments.sh [DATA] bert_tuned
132 | ```
133 | * To conduct the **BERT_reproduced** method on the data sets **ECtHR (A)**, **ECtHR (B)**, and **SCOTUS**, you need a GPU that includes more than 16GB of GPU memory.
134 |
135 | ## Evaluation
136 |
137 | For comparison purposes, we followed Chalkidis et al. (2022) to deal with unlabeled datasets during the evaluation process,
138 | though this setting is not a standard practice in multi-label classification, nor is it supported by [LibMultiLabel](https://github.com/ASUS-AICS/LibMultiLabel).
139 |
140 | ## Reproducibility
141 |
142 | For our experimental results, linear methods were run on the CPU **Intel Xeon E5-2690**, while for BERT we used the GPU **Nvidia V100**. You may notice some minor differences in results between your running of our scripts and our paper results, especially on the BERT results. If you want to fully reproduce our results, you should carefully follow the items below.
143 | * Make sure you install our suggested package version.
144 | * Use the same device as ours.
145 | * Because our BERT results are based on the average results from five runs of different seeds (1,2,3,4,5), you should modify [run_experiments.sh](run_experiments.sh) and follow us to do five runs.
146 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JamesLYC88/text_classification_baseline_code/6436f567372f3c36547cd52da79516900e5a6148/__init__.py
--------------------------------------------------------------------------------
/config/ecthr_a/bert_default.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/ecthr_a
3 | data_name: ecthr_a
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: 0.00005
15 | weight_decay: 0.001
16 | patience: 5
17 |
18 | # eval
19 | eval_batch_size: 8
20 | monitor_metrics: ['Micro-F1', 'Macro-F1']
21 | val_metric: Micro-F1
22 |
23 | # model
24 | model_name: BERT
25 | init_weight: null
26 | network_config:
27 | lm_weight: bert-base-uncased
28 |
29 | # pretrained vocab / embeddings
30 | embed_file: null
31 |
--------------------------------------------------------------------------------
/config/ecthr_a/bert_reproduced.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_hier/ecthr_a
3 | data_name: ecthr_a
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 | hierarchical: true
9 | accumulate_grad_batches: 4
10 |
11 | # train
12 | seed: 1337
13 | epochs: 15
14 | batch_size: 2
15 | optimizer: adamw
16 | learning_rate: 0.00003
17 | weight_decay: 0
18 | patience: 5
19 |
20 | # eval
21 | eval_batch_size: 2
22 | monitor_metrics: ['Micro-F1', 'Macro-F1']
23 | val_metric: Micro-F1
24 |
25 | # model
26 | model_name: BERT
27 | init_weight: null
28 | network_config:
29 | lm_weight: bert-base-uncased
30 |
31 | # pretrained vocab / embeddings
32 | embed_file: null
33 |
--------------------------------------------------------------------------------
/config/ecthr_a/bert_tune.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/ecthr_a
3 | data_name: ecthr_a
4 | min_vocab_freq: 1
5 | max_seq_length: ['grid_search', [128, 512]]
6 | include_test_labels: true
7 | remove_no_label_data: false
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: ['grid_search', [0.00002, 0.00003, 0.00005]]
15 | momentum: 0
16 | weight_decay: ['grid_search', [0, 0.001]]
17 | patience: 5
18 | shuffle: true
19 |
20 | # eval
21 | eval_batch_size: 8
22 | monitor_metrics: ['Micro-F1', 'Macro-F1']
23 | val_metric: Micro-F1
24 |
25 | # model
26 | model_name: BERT
27 | init_weight: null
28 | network_config:
29 | dropout: ['grid_search', [0.1, 0.2]]
30 | lm_weight: bert-base-uncased
31 |
32 | # pretrained vocab / embeddings
33 | vocab_file: null
34 | embed_file: null
35 | normalize_embed: false
36 |
37 | # hyperparamter search
38 | search_alg: basic_variant
39 | embed_cache_dir: null
40 | num_samples: 1
41 | scheduler: null
42 | # Uncomment the following lines to enable the ASHAScheduler.
43 | # See the documentation here: https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-ashascheduler
44 | #scheduler:
45 | #time_attr: training_iteration
46 | #max_t: 50 # the maximum epochs to run for each config (parameter R in the ASHA paper)
47 | #grace_period: 10 # the minimum epochs to run for each config (parameter r in the ASHA paper)
48 | #reduction_factor: 3 # reduce the number of configuration to floor(1/reduction_factor) each round of successive halving (called rung in ASHA paper)
49 | #brackets: 1 # number of brackets. A smaller bracket index (parameter s in the ASHA paper) means earlier stopping (i.e., less total resources used)
50 |
51 | # other parameters specified in main.py::get_args
52 | checkpoint_path: null
53 | cpu: false
54 | data_workers: 4
55 | eval: false
56 | label_file: null
57 | limit_train_batches: 1.0
58 | limit_val_batches: 1.0
59 | limit_test_batches: 1.0
60 | metric_threshold: 0.5
61 | result_dir: runs
62 | save_k_predictions: 0
63 | silent: true
64 | test_path: null
65 | train_path: null
66 | val_path: null
67 | val_size: 0.2
68 |
69 | # LexGLUE
70 | zero: true
71 | multi_class: false
72 | add_special_tokens: true
73 | enable_ce_loss: false
74 | hierarchical: false
75 | accumulate_grad_batches: 1
76 | enable_transformer_trainer: false
77 |
--------------------------------------------------------------------------------
/config/ecthr_a/bert_tuned.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/ecthr_a
3 | data_name: ecthr_a
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: 2e-05
15 | weight_decay: 0
16 | patience: 5
17 |
18 | # eval
19 | eval_batch_size: 8
20 | monitor_metrics: ['Micro-F1', 'Macro-F1']
21 | val_metric: Micro-F1
22 |
23 | # model
24 | model_name: BERT
25 | init_weight: null
26 | network_config:
27 | lm_weight: bert-base-uncased
28 | dropout: 0.1
29 |
30 | # pretrained vocab / embeddings
31 | embed_file: null
32 |
--------------------------------------------------------------------------------
/config/ecthr_a/l2svm.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_linear/ecthr_a
3 | data_name: ecthr_a
4 |
5 | # train
6 | seed: 1337
7 | linear: true
8 | liblinear_options: "-s 2 -B 1 -e 0.0001 -q"
9 | linear_technique: 1vsrest
10 |
11 | # eval
12 | eval_batch_size: 256
13 | monitor_metrics: [Micro-F1, Macro-F1]
14 | metric_threshold: 0
15 |
16 | data_format: txt
17 |
--------------------------------------------------------------------------------
/config/ecthr_b/bert_default.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/ecthr_b
3 | data_name: ecthr_b
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: 0.00005
15 | weight_decay: 0.001
16 | patience: 5
17 |
18 | # eval
19 | eval_batch_size: 8
20 | monitor_metrics: ['Micro-F1', 'Macro-F1']
21 | val_metric: Micro-F1
22 |
23 | # model
24 | model_name: BERT
25 | init_weight: null
26 | network_config:
27 | lm_weight: bert-base-uncased
28 |
29 | # pretrained vocab / embeddings
30 | embed_file: null
31 |
--------------------------------------------------------------------------------
/config/ecthr_b/bert_reproduced.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_hier/ecthr_b
3 | data_name: ecthr_b
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 | hierarchical: true
9 | accumulate_grad_batches: 4
10 |
11 | # train
12 | seed: 1337
13 | epochs: 15
14 | batch_size: 2
15 | optimizer: adamw
16 | learning_rate: 0.00003
17 | weight_decay: 0
18 | patience: 5
19 |
20 | # eval
21 | eval_batch_size: 2
22 | monitor_metrics: ['Micro-F1', 'Macro-F1']
23 | val_metric: Micro-F1
24 |
25 | # model
26 | model_name: BERT
27 | init_weight: null
28 | network_config:
29 | lm_weight: bert-base-uncased
30 |
31 | # pretrained vocab / embeddings
32 | embed_file: null
33 |
--------------------------------------------------------------------------------
/config/ecthr_b/bert_tune.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/ecthr_b
3 | data_name: ecthr_b
4 | min_vocab_freq: 1
5 | max_seq_length: ['grid_search', [128, 512]]
6 | include_test_labels: true
7 | remove_no_label_data: false
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: ['grid_search', [0.00002, 0.00003, 0.00005]]
15 | momentum: 0
16 | weight_decay: ['grid_search', [0, 0.001]]
17 | patience: 5
18 | shuffle: true
19 |
20 | # eval
21 | eval_batch_size: 8
22 | monitor_metrics: ['Micro-F1', 'Macro-F1']
23 | val_metric: Micro-F1
24 |
25 | # model
26 | model_name: BERT
27 | init_weight: null
28 | network_config:
29 | dropout: ['grid_search', [0.1, 0.2]]
30 | lm_weight: bert-base-uncased
31 |
32 | # pretrained vocab / embeddings
33 | vocab_file: null
34 | embed_file: null
35 | normalize_embed: false
36 |
37 | # hyperparamter search
38 | search_alg: basic_variant
39 | embed_cache_dir: null
40 | num_samples: 1
41 | scheduler: null
42 | # Uncomment the following lines to enable the ASHAScheduler.
43 | # See the documentation here: https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-ashascheduler
44 | #scheduler:
45 | #time_attr: training_iteration
46 | #max_t: 50 # the maximum epochs to run for each config (parameter R in the ASHA paper)
47 | #grace_period: 10 # the minimum epochs to run for each config (parameter r in the ASHA paper)
48 | #reduction_factor: 3 # reduce the number of configuration to floor(1/reduction_factor) each round of successive halving (called rung in ASHA paper)
49 | #brackets: 1 # number of brackets. A smaller bracket index (parameter s in the ASHA paper) means earlier stopping (i.e., less total resources used)
50 |
51 | # other parameters specified in main.py::get_args
52 | checkpoint_path: null
53 | cpu: false
54 | data_workers: 4
55 | eval: false
56 | label_file: null
57 | limit_train_batches: 1.0
58 | limit_val_batches: 1.0
59 | limit_test_batches: 1.0
60 | metric_threshold: 0.5
61 | result_dir: runs
62 | save_k_predictions: 0
63 | silent: true
64 | test_path: null
65 | train_path: null
66 | val_path: null
67 | val_size: 0.2
68 |
69 | # LexGLUE
70 | zero: true
71 | multi_class: false
72 | add_special_tokens: true
73 | enable_ce_loss: false
74 | hierarchical: false
75 | accumulate_grad_batches: 1
76 | enable_transformer_trainer: false
77 |
--------------------------------------------------------------------------------
/config/ecthr_b/bert_tuned.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/ecthr_b
3 | data_name: ecthr_b
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: 3e-05
15 | weight_decay: 0.001
16 | patience: 5
17 |
18 | # eval
19 | eval_batch_size: 8
20 | monitor_metrics: ['Micro-F1', 'Macro-F1']
21 | val_metric: Micro-F1
22 |
23 | # model
24 | model_name: BERT
25 | init_weight: null
26 | network_config:
27 | lm_weight: bert-base-uncased
28 | dropout: 0.2
29 |
30 | # pretrained vocab / embeddings
31 | embed_file: null
32 |
--------------------------------------------------------------------------------
/config/ecthr_b/l2svm.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_linear/ecthr_b
3 | data_name: ecthr_b
4 |
5 | # train
6 | seed: 1337
7 | linear: true
8 | liblinear_options: "-s 2 -B 1 -e 0.0001 -q"
9 | linear_technique: 1vsrest
10 |
11 | # eval
12 | eval_batch_size: 256
13 | monitor_metrics: [Micro-F1, Macro-F1]
14 | metric_threshold: 0
15 |
16 | data_format: txt
17 |
--------------------------------------------------------------------------------
/config/eurlex/bert_default.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/eurlex
3 | data_name: eurlex
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: 0.00005
15 | weight_decay: 0.001
16 | patience: 5
17 |
18 | # eval
19 | eval_batch_size: 8
20 | monitor_metrics: ['Micro-F1', 'Macro-F1']
21 | val_metric: Micro-F1
22 |
23 | # model
24 | model_name: BERT
25 | init_weight: null
26 | network_config:
27 | lm_weight: bert-base-uncased
28 |
29 | # pretrained vocab / embeddings
30 | embed_file: null
31 |
--------------------------------------------------------------------------------
/config/eurlex/bert_reproduced.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/eurlex
3 | data_name: eurlex
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 |
9 | # train
10 | seed: 1337
11 | epochs: 20
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: 0.00003
15 | weight_decay: 0
16 | patience: 5
17 |
18 | # eval
19 | eval_batch_size: 8
20 | monitor_metrics: ['Micro-F1', 'Macro-F1']
21 | val_metric: Micro-F1
22 |
23 | # model
24 | model_name: BERT
25 | init_weight: null
26 | network_config:
27 | lm_weight: bert-base-uncased
28 |
29 | # pretrained vocab / embeddings
30 | embed_file: null
31 |
--------------------------------------------------------------------------------
/config/eurlex/bert_tune.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/eurlex
3 | data_name: eurlex
4 | min_vocab_freq: 1
5 | max_seq_length: ['grid_search', [128, 512]]
6 | include_test_labels: true
7 | remove_no_label_data: false
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: ['grid_search', [0.00002, 0.00003, 0.00005]]
15 | momentum: 0
16 | weight_decay: ['grid_search', [0, 0.001]]
17 | patience: 5
18 | shuffle: true
19 |
20 | # eval
21 | eval_batch_size: 8
22 | monitor_metrics: ['Micro-F1', 'Macro-F1']
23 | val_metric: Micro-F1
24 |
25 | # model
26 | model_name: BERT
27 | init_weight: null
28 | network_config:
29 | dropout: ['grid_search', [0.1, 0.2]]
30 | lm_weight: bert-base-uncased
31 |
32 | # pretrained vocab / embeddings
33 | vocab_file: null
34 | embed_file: null
35 | normalize_embed: false
36 |
37 | # hyperparamter search
38 | search_alg: basic_variant
39 | embed_cache_dir: null
40 | num_samples: 1
41 | scheduler: null
42 | # Uncomment the following lines to enable the ASHAScheduler.
43 | # See the documentation here: https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-ashascheduler
44 | #scheduler:
45 | #time_attr: training_iteration
46 | #max_t: 50 # the maximum epochs to run for each config (parameter R in the ASHA paper)
47 | #grace_period: 10 # the minimum epochs to run for each config (parameter r in the ASHA paper)
48 | #reduction_factor: 3 # reduce the number of configuration to floor(1/reduction_factor) each round of successive halving (called rung in ASHA paper)
49 | #brackets: 1 # number of brackets. A smaller bracket index (parameter s in the ASHA paper) means earlier stopping (i.e., less total resources used)
50 |
51 | # other parameters specified in main.py::get_args
52 | checkpoint_path: null
53 | cpu: false
54 | data_workers: 4
55 | eval: false
56 | label_file: null
57 | limit_train_batches: 1.0
58 | limit_val_batches: 1.0
59 | limit_test_batches: 1.0
60 | metric_threshold: 0.5
61 | result_dir: runs
62 | save_k_predictions: 0
63 | silent: true
64 | test_path: null
65 | train_path: null
66 | val_path: null
67 | val_size: 0.2
68 |
69 | # LexGLUE
70 | zero: false
71 | multi_class: false
72 | add_special_tokens: true
73 | enable_ce_loss: false
74 | hierarchical: false
75 | accumulate_grad_batches: 1
76 | enable_transformer_trainer: false
77 |
--------------------------------------------------------------------------------
/config/eurlex/bert_tuned.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/eurlex
3 | data_name: eurlex
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: 5e-05
15 | weight_decay: 0.001
16 | patience: 5
17 |
18 | # eval
19 | eval_batch_size: 8
20 | monitor_metrics: ['Micro-F1', 'Macro-F1']
21 | val_metric: Micro-F1
22 |
23 | # model
24 | model_name: BERT
25 | init_weight: null
26 | network_config:
27 | lm_weight: bert-base-uncased
28 | dropout: 0.1
29 |
30 | # pretrained vocab / embeddings
31 | embed_file: null
32 |
--------------------------------------------------------------------------------
/config/eurlex/l2svm.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_linear/eurlex
3 | data_name: eurlex
4 |
5 | # train
6 | seed: 1337
7 | linear: true
8 | liblinear_options: "-s 2 -B 1 -e 0.0001 -q"
9 | linear_technique: 1vsrest
10 |
11 | # eval
12 | eval_batch_size: 256
13 | monitor_metrics: [Micro-F1, Macro-F1]
14 | metric_threshold: 0
15 |
16 | data_format: txt
17 |
--------------------------------------------------------------------------------
/config/ledgar/bert_default.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/ledgar
3 | data_name: ledgar
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: 0.00005
15 | weight_decay: 0.001
16 | patience: 5
17 |
18 | # eval
19 | eval_batch_size: 8
20 | monitor_metrics: ['Micro-F1', 'Macro-F1']
21 | val_metric: Micro-F1
22 |
23 | # model
24 | model_name: BERT
25 | init_weight: null
26 | network_config:
27 | lm_weight: bert-base-uncased
28 |
29 | # pretrained vocab / embeddings
30 | embed_file: null
31 |
--------------------------------------------------------------------------------
/config/ledgar/bert_reproduced.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/ledgar
3 | data_name: ledgar
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 |
9 | # train
10 | seed: 1337
11 | epochs: 20
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: 0.00003
15 | weight_decay: 0
16 | patience: 5
17 |
18 | # eval
19 | eval_batch_size: 8
20 | monitor_metrics: ['Micro-F1', 'Macro-F1']
21 | val_metric: Micro-F1
22 |
23 | # model
24 | model_name: BERT
25 | init_weight: null
26 | network_config:
27 | lm_weight: bert-base-uncased
28 |
29 | # pretrained vocab / embeddings
30 | embed_file: null
31 |
--------------------------------------------------------------------------------
/config/ledgar/bert_tune.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/ledgar
3 | data_name: ledgar
4 | min_vocab_freq: 1
5 | max_seq_length: ['grid_search', [128, 512]]
6 | include_test_labels: true
7 | remove_no_label_data: false
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: ['grid_search', [0.00002, 0.00003, 0.00005]]
15 | momentum: 0
16 | weight_decay: ['grid_search', [0, 0.001]]
17 | patience: 5
18 | shuffle: true
19 |
20 | # eval
21 | eval_batch_size: 8
22 | monitor_metrics: ['Micro-F1', 'Macro-F1']
23 | val_metric: Micro-F1
24 |
25 | # model
26 | model_name: BERT
27 | init_weight: null
28 | network_config:
29 | dropout: ['grid_search', [0.1, 0.2]]
30 | lm_weight: bert-base-uncased
31 |
32 | # pretrained vocab / embeddings
33 | vocab_file: null
34 | embed_file: null
35 | normalize_embed: false
36 |
37 | # hyperparamter search
38 | search_alg: basic_variant
39 | embed_cache_dir: null
40 | num_samples: 1
41 | scheduler: null
42 | # Uncomment the following lines to enable the ASHAScheduler.
43 | # See the documentation here: https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-ashascheduler
44 | #scheduler:
45 | #time_attr: training_iteration
46 | #max_t: 50 # the maximum epochs to run for each config (parameter R in the ASHA paper)
47 | #grace_period: 10 # the minimum epochs to run for each config (parameter r in the ASHA paper)
48 | #reduction_factor: 3 # reduce the number of configuration to floor(1/reduction_factor) each round of successive halving (called rung in ASHA paper)
49 | #brackets: 1 # number of brackets. A smaller bracket index (parameter s in the ASHA paper) means earlier stopping (i.e., less total resources used)
50 |
51 | # other parameters specified in main.py::get_args
52 | checkpoint_path: null
53 | cpu: false
54 | data_workers: 4
55 | eval: false
56 | label_file: null
57 | limit_train_batches: 1.0
58 | limit_val_batches: 1.0
59 | limit_test_batches: 1.0
60 | metric_threshold: 0.5
61 | result_dir: runs
62 | save_k_predictions: 0
63 | silent: true
64 | test_path: null
65 | train_path: null
66 | val_path: null
67 | val_size: 0.2
68 |
69 | # LexGLUE
70 | zero: false
71 | multi_class: true
72 | add_special_tokens: true
73 | enable_ce_loss: true
74 | hierarchical: false
75 | accumulate_grad_batches: 1
76 | enable_transformer_trainer: false
77 |
--------------------------------------------------------------------------------
/config/ledgar/bert_tuned.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/ledgar
3 | data_name: ledgar
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: 2e-05
15 | weight_decay: 0
16 | patience: 5
17 |
18 | # eval
19 | eval_batch_size: 8
20 | monitor_metrics: ['Micro-F1', 'Macro-F1']
21 | val_metric: Micro-F1
22 |
23 | # model
24 | model_name: BERT
25 | init_weight: null
26 | network_config:
27 | lm_weight: bert-base-uncased
28 | dropout: 0.2
29 |
30 | # pretrained vocab / embeddings
31 | embed_file: null
32 |
--------------------------------------------------------------------------------
/config/ledgar/l2svm.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_linear/ledgar
3 | data_name: ledgar
4 |
5 | # train
6 | seed: 1337
7 | linear: true
8 | liblinear_options: "-s 2 -B 1 -e 0.0001 -q"
9 | linear_technique: 1vsrest
10 |
11 | # eval
12 | eval_batch_size: 256
13 | monitor_metrics: [Micro-F1, Macro-F1]
14 | metric_threshold: 0
15 |
16 | data_format: txt
17 |
--------------------------------------------------------------------------------
/config/scotus/bert_default.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/scotus
3 | data_name: scotus
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: 0.00005
15 | weight_decay: 0.001
16 | patience: 5
17 |
18 | # eval
19 | eval_batch_size: 8
20 | monitor_metrics: ['Micro-F1', 'Macro-F1']
21 | val_metric: Micro-F1
22 |
23 | # model
24 | model_name: BERT
25 | init_weight: null
26 | network_config:
27 | lm_weight: bert-base-uncased
28 |
29 | # pretrained vocab / embeddings
30 | embed_file: null
31 |
--------------------------------------------------------------------------------
/config/scotus/bert_reproduced.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_hier/scotus
3 | data_name: scotus
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 | hierarchical: true
9 | accumulate_grad_batches: 4
10 |
11 | # train
12 | seed: 1337
13 | epochs: 20
14 | batch_size: 2
15 | optimizer: adamw
16 | learning_rate: 0.00003
17 | weight_decay: 0
18 | patience: 5
19 |
20 | # eval
21 | eval_batch_size: 2
22 | monitor_metrics: ['Micro-F1', 'Macro-F1']
23 | val_metric: Micro-F1
24 |
25 | # model
26 | model_name: BERT
27 | init_weight: null
28 | network_config:
29 | lm_weight: bert-base-uncased
30 |
31 | # pretrained vocab / embeddings
32 | embed_file: null
33 |
--------------------------------------------------------------------------------
/config/scotus/bert_tune.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/scotus
3 | data_name: scotus
4 | min_vocab_freq: 1
5 | max_seq_length: ['grid_search', [128, 512]]
6 | include_test_labels: true
7 | remove_no_label_data: false
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: ['grid_search', [0.00002, 0.00003, 0.00005]]
15 | momentum: 0
16 | weight_decay: ['grid_search', [0, 0.001]]
17 | patience: 5
18 | shuffle: true
19 |
20 | # eval
21 | eval_batch_size: 8
22 | monitor_metrics: ['Micro-F1', 'Macro-F1']
23 | val_metric: Micro-F1
24 |
25 | # model
26 | model_name: BERT
27 | init_weight: null
28 | network_config:
29 | dropout: ['grid_search', [0.1, 0.2]]
30 | lm_weight: bert-base-uncased
31 |
32 | # pretrained vocab / embeddings
33 | vocab_file: null
34 | embed_file: null
35 | normalize_embed: false
36 |
37 | # hyperparamter search
38 | search_alg: basic_variant
39 | embed_cache_dir: null
40 | num_samples: 1
41 | scheduler: null
42 | # Uncomment the following lines to enable the ASHAScheduler.
43 | # See the documentation here: https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-ashascheduler
44 | #scheduler:
45 | #time_attr: training_iteration
46 | #max_t: 50 # the maximum epochs to run for each config (parameter R in the ASHA paper)
47 | #grace_period: 10 # the minimum epochs to run for each config (parameter r in the ASHA paper)
48 | #reduction_factor: 3 # reduce the number of configuration to floor(1/reduction_factor) each round of successive halving (called rung in ASHA paper)
49 | #brackets: 1 # number of brackets. A smaller bracket index (parameter s in the ASHA paper) means earlier stopping (i.e., less total resources used)
50 |
51 | # other parameters specified in main.py::get_args
52 | checkpoint_path: null
53 | cpu: false
54 | data_workers: 4
55 | eval: false
56 | label_file: null
57 | limit_train_batches: 1.0
58 | limit_val_batches: 1.0
59 | limit_test_batches: 1.0
60 | metric_threshold: 0.5
61 | result_dir: runs
62 | save_k_predictions: 0
63 | silent: true
64 | test_path: null
65 | train_path: null
66 | val_path: null
67 | val_size: 0.2
68 |
69 | # LexGLUE
70 | zero: false
71 | multi_class: true
72 | add_special_tokens: true
73 | enable_ce_loss: true
74 | hierarchical: false
75 | accumulate_grad_batches: 1
76 | enable_transformer_trainer: false
77 |
--------------------------------------------------------------------------------
/config/scotus/bert_tuned.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/scotus
3 | data_name: scotus
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: 2e-05
15 | weight_decay: 0
16 | patience: 5
17 |
18 | # eval
19 | eval_batch_size: 8
20 | monitor_metrics: ['Micro-F1', 'Macro-F1']
21 | val_metric: Micro-F1
22 |
23 | # model
24 | model_name: BERT
25 | init_weight: null
26 | network_config:
27 | lm_weight: bert-base-uncased
28 | dropout: 0.1
29 |
30 | # pretrained vocab / embeddings
31 | embed_file: null
32 |
--------------------------------------------------------------------------------
/config/scotus/l2svm.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_linear/scotus
3 | data_name: scotus
4 |
5 | # train
6 | seed: 1337
7 | linear: true
8 | liblinear_options: "-s 2 -B 1 -e 0.0001 -q"
9 | linear_technique: 1vsrest
10 |
11 | # eval
12 | eval_batch_size: 256
13 | monitor_metrics: [Micro-F1, Macro-F1]
14 | metric_threshold: 0
15 |
16 | data_format: txt
17 |
--------------------------------------------------------------------------------
/config/unfair_tos/bert_default.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/unfair_tos
3 | data_name: unfair_tos
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: 0.00005
15 | weight_decay: 0.001
16 | patience: 5
17 |
18 | # eval
19 | eval_batch_size: 8
20 | monitor_metrics: ['Micro-F1', 'Macro-F1']
21 | val_metric: Micro-F1
22 |
23 | # model
24 | model_name: BERT
25 | init_weight: null
26 | network_config:
27 | lm_weight: bert-base-uncased
28 |
29 | # pretrained vocab / embeddings
30 | embed_file: null
31 |
--------------------------------------------------------------------------------
/config/unfair_tos/bert_reproduced.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/unfair_tos
3 | data_name: unfair_tos
4 | min_vocab_freq: 1
5 | max_seq_length: 128
6 | include_test_labels: true
7 | add_special_tokens: true
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: 0.00003
15 | weight_decay: 0
16 | patience: 5
17 |
18 | # eval
19 | eval_batch_size: 8
20 | monitor_metrics: ['Micro-F1', 'Macro-F1']
21 | val_metric: Micro-F1
22 |
23 | # model
24 | model_name: BERT
25 | init_weight: null
26 | network_config:
27 | lm_weight: bert-base-uncased
28 |
29 | # pretrained vocab / embeddings
30 | embed_file: null
31 |
--------------------------------------------------------------------------------
/config/unfair_tos/bert_tune.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/unfair_tos
3 | data_name: unfair_tos
4 | min_vocab_freq: 1
5 | max_seq_length: ['grid_search', [128, 512]]
6 | include_test_labels: true
7 | remove_no_label_data: false
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: ['grid_search', [0.00002, 0.00003, 0.00005]]
15 | momentum: 0
16 | weight_decay: ['grid_search', [0, 0.001]]
17 | patience: 5
18 | shuffle: true
19 |
20 | # eval
21 | eval_batch_size: 8
22 | monitor_metrics: ['Micro-F1', 'Macro-F1']
23 | val_metric: Micro-F1
24 |
25 | # model
26 | model_name: BERT
27 | init_weight: null
28 | network_config:
29 | dropout: ['grid_search', [0.1, 0.2]]
30 | lm_weight: bert-base-uncased
31 |
32 | # pretrained vocab / embeddings
33 | vocab_file: null
34 | embed_file: null
35 | normalize_embed: false
36 |
37 | # hyperparamter search
38 | search_alg: basic_variant
39 | embed_cache_dir: null
40 | num_samples: 1
41 | scheduler: null
42 | # Uncomment the following lines to enable the ASHAScheduler.
43 | # See the documentation here: https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-ashascheduler
44 | #scheduler:
45 | #time_attr: training_iteration
46 | #max_t: 50 # the maximum epochs to run for each config (parameter R in the ASHA paper)
47 | #grace_period: 10 # the minimum epochs to run for each config (parameter r in the ASHA paper)
48 | #reduction_factor: 3 # reduce the number of configuration to floor(1/reduction_factor) each round of successive halving (called rung in ASHA paper)
49 | #brackets: 1 # number of brackets. A smaller bracket index (parameter s in the ASHA paper) means earlier stopping (i.e., less total resources used)
50 |
51 | # other parameters specified in main.py::get_args
52 | checkpoint_path: null
53 | cpu: false
54 | data_workers: 4
55 | eval: false
56 | label_file: null
57 | limit_train_batches: 1.0
58 | limit_val_batches: 1.0
59 | limit_test_batches: 1.0
60 | metric_threshold: 0.5
61 | result_dir: runs
62 | save_k_predictions: 0
63 | silent: true
64 | test_path: null
65 | train_path: null
66 | val_path: null
67 | val_size: 0.2
68 |
69 | # LexGLUE
70 | zero: true
71 | multi_class: false
72 | add_special_tokens: true
73 | enable_ce_loss: false
74 | hierarchical: false
75 | accumulate_grad_batches: 1
76 | enable_transformer_trainer: false
77 |
--------------------------------------------------------------------------------
/config/unfair_tos/bert_tuned.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_nn/unfair_tos
3 | data_name: unfair_tos
4 | min_vocab_freq: 1
5 | max_seq_length: 512
6 | include_test_labels: true
7 | add_special_tokens: true
8 |
9 | # train
10 | seed: 1337
11 | epochs: 15
12 | batch_size: 8
13 | optimizer: adamw
14 | learning_rate: 3e-05
15 | weight_decay: 0
16 | patience: 5
17 |
18 | # eval
19 | eval_batch_size: 8
20 | monitor_metrics: ['Micro-F1', 'Macro-F1']
21 | val_metric: Micro-F1
22 |
23 | # model
24 | model_name: BERT
25 | init_weight: null
26 | network_config:
27 | lm_weight: bert-base-uncased
28 | dropout: 0.1
29 |
30 | # pretrained vocab / embeddings
31 | embed_file: null
32 |
--------------------------------------------------------------------------------
/config/unfair_tos/l2svm.yml:
--------------------------------------------------------------------------------
1 | # data
2 | data_dir: data_linear/unfair_tos
3 | data_name: unfair_tos
4 |
5 | # train
6 | seed: 1337
7 | linear: true
8 | liblinear_options: "-s 2 -B 1 -e 0.0001 -q"
9 | linear_technique: 1vsrest
10 |
11 | # eval
12 | eval_batch_size: 256
13 | monitor_metrics: [Micro-F1, Macro-F1]
14 | metric_threshold: 0
15 |
16 | data_format: txt
17 |
--------------------------------------------------------------------------------
/generate_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import re
4 | from datasets import load_dataset
5 |
6 |
7 | def get_args():
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('-dl', '--data_list', type=list,
10 | default=['ecthr_a', 'ecthr_b', 'scotus', 'eurlex', 'ledgar', 'unfair_tos'])
11 | parser.add_argument('-sl', '--split_list', type=list, default = ['train', 'validation', 'test'])
12 | parser.add_argument('-ddp', '--data_dir_prefix', type=str, default='data')
13 | parser.add_argument('-f', '--format', type=str, choices=['linear', 'nn', 'hier'], required=True)
14 | args = parser.parse_args()
15 | return args
16 |
17 |
18 | def data2task(data):
19 | return 'multi_label' if data not in ['scotus', 'ledgar'] else 'multi_class'
20 |
21 |
22 | def data2hier(data):
23 | return True if data in ['ecthr_a', 'ecthr_b', 'scotus'] else False
24 |
25 |
26 | def split2name(split):
27 | _split2name = {'train': 'train.txt', 'validation': 'valid.txt', 'test': 'test.txt'}
28 | return _split2name[split]
29 |
30 |
31 | def get_texts(data, dataset, hier=False):
32 | if 'ecthr' in data:
33 | if not hier:
34 | texts = [' '.join(text) for text in dataset['text']]
35 | else:
36 | texts = [' [HIER] '.join(text) for text in dataset['text']]
37 | return [' '.join(text.split()) for text in texts]
38 | elif data == 'scotus' and hier:
39 | texts = [' [HIER] '.join(re.split('\n{2,}', text)) for text in dataset['text']]
40 | # Huggingface tokenizer ignores newline and tab,
41 | # so it's okay to replace them with a space here.
42 | for i in range(len(texts)):
43 | texts[i] = texts[i].replace('\n', ' ')
44 | texts[i] = texts[i].replace('\r', ' ')
45 | texts[i] = texts[i].replace('\t', ' ')
46 | return texts
47 | elif data == 'case_hold':
48 | return [contexts[0] + ' [SEP] '.join(holdings)
49 | for contexts, holdings in zip(dataset['contexts'], dataset['endings'])]
50 | else:
51 | return [' '.join(text.split()) for text in dataset['text']]
52 |
53 |
54 | def get_labels(data, dataset, task):
55 | if task == 'multi_class':
56 | return list(map(str, dataset['label']))
57 | else:
58 | if data == 'eurlex':
59 | return [' '.join(map(str, [l for l in label if l < 100])) for label in dataset['labels']]
60 | else:
61 | return [' '.join(map(str, label)) for label in dataset['labels']]
62 |
63 |
64 | def save_data(data_path, data):
65 | with open(data_path, 'w') as f:
66 | for text, label in zip(data['text'], data['labels']):
67 | assert '\n' not in label+text
68 | assert '\r' not in label+text
69 | assert '\t' not in label+text
70 | formatted_instance = '\t'.join([label, text])
71 | f.write(f'{formatted_instance}\n')
72 |
73 |
74 | def main():
75 | # args
76 | args = get_args()
77 | data_dir = f'{args.data_dir_prefix}_{args.format}'
78 | os.makedirs(data_dir, exist_ok=True)
79 |
80 | # generate
81 | for data in args.data_list:
82 | if args.format == 'hier' and not data2hier(data):
83 | continue
84 | data_path = os.path.join(data_dir, data)
85 | os.makedirs(data_path, exist_ok=True)
86 | processed_data = {}
87 | for split in args.split_list:
88 | dataset = load_dataset('coastalcph/lex_glue', data, split=split, trust_remote_code=True)
89 | texts = get_texts(data, dataset, hier=args.format == 'hier')
90 | labels = get_labels(data, dataset, data2task(data))
91 | assert len(texts) == len(labels)
92 | print(f'{data} ({split}): num_instance = {len(texts)}')
93 | processed_data[split] = {'text': texts, 'labels': labels}
94 | # format
95 | if args.format == 'linear':
96 | # train
97 | train_path = os.path.join(data_path, split2name('train'))
98 | train_data = {
99 | 'text': processed_data['train']['text'] + processed_data['validation']['text'],
100 | 'labels': processed_data['train']['labels'] + processed_data['validation']['labels']
101 | }
102 | save_data(train_path, train_data)
103 | # test
104 | test_path = os.path.join(data_path, split2name('test'))
105 | test_data = processed_data['test']
106 | save_data(test_path, test_data)
107 | elif args.format == 'nn' or args.format == 'hier':
108 | # train/validation/test
109 | for split in processed_data:
110 | split_path = os.path.join(data_path, split2name(split))
111 | save_data(split_path, processed_data[split])
112 |
113 |
114 | if __name__ == '__main__':
115 | main()
116 |
--------------------------------------------------------------------------------
/generate_data.sh:
--------------------------------------------------------------------------------
1 | format_list=(linear nn hier)
2 |
3 | format=$1
4 |
5 | if [ $# -eq 0 ]; then
6 | for format in "${format_list[@]}"; do
7 | python3 generate_data.py -f ${format}
8 | done
9 | else
10 | if [[ ! " ${format_list[*]} " =~ " ${format} " ]]; then
11 | echo "Invalid argument! Format ${format} is not in (${format_list[*]})."
12 | exit
13 | else
14 | python3 generate_data.py -f ${format}
15 | fi
16 | fi
17 |
--------------------------------------------------------------------------------
/libmultilabel/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JamesLYC88/text_classification_baseline_code/6436f567372f3c36547cd52da79516900e5a6148/libmultilabel/__init__.py
--------------------------------------------------------------------------------
/libmultilabel/common_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import logging
4 | import os
5 | import time
6 |
7 | import numpy as np
8 |
9 |
10 | class Timer(object):
11 | """Computes elasped time."""
12 |
13 | def __init__(self):
14 | self.reset()
15 |
16 | def reset(self):
17 | self.running = True
18 | self.total = 0
19 | self.start = time.time()
20 | return self
21 |
22 | def resume(self):
23 | if not self.running:
24 | self.running = True
25 | self.start = time.time()
26 | return self
27 |
28 | def stop(self):
29 | if self.running:
30 | self.running = False
31 | self.total += time.time() - self.start
32 | return self
33 |
34 | def time(self):
35 | if self.running:
36 | return self.total + time.time() - self.start
37 | return self.total
38 |
39 |
40 | def dump_log(log_path, metrics=None, split=None, config=None):
41 | """Write log including config and the evaluation scores.
42 |
43 | Args:
44 | log_path(str): path to log path
45 | metrics (dict): metric and scores in dictionary format, defaults to None
46 | split (str): val or test, defaults to None
47 | config (dict): config to save, defaults to None
48 | """
49 | os.makedirs(os.path.dirname(log_path), exist_ok=True)
50 | if os.path.isfile(log_path):
51 | with open(log_path) as fp:
52 | result = json.load(fp)
53 | else:
54 | result = dict()
55 |
56 | if config:
57 | config_to_save = copy.deepcopy(dict(config))
58 | config_to_save.pop('device', None) # delete if device exists
59 | result['config'] = config_to_save
60 | if split and metrics:
61 | if split in result:
62 | result[split].append(metrics)
63 | else:
64 | result[split] = [metrics]
65 |
66 | with open(log_path, 'w') as fp:
67 | json.dump(result, fp)
68 |
69 | logging.info(f'Finish writing log to {log_path}.')
70 |
71 |
72 | def argsort_top_k(vals, k, axis=-1):
73 | unsorted_top_k_idx = np.argpartition(vals, -k, axis=axis)[:, -k:]
74 | unsorted_top_k_scores = np.take_along_axis(
75 | vals, unsorted_top_k_idx, axis=axis)
76 | sorted_order = np.argsort(-unsorted_top_k_scores, axis=axis)
77 | sorted_top_k_idx = np.take_along_axis(
78 | unsorted_top_k_idx, sorted_order, axis=axis)
79 | return sorted_top_k_idx
80 |
81 |
82 | class AttributeDict(dict):
83 | """AttributeDict is an extended dict that can access
84 | stored items as attributes.
85 |
86 | >>> ad = AttributeDict({'ans': 42})
87 | >>> ad.ans
88 | >>> 42
89 | """
90 |
91 | def __getattr__(self, key: str) -> any:
92 | try:
93 | return self[key]
94 | except KeyError:
95 | raise AttributeError(f'Missing attribute "{key}"')
96 |
97 | def __setattr__(self, key: str, value: any) -> None:
98 | self[key] = value
99 |
--------------------------------------------------------------------------------
/libmultilabel/linear/__init__.py:
--------------------------------------------------------------------------------
1 | from .linear import *
2 | from .metrics import get_metrics, tabulate_metrics
3 | from .preprocessor import *
4 | from .utils import *
5 |
--------------------------------------------------------------------------------
/libmultilabel/linear/linear.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import numpy as np
5 | import scipy.sparse as sparse
6 | from liblinear.liblinearutil import train
7 | from tqdm import tqdm
8 |
9 | __all__ = ['train_1vsrest',
10 | 'train_thresholding',
11 | 'train_cost_sensitive',
12 | 'train_cost_sensitive_micro',
13 | 'predict_values']
14 |
15 |
16 | def train_1vsrest(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str):
17 | """Trains a linear model for multiabel data using a one-vs-rest strategy.
18 |
19 | Args:
20 | y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
21 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
22 | options (str): The option string passed to liblinear.
23 |
24 | Returns:
25 | A model which can be used in predict_values.
26 | """
27 | # Follows the MATLAB implementation at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/multilabel/
28 | x, options, bias = prepare_options(x, options)
29 |
30 | y = y.tocsc()
31 | num_class = y.shape[1]
32 | num_feature = x.shape[1]
33 | weights = np.zeros((num_feature, num_class), order='F')
34 |
35 | logging.info(f'Training one-vs-rest model on {num_class} labels')
36 | for i in tqdm(range(num_class)):
37 | yi = y[:, i].toarray().reshape(-1)
38 | modeli = train(2*yi - 1, x, options)
39 | w = np.ctypeslib.as_array(modeli.w, (num_feature,))
40 | # Liblinear flips +1/-1 labels so +1 is always the first label,
41 | # but not if all labels are -1.
42 | # For our usage, we need +1 to always be the first label,
43 | # so the check is necessary.
44 | if modeli.get_labels()[0] == -1:
45 | weights[:, i] = -w
46 | else:
47 | weights[:, i] = w
48 |
49 | return {'weights': np.asmatrix(weights), '-B': bias, 'threshold': 0}
50 |
51 |
52 | def prepare_options(x: sparse.csr_matrix, options: str) -> 'tuple[sparse.csr_matrix, str, float]':
53 | """Prepare options and x for multi-label training. Called in the first line of
54 | any training function.
55 |
56 | Args:
57 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
58 | options (str): The option string passed to liblinear.
59 |
60 | Returns:
61 | tuple[sparse.csr_matrix, str, float]: Transformed x, transformed options and
62 | bias parsed from options.
63 | """
64 | if options is None:
65 | options = ''
66 | if any(o in options for o in ['-R', '-C', '-v']):
67 | raise ValueError('-R, -C and -v are not supported')
68 |
69 | bias = -1.
70 | if options.find('-B') != -1:
71 | options_split = options.split()
72 | i = options_split.index('-B')
73 | bias = float(options_split[i+1])
74 | options = ' '.join(options_split[:i] + options_split[i+2:])
75 | x = sparse.hstack([
76 | x,
77 | np.full((x.shape[0], 1), bias),
78 | ], 'csr')
79 |
80 | if not '-q' in options:
81 | options += ' -q'
82 |
83 | if not '-m' in options:
84 | options += f' -m {int(os.cpu_count() / 2)}'
85 |
86 | return x, options, bias
87 |
88 |
89 | def train_thresholding(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str):
90 | """Trains a linear model for multilabel data using a one-vs-rest strategy
91 | and cross-validation to pick an optimal decision threshold for Macro-F1.
92 | Outperforms train_1vsrest in most aspects at the cost of higher
93 | time complexity.
94 | See user guide for more details.
95 |
96 | Args:
97 | y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
98 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
99 | options (str): The option string passed to liblinear.
100 |
101 | Returns:
102 | A model which can be used in predict_values.
103 | """
104 | # Follows the MATLAB implementation at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/multilabel/
105 | x, options, bias = prepare_options(x, options)
106 |
107 | y = y.tocsc()
108 | num_class = y.shape[1]
109 | num_feature = x.shape[1]
110 | weights = np.zeros((num_feature, num_class), order='F')
111 | thresholds = np.zeros(num_class)
112 |
113 | logging.info(f'Training thresholding model on {num_class} labels')
114 | for i in tqdm(range(num_class)):
115 | yi = y[:, i].toarray().reshape(-1)
116 | w, t = thresholding_one_label(2*yi - 1, x, options)
117 | weights[:, i] = w.ravel()
118 | thresholds[i] = t
119 |
120 | return {'weights': np.asmatrix(weights), '-B': bias, 'threshold': thresholds}
121 |
122 |
123 | def thresholding_one_label(y: np.ndarray,
124 | x: sparse.csr_matrix,
125 | options: str
126 | ) -> 'tuple[np.ndarray, float]':
127 | """Outer cross-validation for thresholding on a single label.
128 |
129 | Args:
130 | y (np.ndarray): A +1/-1 array with dimensions number of instances * 1.
131 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
132 | options (str): The option string passed to liblinear.
133 |
134 | Returns:
135 | tuple[np.ndarray, float]: tuple of the weights and threshold.
136 | """
137 | fbr_list = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
138 |
139 | nr_fold = 3
140 |
141 | l = y.shape[0]
142 |
143 | perm = np.random.permutation(l)
144 |
145 | f_list = np.zeros_like(fbr_list)
146 |
147 | for fold in range(nr_fold):
148 | mask = np.zeros_like(perm, dtype='?')
149 | mask[np.arange(int(fold*l/nr_fold), int((fold+1)*l/nr_fold))] = 1
150 | val_idx = perm[mask]
151 | train_idx = perm[mask != True]
152 |
153 | scutfbr_w, scutfbr_b_list = scutfbr(
154 | y[train_idx], x[train_idx], fbr_list, options)
155 | wTx = (x[val_idx] * scutfbr_w).A1
156 |
157 | for i in range(fbr_list.size):
158 | F = fmeasure(y[val_idx], 2*(wTx > -scutfbr_b_list[i]) - 1)
159 | f_list[i] += F
160 |
161 | best_fbr = fbr_list[::-1][np.argmax(f_list[::-1])] # last largest
162 | if np.max(f_list) == 0:
163 | best_fbr = np.min(fbr_list)
164 |
165 | # final model
166 | w, b_list = scutfbr(y, x, np.array([best_fbr]), options)
167 |
168 | return w, b_list[0]
169 |
170 |
171 | def scutfbr(y: np.ndarray,
172 | x: sparse.csr_matrix,
173 | fbr_list: 'list[float]',
174 | options: str
175 | ) -> 'tuple[np.matrix, np.ndarray]':
176 | """Inner cross-validation for SCutfbr heuristic.
177 |
178 | Args:
179 | y (np.ndarray): A +1/-1 array with dimensions number of instances * 1.
180 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
181 | fbr_list (list[float]): list of fbr values.
182 | options (str): The option string passed to liblinear.
183 |
184 | Returns:
185 | tuple[np.matrix, np.ndarray]: tuple of weights and threshold candidates.
186 | """
187 |
188 | b_list = np.zeros_like(fbr_list)
189 |
190 | nr_fold = 3
191 |
192 | l = y.shape[0]
193 |
194 | perm = np.random.permutation(l)
195 |
196 | for fold in range(nr_fold):
197 | mask = np.zeros_like(perm, dtype='?')
198 | mask[np.arange(int(fold*l/nr_fold), int((fold+1)*l/nr_fold))] = 1
199 | val_idx = perm[mask]
200 | train_idx = perm[mask != True]
201 |
202 | w = do_train(y[train_idx], x[train_idx], options)
203 |
204 | wTx = (x[val_idx] * w).A1
205 | scut_b = 0.
206 | start_F = fmeasure(y[val_idx], 2*(wTx > -scut_b) - 1)
207 |
208 | # stableness to match the MATLAB implementation
209 | sorted_wTx_index = np.argsort(wTx, kind='stable')
210 | sorted_wTx = wTx[sorted_wTx_index]
211 |
212 | tp = np.sum(y[val_idx] == 1)
213 | fp = val_idx.size - tp
214 | fn = 0
215 | cut = -1
216 | best_F = 2*tp / (2*tp + fp + fn)
217 | y_val = y[val_idx]
218 |
219 | # following MATLAB implementation to suppress NaNs
220 | prev_settings = np.seterr('ignore')
221 | for i in range(val_idx.size):
222 | if y_val[sorted_wTx_index[i]] == -1:
223 | fp -= 1
224 | else:
225 | tp -= 1
226 | fn += 1
227 |
228 | # There will be NaNs, but the behaviour is correct
229 | F = 2*tp / (2*tp + fp + fn)
230 |
231 | if F >= best_F:
232 | best_F = F
233 | cut = i
234 | np.seterr(**prev_settings)
235 |
236 | if best_F > start_F:
237 | if cut == -1: # i.e. all 1 in scut
238 | scut_b = np.nextafter(-sorted_wTx[0], np.inf) # predict all 1
239 | elif cut == val_idx.size - 1:
240 | scut_b = np.nextafter(-sorted_wTx[-1], np.inf)
241 | else:
242 | scut_b = -(sorted_wTx[cut] + sorted_wTx[cut + 1]) / 2
243 |
244 | F = fmeasure(y_val, 2*(wTx > -scut_b) - 1)
245 |
246 | for i in range(fbr_list.size):
247 | if F > fbr_list[i]:
248 | b_list[i] += scut_b
249 | else:
250 | b_list[i] -= np.max(wTx)
251 |
252 | b_list = b_list / nr_fold
253 | return do_train(y, x, options), b_list
254 |
255 |
256 | def do_train(y: np.ndarray, x: sparse.csr_matrix, options: str) -> np.matrix:
257 | """Wrapper around liblinear.liblinearutil.train.
258 | Forcibly suppresses all IO regardless of options.
259 |
260 | Args:
261 | y (np.ndarray): A +1/-1 array with dimensions number of instances * 1.
262 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
263 | options (str): The option string passed to liblinear.
264 |
265 | Returns:
266 | np.matrix: the weights.
267 | """
268 | with silent_stderr():
269 | model = train(y, x, options)
270 |
271 | w = np.ctypeslib.as_array(model.w, (x.shape[1], 1))
272 | w = np.asmatrix(w)
273 | # Liblinear flips +1/-1 labels so +1 is always the first label,
274 | # but not if all labels are -1.
275 | # For our usage, we need +1 to always be the first label,
276 | # so the check is necessary.
277 | if model.get_labels()[0] == -1:
278 | return -w
279 | else:
280 | # The memory is freed on model deletion so we make a copy.
281 | return w.copy()
282 |
283 |
284 | class silent_stderr:
285 | """Context manager that suppresses stderr.
286 | Liblinear emits warnings on missing classes with
287 | specified weight, which may happen during cross-validation.
288 | Since this information is useless to the user, we suppress it.
289 | """
290 |
291 | def __init__(self):
292 | self.stderr = os.dup(2)
293 | self.devnull = os.open(os.devnull, os.O_WRONLY)
294 |
295 | def __enter__(self):
296 | os.dup2(self.devnull, 2)
297 |
298 | def __exit__(self, type, value, traceback):
299 | os.dup2(self.stderr, 2)
300 | os.close(self.devnull)
301 | os.close(self.stderr)
302 |
303 |
304 | def fmeasure(y_true: np.ndarray, y_pred: np.ndarray) -> float:
305 | """Calculate F1 score.
306 |
307 | Args:
308 | y_true (np.ndarray): array of +1/-1.
309 | y_pred (np.ndarray): array of +1/-1.
310 |
311 | Returns:
312 | float: the F1 score.
313 | """
314 | tp = np.sum(np.logical_and(y_true == 1, y_pred == 1))
315 | fn = np.sum(np.logical_and(y_true == 1, y_pred == -1))
316 | fp = np.sum(np.logical_and(y_true == -1, y_pred == 1))
317 |
318 | F = 0
319 | if tp != 0 or fp != 0 or fn != 0:
320 | F = 2*tp / (2*tp + fp + fn)
321 | return F
322 |
323 |
324 | def train_cost_sensitive(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str):
325 | """Trains a linear model for multilabel data using a one-vs-rest strategy
326 | and cross-validation to pick an optimal asymmetric misclassification cost
327 | for Macro-F1.
328 | Outperforms train_1vsrest in most aspects at the cost of higher
329 | time complexity.
330 | See user guide for more details.
331 |
332 | Args:
333 | y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
334 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
335 | options (str): The option string passed to liblinear.
336 |
337 | Returns:
338 | A model which can be used in predict_values.
339 | """
340 | # Follows the MATLAB implementation at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/multilabel/
341 | x, options, bias = prepare_options(x, options)
342 |
343 | y = y.tocsc()
344 | num_class = y.shape[1]
345 | num_feature = x.shape[1]
346 | weights = np.zeros((num_feature, num_class), order='F')
347 |
348 | logging.info(
349 | f'Training cost-sensitive model for Macro-F1 on {num_class} labels')
350 | for i in tqdm(range(num_class)):
351 | yi = y[:, i].toarray().reshape(-1)
352 | w = cost_sensitive_one_label(2*yi - 1, x, options)
353 | weights[:, i] = w.ravel()
354 |
355 | return {'weights': np.asmatrix(weights), '-B': bias, 'threshold': 0}
356 |
357 |
358 | def cost_sensitive_one_label(y: np.ndarray,
359 | x: sparse.csr_matrix,
360 | options: str
361 | ) -> np.ndarray:
362 | """Loop over parameter space for cost-sensitive on a single label.
363 |
364 | Args:
365 | y (np.ndarray): A +1/-1 array with dimensions number of instances * 1.
366 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
367 | options (str): The option string passed to liblinear.
368 |
369 | Returns:
370 | np.ndarray: the weights.
371 | """
372 |
373 | l = y.shape[0]
374 | perm = np.random.permutation(l)
375 |
376 | param_space = [1, 1.33, 1.8, 2.5, 3.67, 6, 13]
377 |
378 | bestScore = -np.Inf
379 | for a in param_space:
380 | cv_options = f'{options} -w1 {a}'
381 | pred = cross_validate(y, x, cv_options, perm)
382 | score = fmeasure(y, pred)
383 | if bestScore < score:
384 | bestScore = score
385 | bestA = a
386 |
387 | final_options = f'{options} -w1 {bestA}'
388 | return do_train(y, x, final_options)
389 |
390 |
391 | def cross_validate(y: np.ndarray,
392 | x: sparse.csr_matrix,
393 | options: str,
394 | perm: np.ndarray
395 | ) -> np.ndarray:
396 | """Cross-validation for cost-sensitive.
397 |
398 | Args:
399 | y (np.ndarray): A +1/-1 array with dimensions number of instances * 1.
400 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
401 | options (str): The option string passed to liblinear.
402 |
403 | Returns:
404 | np.ndarray: Cross-validation result as a +1/-1 array.
405 | """
406 | l = y.shape[0]
407 | nr_fold = 3
408 | pred = np.zeros_like(y)
409 | for fold in range(nr_fold):
410 | mask = np.zeros_like(perm, dtype='?')
411 | mask[np.arange(int(fold*l/nr_fold), int((fold+1)*l/nr_fold))] = 1
412 | val_idx = perm[mask]
413 | train_idx = perm[mask != True]
414 |
415 | w = do_train(y[train_idx], x[train_idx], options)
416 | pred[val_idx] = (x[val_idx] * w).A1 > 0
417 |
418 | return 2*pred - 1
419 |
420 |
421 | def train_cost_sensitive_micro(y: sparse.csr_matrix, x: sparse.csr_matrix, options: str):
422 | """Trains a linear model for multilabel data using a one-vs-rest strategy
423 | and cross-validation to pick an optimal asymmetric misclassification cost
424 | for Micro-F1.
425 | Outperforms train_1vsrest in most aspects at the cost of higher
426 | time complexity.
427 | See user guide for more details.
428 |
429 | Args:
430 | y (sparse.csr_matrix): A 0/1 matrix with dimensions number of instances * number of classes.
431 | x (sparse.csr_matrix): A matrix with dimensions number of instances * number of features.
432 | options (str): The option string passed to liblinear.
433 |
434 | Returns:
435 | A model which can be used in predict_values.
436 | """
437 | # Follows the MATLAB implementation at https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/multilabel/
438 | x, options, bias = prepare_options(x, options)
439 |
440 | y = y.tocsc()
441 | num_class = y.shape[1]
442 | num_feature = x.shape[1]
443 | weights = np.zeros((num_feature, num_class), order='F')
444 |
445 | l = y.shape[0]
446 | perm = np.random.permutation(l)
447 | param_space = [1, 1.33, 1.8, 2.5, 3.67, 6, 13]
448 | bestScore = -np.Inf
449 |
450 | logging.info(
451 | f'Training cost-sensitive model for Micro-F1 on {num_class} labels')
452 | for a in param_space:
453 | tp = fn = fp = 0
454 | for i in tqdm(range(num_class)):
455 | yi = y[:, i].toarray().reshape(-1)
456 | yi = 2*yi - 1
457 |
458 | cv_options = f'{options} -w1 {a}'
459 | pred = cross_validate(yi, x, cv_options, perm)
460 | tp = tp + np.sum(np.logical_and(yi == 1, pred == 1))
461 | fn = fn + np.sum(np.logical_and(yi == 1, pred == -1))
462 | fp = fp + np.sum(np.logical_and(yi == -1, pred == 1))
463 |
464 | score = 2*tp / (2*tp + fn + fp)
465 | if bestScore < score:
466 | bestScore = score
467 | bestA = a
468 |
469 | final_options = f'{options} -w1 {bestA}'
470 | for i in range(num_class):
471 | yi = y[:, i].toarray().reshape(-1)
472 | w = do_train(2*yi - 1, x, final_options)
473 | weights[:, i] = w.ravel()
474 |
475 | return {'weights': np.asmatrix(weights), '-B': bias, 'threshold': 0}
476 |
477 |
478 | def predict_values(model, x: sparse.csr_matrix) -> np.ndarray:
479 | """Calculates the decision values associated with x.
480 |
481 | Args:
482 | model: A model returned from a training function.
483 | x (sparse.csr_matrix): A matrix with dimension number of instances * number of features.
484 |
485 | Returns:
486 | np.ndarray: A matrix with dimension number of instances * number of classes.
487 | """
488 | bias = model['-B']
489 | bias_col = np.full((x.shape[0], 1 if bias > 0 else 0), bias)
490 | num_feature = model['weights'].shape[0]
491 | num_feature -= 1 if bias > 0 else 0
492 | if x.shape[1] < num_feature:
493 | x = sparse.hstack([
494 | x,
495 | np.zeros((x.shape[0], num_feature - x.shape[1])),
496 | bias_col,
497 | ], 'csr')
498 | else:
499 | x = sparse.hstack([
500 | x[:, :num_feature],
501 | bias_col,
502 | ], 'csr')
503 |
504 | return (x * model['weights']).A + model['threshold']
505 |
--------------------------------------------------------------------------------
/libmultilabel/linear/metrics.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import numpy as np
4 |
5 | __all__ = ['RPrecision',
6 | 'Precision',
7 | 'F1',
8 | 'MetricCollection',
9 | 'get_metrics',
10 | 'tabulate_metrics']
11 |
12 |
13 | class RPrecision:
14 | def __init__(self, top_k: int) -> None:
15 | self.top_k = top_k
16 | self.score = 0
17 | self.num_sample = 0
18 |
19 | def update(self, preds: np.ndarray, target: np.ndarray) -> None:
20 | assert preds.shape == target.shape # (batch_size, num_classes)
21 | top_k_ind = np.argpartition(preds, -self.top_k)[:, -self.top_k:]
22 | num_relevant = np.take_along_axis(
23 | target, top_k_ind, axis=-1).sum(axis=-1) # (batch_size, top_k)
24 | self.score += np.nan_to_num(
25 | num_relevant / np.minimum(self.top_k, target.sum(axis=-1)),
26 | posinf=0.
27 | ).sum()
28 | self.num_sample += preds.shape[0]
29 |
30 | def compute(self) -> float:
31 | return self.score / self.num_sample
32 |
33 |
34 | class Precision:
35 | def __init__(self, num_classes: int, average: str, top_k: int) -> None:
36 | self.top_k = top_k
37 | self.score = 0
38 | self.num_sample = 0
39 |
40 | def update(self, preds: np.ndarray, target: np.ndarray) -> None:
41 | assert preds.shape == target.shape # (batch_size, num_classes)
42 | top_k_ind = np.argpartition(preds, -self.top_k)[:, -self.top_k:]
43 | num_relevant = np.take_along_axis(target, top_k_ind, -1).sum()
44 | self.score += num_relevant / self.top_k
45 | self.num_sample += preds.shape[0]
46 |
47 | def compute(self) -> float:
48 | return self.score / self.num_sample
49 |
50 |
51 | def add_zero_class(labels):
52 | augmented_labels = np.zeros((len(labels), len(labels[0]) + 1), dtype=np.int32)
53 | augmented_labels[:, :-1] = labels
54 | augmented_labels[:, -1] = (np.sum(labels, axis=1) == 0).astype('int32')
55 | return augmented_labels
56 |
57 |
58 | class F1:
59 | def __init__(self, num_classes: int, metric_threshold: float, average: str,
60 | zero: bool, multi_class: bool) -> None:
61 | self.num_classes = num_classes
62 | self.metric_threshold = metric_threshold
63 | if average not in {'macro', 'micro', 'another-macro'}:
64 | raise ValueError('unsupported average')
65 | self.average = average
66 | self.tp = self.fp = self.fn = 0
67 | self.zero = zero
68 | if self.zero:
69 | self.num_classes += 1
70 | self.multi_class = multi_class
71 |
72 | def update(self, preds: np.ndarray, target: np.ndarray) -> None:
73 | if self.multi_class:
74 | preds = np.eye(preds.shape[1])[preds.argmax(1)]
75 | else:
76 | preds = preds > self.metric_threshold
77 | if self.zero:
78 | preds = add_zero_class(preds)
79 | target = add_zero_class(target)
80 | assert preds.shape == target.shape # (batch_size, num_classes)
81 | self.tp += np.logical_and(target == 1, preds == 1).sum(axis=0)
82 | self.fn += np.logical_and(target == 1, preds == 0).sum(axis=0)
83 | self.fp += np.logical_and(target == 0, preds == 1).sum(axis=0)
84 |
85 | def compute(self) -> float:
86 | prev_settings = np.seterr('ignore')
87 |
88 | if self.average == 'macro':
89 | score = np.nansum(
90 | 2*self.tp / (2*self.tp + self.fp + self.fn)) / self.num_classes
91 | elif self.average == 'micro':
92 | score = np.nan_to_num(2*np.sum(self.tp) /
93 | np.sum(2*self.tp + self.fp + self.fn))
94 | elif self.average == 'another-macro':
95 | macro_prec = np.nansum(
96 | self.tp / (self.tp + self.fp)) / self.num_classes
97 | macro_recall = np.nansum(
98 | self.tp / (self.tp + self.fn)) / self.num_classes
99 | score = np.nan_to_num(
100 | 2 * macro_prec * macro_recall / (macro_prec + macro_recall))
101 |
102 | np.seterr(**prev_settings)
103 | return score
104 |
105 |
106 | class MetricCollection(dict):
107 | def __init__(self, metrics) -> None:
108 | self.metrics = metrics
109 |
110 | def update(self, preds: np.ndarray, target: np.ndarray) -> None:
111 | assert preds.shape == target.shape # (batch_size, num_classes)
112 | for metric in self.metrics.values():
113 | metric.update(preds, target)
114 |
115 | def compute(self) -> "dict[str, float]":
116 | ret = {}
117 | for name, metric in self.metrics.items():
118 | ret[name] = metric.compute()
119 | return ret
120 |
121 |
122 | def get_metrics(metric_threshold: float, monitor_metrics: list, num_classes: int,
123 | zero: bool, multi_class: bool):
124 | """Get a collection of metrics by their names.
125 | Args:
126 | metric_threshold (float): The decision value threshold over which a label
127 | is predicted as positive.
128 | monitor_metrics (list): A list of strings naming the metrics.
129 | num_classes (int): The number of classes.
130 | zero (bool)
131 | multi_class (bool)
132 | Returns:
133 | MetricCollection: A metric collection of the list of metrics.
134 | """
135 | if monitor_metrics is None:
136 | monitor_metrics = []
137 | metrics = {}
138 | for metric in monitor_metrics:
139 | if re.match('P@\d+', metric):
140 | metrics[metric] = Precision(
141 | num_classes, average='samples', top_k=int(metric[2:]))
142 | elif re.match('RP@\d+', metric):
143 | metrics[metric] = RPrecision(top_k=int(metric[3:]))
144 | elif metric in {'Another-Macro-F1', 'Macro-F1', 'Micro-F1'}:
145 | metrics[metric] = F1(
146 | num_classes, metric_threshold, average=metric[:-3].lower(),
147 | zero=zero, multi_class=multi_class)
148 | else:
149 | raise ValueError(f'Invalid metric: {metric}')
150 |
151 | return MetricCollection(metrics)
152 |
153 |
154 | def tabulate_metrics(metric_dict, split):
155 | msg = f'====== {split} dataset evaluation result =======\n'
156 | header = '|'.join([f'{k:^18}' for k in metric_dict.keys()])
157 | values = '|'.join([f'{x * 100:^18.4f}' if isinstance(x, (np.floating,
158 | float)) else f'{x:^18}' for x in metric_dict.values()])
159 | msg += f"|{header}|\n|{'-----------------:|' * len(metric_dict)}\n|{values}|\n"
160 | return msg
161 |
--------------------------------------------------------------------------------
/libmultilabel/linear/preprocessor.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import logging
4 | import os
5 | import re
6 | from array import array
7 | from collections import defaultdict
8 |
9 | import numpy as np
10 | import pandas as pd
11 | import scipy
12 | import scipy.sparse as sparse
13 | from sklearn.feature_extraction.text import TfidfVectorizer
14 | from sklearn.preprocessing import MultiLabelBinarizer
15 |
16 | __all__ = ['Preprocessor']
17 |
18 |
19 | class Preprocessor:
20 | """Preprocessor is used to load and preprocess input data in LibSVM and LibMultiLabel formats.
21 | The same Preprocessor has to be used for both training and testing data;
22 | see save_pipeline and load_pipeline.
23 | """
24 |
25 | def __init__(self, data_format: str) -> None:
26 | """Initializes the preprocessor.
27 |
28 | Args:
29 | data_format (str): The data format used. 'svm' for LibSVM format and 'txt' for LibMultiLabel format.
30 | """
31 | if not data_format in {'txt', 'svm'}:
32 | raise ValueError(f'unsupported data format {data_format}')
33 |
34 | self.data_format = data_format
35 |
36 | def load_data(self, train_path: str = '',
37 | test_path: str = '',
38 | eval: bool = False,
39 | label_file: str = None,
40 | include_test_labels: bool = False,
41 | remove_no_label_data: bool = False) -> 'dict[str, dict]':
42 | """Loads and preprocesses data.
43 |
44 | Args:
45 | train_path (str): Training data path. Ignored if eval is True. Defaults to ''.
46 | test_path (str): Test data path. Ignored if test_path doesn't exist. Defaults to ''.
47 | eval (bool): If True, ignores training data and uses previously loaded state to preprocess test data.
48 | label_file (str, optional): Path to a file holding all labels.
49 | include_test_labels (bool, optional): Whether to include labels in the test dataset. Defaults to False.
50 | remove_no_label_data (bool, optional): Whether to remove training instances that have no labels.
51 |
52 | Returns:
53 | dict[str, dict]: The training and test data, with keys 'train' and 'test' respectively. The data
54 | has keys 'x' for input features and 'y' for labels.
55 | """
56 | if label_file is not None:
57 | logging.info(f'Load labels from {label_file}.')
58 | with open(label_file, 'r') as fp:
59 | self.classes = sorted([s.strip() for s in fp.readlines()])
60 | else:
61 | if not os.path.exists(test_path) and include_test_labels:
62 | raise ValueError(
63 | f'Specified the inclusion of test labels but test file does not exist')
64 | self.classes = None
65 | self.include_test_labels = include_test_labels
66 |
67 | if self.data_format == 'txt':
68 | data = self._load_txt(train_path, test_path, eval)
69 | elif self.data_format == 'svm':
70 | data = self._load_svm(train_path, test_path, eval)
71 |
72 | if 'train' in data:
73 | num_labels = data['train']['y'].getnnz(axis=1)
74 | num_no_label_data = np.count_nonzero(num_labels == 0)
75 | if num_no_label_data > 0:
76 | if remove_no_label_data:
77 | logging.info(f'Remove {num_no_label_data} instances that have no labels from {train_path}.')
78 | data['train']['x'] = data['train']['x'][num_labels > 0]
79 | data['train']['y'] = data['train']['y'][num_labels > 0]
80 | else:
81 | logging.info(f'Keep {num_no_label_data} instances that have no labels from {train_path}.')
82 |
83 | return data
84 |
85 | def _load_txt(self, train_path, test_path, eval) -> 'dict[str, dict]':
86 | datasets = defaultdict(dict)
87 | if os.path.exists(test_path):
88 | test = read_libmultilabel_format(test_path)
89 |
90 | if not eval:
91 | train = read_libmultilabel_format(train_path)
92 | self._generate_tfidf(train['text'])
93 |
94 | if self.classes or not self.include_test_labels:
95 | self._generate_label_mapping(train['label'], self.classes)
96 | else:
97 | self._generate_label_mapping(train['label'] + test['label'])
98 | datasets['train']['x'] = self.vectorizer.transform(train['text'])
99 | datasets['train']['y'] = self.binarizer.transform(
100 | train['label']).astype('d')
101 |
102 | if os.path.exists(test_path):
103 | datasets['test']['x'] = self.vectorizer.transform(test['text'])
104 | datasets['test']['y'] = self.binarizer.transform(
105 | test['label']).astype('d')
106 |
107 | return dict(datasets)
108 |
109 | def _load_svm(self, train_path, test_path, eval) -> 'dict[str, dict]':
110 | datasets = defaultdict(dict)
111 | if os.path.exists(test_path):
112 | ty, tx = read_libsvm_format(test_path)
113 |
114 | if not eval:
115 | y, x = read_libsvm_format(train_path)
116 | if self.classes or not self.include_test_labels:
117 | self._generate_label_mapping(y, self.classes)
118 | else:
119 | self._generate_label_mapping(y + ty)
120 | datasets['train']['x'] = x
121 | datasets['train']['y'] = self.binarizer.transform(y).astype('d')
122 |
123 | if os.path.exists(test_path):
124 | datasets['test']['x'] = tx
125 | datasets['test']['y'] = self.binarizer.transform(ty).astype('d')
126 | return dict(datasets)
127 |
128 | def _generate_tfidf(self, texts):
129 | self.vectorizer = TfidfVectorizer()
130 | self.vectorizer.fit(texts)
131 |
132 | def _generate_label_mapping(self, labels, classes=None):
133 | self.binarizer = MultiLabelBinarizer(
134 | sparse_output=True, classes=classes)
135 | self.binarizer.fit(labels)
136 |
137 |
138 | def read_libmultilabel_format(path: str) -> 'dict[str,list[str]]':
139 | data = pd.read_csv(path, sep='\t', header=None,
140 | dtype=str,
141 | on_bad_lines='skip').fillna('')
142 | if data.shape[1] == 2:
143 | data.columns = ['label', 'text']
144 | data = data.reset_index()
145 | elif data.shape[1] == 3:
146 | data.columns = ['index', 'label', 'text']
147 | else:
148 | raise ValueError(f'Expected 2 or 3 columns, got {data.shape[1]}.')
149 | data['label'] = data['label'].map(lambda s: s.split())
150 | return data.to_dict('list')
151 |
152 |
153 | def read_libsvm_format(file_path: str) -> 'tuple[list[list[int]], sparse.csr_matrix]':
154 | """Read multi-label LIBSVM-format data.
155 |
156 | Args:
157 | file_path (str): Path to file.
158 |
159 | Returns:
160 | tuple[list[list[int]], sparse.csr_matrix]: A tuple of labels and features.
161 | """
162 | def as_ints(str):
163 | return [int(s) for s in str.split(',')]
164 |
165 | prob_y = []
166 | prob_x = array('d')
167 | row_ptr = array('l', [0])
168 | col_idx = array('l')
169 |
170 | pattern = re.compile(r'(?!^$)([+\-0-9,]+\s+)?(.*\n?)')
171 | for i, line in enumerate(open(file_path)):
172 | m = pattern.fullmatch(line)
173 | try:
174 | labels = m[1]
175 | prob_y.append(as_ints(labels) if labels else [])
176 | features = m[2] or ''
177 | nz = 0
178 | for e in features.split():
179 | ind, val = e.split(':')
180 | ind, val = int(ind), float(val)
181 | if ind < 1:
182 | raise IndexError(f'invalid svm format at line {i+1} of the file \'{file_path}\' --> Indices should start from one.')
183 | if val != 0:
184 | col_idx.append(ind - 1)
185 | prob_x.append(val)
186 | nz += 1
187 | row_ptr.append(row_ptr[-1]+nz)
188 | except IndexError:
189 | raise
190 | except:
191 | raise ValueError(f'invalid svm format at line {i+1} of the file \'{file_path}\'')
192 |
193 | prob_x = scipy.frombuffer(prob_x, dtype='d')
194 | col_idx = scipy.frombuffer(col_idx, dtype='l')
195 | row_ptr = scipy.frombuffer(row_ptr, dtype='l')
196 | prob_x = sparse.csr_matrix((prob_x, col_idx, row_ptr))
197 |
198 | return (prob_y, prob_x)
199 |
--------------------------------------------------------------------------------
/libmultilabel/linear/utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os
3 | from pathlib import Path
4 |
5 | from .preprocessor import Preprocessor
6 |
7 | __all__ = ['save_pipeline', 'load_pipeline']
8 |
9 |
10 | def save_pipeline(checkpoint_dir: str, preprocessor: Preprocessor, model):
11 | """Saves preprocessor and model to checkpoint_dir/linear_pipline.pickle.
12 |
13 | Args:
14 | checkpoint_dir (str): The directory to save to.
15 | preprocessor (Preprocessor): A Preprocessor.
16 | model: A model returned from one of the training functions.
17 | """
18 | Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
19 | checkpoint_path = os.path.join(checkpoint_dir, 'linear_pipeline.pickle')
20 |
21 | with open(checkpoint_path, 'wb') as f:
22 | pickle.dump({
23 | 'preprocessor': preprocessor,
24 | 'model': model,
25 | }, f)
26 |
27 |
28 | def load_pipeline(checkpoint_path: str) -> tuple:
29 | """Loads preprocessor and model from checkpoint_path.
30 |
31 | Args:
32 | checkpoint_path (str): The path to a previously saved pipeline.
33 |
34 | Returns:
35 | tuple: A tuple of the preprocessor and model.
36 | """
37 | with open(checkpoint_path, 'rb') as f:
38 | pipeline = pickle.load(f)
39 | return (pipeline['preprocessor'], pipeline['model'])
40 |
--------------------------------------------------------------------------------
/libmultilabel/nn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JamesLYC88/text_classification_baseline_code/6436f567372f3c36547cd52da79516900e5a6148/libmultilabel/nn/__init__.py
--------------------------------------------------------------------------------
/libmultilabel/nn/metrics.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | import numpy as np
4 | import torch
5 | import torchmetrics.classification
6 | from torchmetrics import Metric, MetricCollection, Precision, Recall, RetrievalNormalizedDCG
7 | from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
8 | from torchmetrics.utilities.data import select_topk
9 |
10 |
11 | class NDCG(Metric):
12 | """NDCG (Normalized Discounted Cumulative Gain) sums the true scores
13 | ranked in the order induced by the predicted scores after applying a logarithmic discount,
14 | and then divides by the best possible score (Ideal DCG, obtained for a perfect ranking)
15 | to obtain a score between 0 and 1.
16 | The definition is quoted from:
17 | https://scikit-learn.org/stable/modules/generated/sklearn.metrics.ndcg_score.html
18 | Please find the formal definition here:
19 | https://nlp.stanford.edu/IR-book/html/htmledition/evaluation-of-ranked-retrieval-results-1.html
20 |
21 | Args:
22 | top_k (int): the top k relevant labels to evaluate.
23 | """
24 | # If the metric state of one batch is independent of the state of other batches,
25 | # full_state_update can be set to False,
26 | # which leads to more efficient computation with calling update() only once.
27 | # Please find the detailed explanation here:
28 | # https://torchmetrics.readthedocs.io/en/stable/pages/implement.html
29 | full_state_update = False
30 |
31 | def __init__(
32 | self,
33 | top_k
34 | ):
35 | super().__init__()
36 | self.top_k = top_k
37 | self.add_state("ndcg", default=[], dist_reduce_fx="cat")
38 |
39 | def update(self, preds, target):
40 | assert preds.shape == target.shape
41 | # implement batch-wise calculations instead of storing results of all batches
42 | self.ndcg += [self._metric(p, t) for p, t in zip(preds, target)]
43 |
44 | def compute(self):
45 | return torch.stack(self.ndcg).mean()
46 |
47 | def _metric(self, preds, target):
48 | return retrieval_normalized_dcg(preds, target, k=self.top_k)
49 |
50 |
51 | class RPrecision(Metric):
52 | """R-precision calculates precision at k by adjusting k to the minimum value of the number of
53 | relevant labels and k. The definition is given at Appendix C equation (3) of
54 | https://aclanthology.org/P19-1636.pdf
55 |
56 | Args:
57 | top_k (int): the top k relevant labels to evaluate.
58 | """
59 | # If the metric state of one batch is independent of the state of other batches,
60 | # full_state_update can be set to False,
61 | # which leads to more efficient computation with calling update() only once.
62 | # Please find the detailed explanation here:
63 | # https://torchmetrics.readthedocs.io/en/stable/pages/implement.html
64 | full_state_update = False
65 |
66 | def __init__(
67 | self,
68 | top_k
69 | ):
70 | super().__init__()
71 | self.top_k = top_k
72 | self.add_state("score", default=torch.tensor(0., dtype=torch.double), dist_reduce_fx="sum")
73 | self.add_state("num_sample", default=torch.tensor(0), dist_reduce_fx="sum")
74 |
75 | def update(self, preds, target):
76 | assert preds.shape == target.shape
77 | binary_topk_preds = select_topk(preds, self.top_k)
78 | target = target.to(dtype=torch.int)
79 | num_relevant = torch.sum(binary_topk_preds & target, dim=-1)
80 | top_ks = torch.tensor([self.top_k]*preds.shape[0]).to(preds.device)
81 | self.score += torch.nan_to_num(
82 | num_relevant / torch.min(top_ks, target.sum(dim=-1)),
83 | posinf=0.
84 | ).sum()
85 | self.num_sample += len(preds)
86 |
87 | def compute(self):
88 | return self.score / self.num_sample
89 |
90 |
91 | class MacroF1(Metric):
92 | """The macro-f1 score computes the average f1 scores of all labels in the dataset.
93 |
94 | Args:
95 | num_classes (int): Total number of classes.
96 | metric_threshold (float): Threshold to monitor for metrics.
97 | another_macro_f1 (bool, optional): Whether to compute the 'Another-Macro-F1' score.
98 | The 'Another-Macro-F1' is the f1 value of macro-precision and macro-recall.
99 | This variant of macro-f1 is less preferred but is used in some works.
100 | Please refer to Opitz et al. 2019 [https://arxiv.org/pdf/1911.03347.pdf].
101 | Defaults to False.
102 | """
103 | # If the metric state of one batch is independent of the state of other batches,
104 | # full_state_update can be set to False,
105 | # which leads to more efficient computation with calling update() only once.
106 | # Please find the detailed explanation here:
107 | # https://torchmetrics.readthedocs.io/en/stable/pages/implement.html
108 | full_state_update = False
109 |
110 | def __init__(
111 | self,
112 | num_classes,
113 | metric_threshold,
114 | another_macro_f1=False
115 | ):
116 | super().__init__()
117 | self.metric_threshold = metric_threshold
118 | self.another_macro_f1 = another_macro_f1
119 | self.add_state("preds_sum", default=torch.zeros(num_classes, dtype=torch.double))
120 | self.add_state("target_sum", default=torch.zeros(num_classes, dtype=torch.double))
121 | self.add_state("tp_sum", default=torch.zeros(num_classes, dtype=torch.double))
122 |
123 | def update(self, preds, target):
124 | assert preds.shape == target.shape
125 | preds = torch.where(preds > self.metric_threshold, 1, 0)
126 | self.preds_sum = torch.add(self.preds_sum, preds.sum(dim=0))
127 | self.target_sum = torch.add(self.target_sum, target.sum(dim=0))
128 | self.tp_sum = torch.add(self.tp_sum, (preds & target).sum(dim=0))
129 |
130 | def compute(self):
131 | if self.another_macro_f1:
132 | macro_prec = torch.mean(torch.nan_to_num(self.tp_sum / self.preds_sum, posinf=0.))
133 | macro_recall = torch.mean(torch.nan_to_num(self.tp_sum / self.target_sum, posinf=0.))
134 | return 2 * (macro_prec * macro_recall) / (macro_prec + macro_recall + 1e-10)
135 | else:
136 | label_f1 = 2 * self.tp_sum / (self.preds_sum + self.target_sum + 1e-10)
137 | return torch.mean(label_f1)
138 |
139 |
140 | def add_zero_class(labels):
141 | augmented_labels = torch.zeros((len(labels), len(labels[0]) + 1), dtype=torch.int32).to(labels.device)
142 | augmented_labels[:, :-1] = labels
143 | augmented_labels[:, -1] = (torch.sum(labels, axis=1) == 0).type(torch.int32)
144 | return augmented_labels
145 |
146 |
147 | class F1(Metric):
148 | full_state_update = False
149 |
150 | def __init__(
151 | self,
152 | num_classes,
153 | metric_threshold,
154 | average,
155 | zero,
156 | multi_class
157 | ):
158 | super().__init__()
159 | self.metric_threshold = metric_threshold
160 | if average not in {'macro', 'micro', 'another-macro'}:
161 | raise ValueError('unsupported average')
162 | self.average = average
163 | self.zero = zero
164 | if self.zero:
165 | num_classes += 1
166 | self.multi_class = multi_class
167 | self.add_state("preds_sum", default=torch.zeros(num_classes, dtype=torch.double))
168 | self.add_state("target_sum", default=torch.zeros(num_classes, dtype=torch.double))
169 | self.add_state("tp_sum", default=torch.zeros(num_classes, dtype=torch.double))
170 |
171 | def update(self, preds, target):
172 | if self.multi_class:
173 | preds = torch.eye(preds.shape[1])[preds.argmax(1)].type(torch.int32).to(preds.device)
174 | else:
175 | preds = torch.where(preds > self.metric_threshold, 1, 0)
176 | if self.zero:
177 | preds = add_zero_class(preds)
178 | target = add_zero_class(target)
179 | assert preds.shape == target.shape
180 | self.preds_sum = torch.add(self.preds_sum, preds.sum(dim=0))
181 | self.target_sum = torch.add(self.target_sum, target.sum(dim=0))
182 | self.tp_sum = torch.add(self.tp_sum, (preds & target).sum(dim=0))
183 |
184 | def compute(self):
185 | if self.average == 'another-macro':
186 | macro_prec = torch.mean(torch.nan_to_num(self.tp_sum / self.preds_sum, posinf=0.))
187 | macro_recall = torch.mean(torch.nan_to_num(self.tp_sum / self.target_sum, posinf=0.))
188 | return 2 * (macro_prec * macro_recall) / (macro_prec + macro_recall + 1e-10)
189 | elif self.average == 'macro':
190 | label_f1 = 2 * self.tp_sum / (self.preds_sum + self.target_sum + 1e-10)
191 | return torch.mean(label_f1)
192 | elif self.average == 'micro':
193 | return 2 * torch.sum(self.tp_sum) / (torch.sum(self.preds_sum + self.target_sum) + 1e-10)
194 |
195 |
196 | def get_metrics(metric_threshold, monitor_metrics, num_classes, zero, multi_class):
197 | """Map monitor metrics to the corresponding classes defined in `torchmetrics.Metric`
198 | (https://torchmetrics.readthedocs.io/en/latest/references/modules.html).
199 |
200 | Args:
201 | metric_threshold (float): Threshold to monitor for metrics.
202 | monitor_metrics (list): Metrics to monitor while validating.
203 | num_classes (int): Total number of classes.
204 | zero (bool)
205 | multi_class (bool)
206 |
207 | Raises:
208 | ValueError: The metric is invalid if:
209 | (1) It is not one of 'P@k', 'R@k', 'RP@k', 'nDCG@k', 'Micro-Precision',
210 | 'Micro-Recall', 'Micro-F1', 'Macro-F1', 'Another-Macro-F1', or a
211 | `torchmetrics.Metric`.
212 | (2) Metric@k: k is greater than `num_classes`.
213 |
214 | Returns:
215 | torchmetrics.MetricCollection: A collections of `torchmetrics.Metric` for evaluation.
216 | """
217 | if monitor_metrics is None:
218 | monitor_metrics = []
219 |
220 | metrics = dict()
221 | for metric in monitor_metrics:
222 | if isinstance(metric, Metric): # customized metric
223 | metrics[type(metric).__name__] = metric
224 | continue
225 |
226 | match_top_k = re.match(r'\b(P|R|RP|nDCG)\b@(\d+)', metric)
227 | match_metric = re.match(r'\b(Micro|Macro)\b-\b(Precision|Recall|F1)\b', metric)
228 |
229 | if match_top_k:
230 | metric_abbr = match_top_k.group(1) # P, R, PR, or nDCG
231 | top_k = int(match_top_k.group(2))
232 | if top_k >= num_classes:
233 | raise ValueError(
234 | f'Invalid metric: {metric}. {top_k} is greater than {num_classes}.')
235 | if metric_abbr == 'P':
236 | metrics[metric] = Precision(num_classes, average='samples', top_k=top_k)
237 | elif metric_abbr == 'R':
238 | metrics[metric] = Recall(num_classes, average='samples', top_k=top_k)
239 | elif metric_abbr == 'RP':
240 | metrics[metric] = RPrecision(top_k=top_k)
241 | elif metric_abbr == 'nDCG':
242 | metrics[metric] = NDCG(top_k=top_k)
243 | # The implementation in torchmetrics stores the prediction/target of all batches,
244 | # which can lead to CUDA out of memory.
245 | # metrics[metric] = RetrievalNormalizedDCG(k=top_k)
246 | elif metric == 'Another-Macro-F1':
247 | metrics[metric] = F1(num_classes, metric_threshold, average='another-macro',
248 | zero=zero, multi_class=multi_class)
249 | elif metric == 'Macro-F1':
250 | metrics[metric] = F1(num_classes, metric_threshold, average='macro',
251 | zero=zero, multi_class=multi_class)
252 | elif metric == 'Micro-F1':
253 | metrics[metric] = F1(num_classes, metric_threshold, average='micro',
254 | zero=zero, multi_class=multi_class)
255 | elif match_metric:
256 | average_type = match_metric.group(1).lower() # Micro
257 | metric_type = match_metric.group(2) # Precision, Recall, or F1
258 | metric_type = metric_type.replace('F1', 'F1Score') # to be determined
259 | metrics[metric] = getattr(torchmetrics.classification, metric_type)(
260 | num_classes, metric_threshold, average=average_type)
261 | else:
262 | raise ValueError(
263 | f'Invalid metric: {metric}. Make sure the metric is in the right format: Macro/Micro-Precision/Recall/F1 (ex. Micro-F1)')
264 |
265 | # If compute_groups is set to True (default), incorrect results may be calculated.
266 | # Please refer to https://github.com/Lightning-AI/metrics/issues/746 for more details.
267 | return MetricCollection(metrics, compute_groups=False)
268 |
269 |
270 | def tabulate_metrics(metric_dict, split):
271 | msg = f'====== {split} dataset evaluation result =======\n'
272 | header = '|'.join([f'{k:^18}' for k in metric_dict.keys()])
273 | values = '|'.join([f'{x * 100:^18.4f}' if isinstance(x, (np.floating, float)) else f'{x:^18}' for x in metric_dict.values()])
274 | msg += f"|{header}|\n|{'-----------------:|' * len(metric_dict)}\n|{values}|\n"
275 | return msg
276 |
--------------------------------------------------------------------------------
/libmultilabel/nn/model.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 |
3 | import numpy as np
4 | import pytorch_lightning as pl
5 | import torch
6 | import torch.nn.functional as F
7 | import torch.optim as optim
8 |
9 | from ..common_utils import dump_log, argsort_top_k
10 | from ..nn.metrics import get_metrics, tabulate_metrics
11 |
12 |
13 | class MultiLabelModel(pl.LightningModule):
14 | """Abstract class handling Pytorch Lightning training flow
15 |
16 | Args:
17 | num_classes (int): Total number of classes.
18 | learning_rate (float, optional): Learning rate for optimizer. Defaults to 0.0001.
19 | optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'.
20 | momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9.
21 | weight_decay (int, optional): Weight decay factor. Defaults to 0.
22 | metric_threshold (float, optional): Threshold to monitor for metrics. Defaults to 0.5.
23 | monitor_metrics (list, optional): Metrics to monitor while validating. Defaults to None.
24 | log_path (str): Path to a directory holding the log files and models.
25 | silent (bool, optional): Enable silent mode. Defaults to False.
26 | save_k_predictions (int, optional): Save top k predictions on test set. Defaults to 0.
27 | zero (bool, optional)
28 | multi_class (bool, optional)
29 | """
30 |
31 | def __init__(
32 | self,
33 | num_classes,
34 | learning_rate=0.0001,
35 | optimizer='adam',
36 | momentum=0.9,
37 | weight_decay=0,
38 | metric_threshold=0.5,
39 | monitor_metrics=None,
40 | log_path=None,
41 | silent=False,
42 | save_k_predictions=0,
43 | zero=False,
44 | multi_class=False,
45 | **kwargs
46 | ):
47 | super().__init__()
48 |
49 | # optimizer
50 | self.learning_rate = learning_rate
51 | self.optimizer = optimizer
52 | self.momentum = momentum
53 | self.weight_decay = weight_decay
54 |
55 | # dump log
56 | self.log_path = log_path
57 | self.silent = silent
58 | self.save_k_predictions = save_k_predictions
59 |
60 | # metrics for evaluation
61 | self.eval_metric = get_metrics(metric_threshold, monitor_metrics, num_classes,
62 | zero, multi_class)
63 |
64 | @abstractmethod
65 | def shared_step(self, batch):
66 | """Return loss and predicted logits"""
67 | return NotImplemented
68 |
69 | def configure_optimizers(self):
70 | """Initialize an optimizer for the free parameters of the network.
71 | """
72 | parameters = [p for p in self.parameters() if p.requires_grad]
73 | optimizer_name = self.optimizer
74 | if optimizer_name == 'sgd':
75 | optimizer = optim.SGD(parameters, self.learning_rate,
76 | momentum=self.momentum,
77 | weight_decay=self.weight_decay)
78 | elif optimizer_name == 'adam':
79 | optimizer = optim.Adam(parameters,
80 | weight_decay=self.weight_decay,
81 | lr=self.learning_rate)
82 | elif optimizer_name == 'adamw':
83 | optimizer = optim.AdamW(parameters,
84 | weight_decay=self.weight_decay,
85 | lr=self.learning_rate)
86 | elif optimizer_name == 'adamax':
87 | optimizer = optim.Adamax(parameters,
88 | weight_decay=self.weight_decay,
89 | lr=self.learning_rate)
90 | else:
91 | raise RuntimeError(
92 | 'Unsupported optimizer: {self.optimizer}')
93 |
94 | torch.nn.utils.clip_grad_value_(parameters, 0.5)
95 |
96 | return optimizer
97 |
98 | def training_step(self, batch, batch_idx):
99 | loss, _ = self.shared_step(batch)
100 | return loss
101 |
102 | def validation_step(self, batch, batch_idx):
103 | return self._shared_eval_step(batch, batch_idx)
104 |
105 | def validation_step_end(self, batch_parts):
106 | return self._shared_eval_step_end(batch_parts)
107 |
108 | def validation_epoch_end(self, step_outputs):
109 | return self._shared_eval_epoch_end(step_outputs, 'val')
110 |
111 | def test_step(self, batch, batch_idx):
112 | return self._shared_eval_step(batch, batch_idx)
113 |
114 | def test_step_end(self, batch_parts):
115 | return self._shared_eval_step_end(batch_parts)
116 |
117 | def test_epoch_end(self, step_outputs):
118 | return self._shared_eval_epoch_end(step_outputs, 'test')
119 |
120 | def _shared_eval_step(self, batch, batch_idx):
121 | loss, pred_logits = self.shared_step(batch)
122 | return {'batch_idx': batch_idx,
123 | 'loss': loss,
124 | 'pred_scores': torch.sigmoid(pred_logits),
125 | 'target': batch['labels']}
126 |
127 | def _shared_eval_step_end(self, batch_parts):
128 | batch_size, num_classes = batch_parts['target'].shape
129 | # `indexes` indicates which index a prediction belongs. `RetrievalNormalizedDCG`
130 | # will compute the mean of nDCG scores over each prediction.
131 | indexes = torch.arange(
132 | batch_size*batch_parts['batch_idx'], batch_size*(batch_parts['batch_idx']+1))
133 | indexes = indexes.unsqueeze(1).repeat(1, num_classes)
134 | return self.eval_metric.update(
135 | preds=batch_parts['pred_scores'],
136 | target=batch_parts['target'],
137 | indexes=indexes
138 | )
139 |
140 | def _shared_eval_epoch_end(self, step_outputs, split):
141 | """Get scores such as `Micro-F1`, `Macro-F1`, and monitor metrics defined
142 | in the configuration file in the end of an epoch.
143 |
144 | Args:
145 | step_outputs (list): List of the return values from the val or test step end.
146 | split (str): One of the `val` or `test`.
147 |
148 | Returns:
149 | metric_dict (dict): Scores for all metrics in the dictionary format.
150 | """
151 | metric_dict = self.eval_metric.compute()
152 | self.log_dict(metric_dict)
153 | for k, v in metric_dict.items():
154 | metric_dict[k] = v.item()
155 | if self.log_path:
156 | dump_log(metrics=metric_dict, split=split, log_path=self.log_path)
157 | self.print(tabulate_metrics(metric_dict, split))
158 | self.eval_metric.reset()
159 | return metric_dict
160 |
161 | def predict_step(self, batch, batch_idx, dataloader_idx):
162 | """`predict_step` is triggered when calling `trainer.predict()`.
163 | This function is used to get the top-k labels and their prediction scores.
164 |
165 | Args:
166 | batch (dict): A batch of text and label.
167 | batch_idx (int): Index of current batch.
168 | dataloader_idx (int): Index of current dataloader.
169 |
170 | Returns:
171 | dict: Top k label indexes and the prediction scores.
172 | """
173 | _, pred_logits = self.shared_step(batch)
174 | pred_scores = pred_logits.detach().cpu().numpy()
175 | k = self.save_k_predictions
176 | top_k_idx = argsort_top_k(pred_scores, k, axis=1)
177 | top_k_scores = np.take_along_axis(pred_scores, top_k_idx, axis=1)
178 |
179 | return {'top_k_pred': top_k_idx,
180 | 'top_k_pred_scores': top_k_scores}
181 |
182 | def print(self, *args, **kwargs):
183 | """Prints only from process 0 and not in silent mode. Use this in any
184 | distributed mode to log only once."""
185 |
186 | if not self.silent:
187 | # print() in LightningModule to print only from process 0
188 | super().print(*args, **kwargs)
189 |
190 |
191 | class Model(MultiLabelModel):
192 | """A class that implements `MultiLabelModel` for initializing and training a neural network.
193 |
194 | Args:
195 | classes (list): List of class names.
196 | word_dict (torchtext.vocab.Vocab): A vocab object which maps tokens to indices.
197 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
198 | network (nn.Module): Network (i.e., CAML, KimCNN, or XMLCNN).
199 | log_path (str): Path to a directory holding the log files and models.
200 | enable_ce_loss (bool)
201 | """
202 | def __init__(
203 | self,
204 | classes,
205 | word_dict,
206 | embed_vecs,
207 | network,
208 | log_path=None,
209 | enable_ce_loss=False,
210 | **kwargs
211 | ):
212 | super().__init__(num_classes=len(classes), log_path=log_path, **kwargs)
213 | self.save_hyperparameters()
214 | self.word_dict = word_dict
215 | self.embed_vecs = embed_vecs
216 | self.classes = classes
217 | self.network = network
218 | if not enable_ce_loss:
219 | self.loss_fn = F.binary_cross_entropy_with_logits
220 | else:
221 | self.loss_fn = F.cross_entropy
222 |
223 | def shared_step(self, batch):
224 | """Return loss and predicted logits of the network.
225 |
226 | Args:
227 | batch (dict): A batch of text and label.
228 |
229 | Returns:
230 | loss (torch.Tensor): Binary cross-entropy between target and predict logits.
231 | pred_logits (torch.Tensor): The predict logits (batch_size, num_classes).
232 | """
233 | target_labels = batch['labels']
234 | outputs = self.network(batch)
235 | pred_logits = outputs['logits']
236 | loss = self.loss_fn(pred_logits, target_labels.float())
237 | return loss, pred_logits
238 |
--------------------------------------------------------------------------------
/libmultilabel/nn/networks/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .bert import BERT
4 | from .bert_attention import BERTAttention
5 | from .caml import CAML
6 | from .kim_cnn import KimCNN
7 | from .xml_cnn import XMLCNN
8 | from .labelwise_attention_networks import BiGRULWAN
9 | from .labelwise_attention_networks import BiLSTMLWAN
10 | from .labelwise_attention_networks import BiLSTMLWMHAN
11 | from .labelwise_attention_networks import CNNLWAN
12 |
13 |
14 | def get_init_weight_func(init_weight):
15 | def init_weight_func(m):
16 | if isinstance(m, nn.Linear) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv2d):
17 | getattr(nn.init, init_weight+ '_')(m.weight)
18 | return init_weight_func
19 |
--------------------------------------------------------------------------------
/libmultilabel/nn/networks/bert.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from transformers import AutoModelForSequenceClassification
3 |
4 | from .hierbert import HierarchicalBert
5 |
6 |
7 | class BERT(nn.Module):
8 | """BERT
9 |
10 | Args:
11 | num_classes (int): Total number of classes.
12 | dropout (float): The dropout rate of the word embedding. Defaults to 0.1.
13 | lm_weight (str): Pretrained model name or path. Defaults to 'bert-base-cased'.
14 | """
15 | def __init__(
16 | self,
17 | num_classes,
18 | dropout=0.1,
19 | lm_weight='bert-base-cased',
20 | hierarchical=False,
21 | max_segments=64,
22 | max_seg_length=128,
23 | **kwargs
24 | ):
25 | super().__init__()
26 | self.lm = AutoModelForSequenceClassification.from_pretrained(lm_weight,
27 | num_labels=num_classes,
28 | hidden_dropout_prob=dropout,
29 | torchscript=True)
30 | self.hierarchical = hierarchical
31 | if self.hierarchical:
32 | segment_encoder = self.lm.bert
33 | model_encoder = HierarchicalBert(encoder=segment_encoder,
34 | max_segments=max_segments,
35 | max_segment_length=max_seg_length)
36 | self.lm.bert = model_encoder
37 |
38 | def forward(self, input):
39 | input_ids = input['input_ids']
40 | attention_mask = input['attention_mask']
41 | x = self.lm(input_ids, attention_mask=attention_mask)[0] # (batch_size, num_classes)
42 | return {'logits': x}
43 |
--------------------------------------------------------------------------------
/libmultilabel/nn/networks/bert_attention.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from torch.nn.utils.rnn import pad_sequence
3 | from transformers import AutoModel
4 |
5 | from .modules import LabelwiseAttention, LabelwiseLinearOutput, LabelwiseMultiHeadAttention
6 |
7 |
8 | class BERTAttention(nn.Module):
9 | """BERT + Label-wise Document Attention or Multi-Head Attention
10 |
11 | Args:
12 | num_classes (int): Total number of classes.
13 | dropout (float): The dropout rate of the word embedding. Defaults to 0.2.
14 | lm_weight (str): Pretrained model name or path. Defaults to 'bert-base-cased'.
15 | lm_window (int): Length of the subsequences to be split before feeding them to
16 | the language model. Defaults to 512.
17 | num_heads (int): The number of parallel attention heads. Defaults to 8.
18 | attention_type (str): Type of attention to use (caml or multihead). Defaults to 'multihead'.
19 | attention_dropout (float): The dropout rate for the attention. Defaults to 0.0.
20 | """
21 | def __init__(
22 | self,
23 | num_classes,
24 | dropout=0.2,
25 | lm_weight='bert-base-cased',
26 | lm_window=512,
27 | num_heads=8,
28 | attention_type='multihead',
29 | attention_dropout=0.0,
30 | **kwargs
31 | ):
32 | super().__init__()
33 | self.lm_window = lm_window
34 | self.attention_type = attention_type
35 |
36 | self.lm = AutoModel.from_pretrained(lm_weight, torchscript=True)
37 | self.embed_drop = nn.Dropout(p=dropout)
38 |
39 | self.attention_type = attention_type
40 | assert attention_type in ['singlehead', 'multihead'], "attention_type must be 'singlehead' or 'multihead'"
41 | if attention_type == 'singlehead':
42 | self.attention = LabelwiseAttention(self.lm.config.hidden_size, num_classes)
43 | else:
44 | self.attention = LabelwiseMultiHeadAttention(
45 | self.lm.config.hidden_size, num_classes, num_heads, attention_dropout)
46 |
47 | # Final layer: create a matrix to use for the #labels binary classifiers
48 | self.output = LabelwiseLinearOutput(self.lm.config.hidden_size, num_classes)
49 |
50 | def lm_feature(self, input_ids):
51 | """BERT takes an input of a sequence of no more than 512 tokens.
52 | Therefore, long sequence are split into subsequences of size `lm_window`, which is a number no greater than 512.
53 | If the split subsequence is shorter than `lm_window`, pad it with the pad token.
54 |
55 | Args:
56 | input_ids (torch.Tensor): Input ids of the sequence with shape (batch_size, sequence_length).
57 |
58 | Returns:
59 | torch.Tensor: The representation of the sequence.
60 | """
61 | if input_ids.size(-1) <= self.lm_window:
62 | return self.lm(input_ids, attention_mask=input_ids != self.lm.config.pad_token_id)[0]
63 | else:
64 | inputs = []
65 | batch_indexes = []
66 | seq_lengths = []
67 | for token_id in input_ids:
68 | indexes = []
69 | seq_length = (token_id != self.lm.config.pad_token_id).sum()
70 | seq_lengths.append(seq_length)
71 | for i in range(0, seq_length, self.lm_window):
72 | indexes.append(len(inputs))
73 | inputs.append(token_id[i: i + self.lm_window])
74 | batch_indexes.append(indexes)
75 |
76 | padded_inputs = pad_sequence(inputs, batch_first=True)
77 | last_hidden_states = self.lm(
78 | padded_inputs, attention_mask=padded_inputs != self.lm.config.pad_token_id)[0]
79 |
80 | x = []
81 | for seq_l, mapping in zip(seq_lengths, batch_indexes):
82 | last_hidden_state = last_hidden_states[mapping].view(
83 | -1, last_hidden_states.size(-1))[:seq_l, :]
84 | x.append(last_hidden_state)
85 | return pad_sequence(x, batch_first=True)
86 |
87 | def forward(self, input):
88 | input_ids = input['text'] # (batch_size, sequence_length)
89 | attention_mask = input_ids == self.lm.config.pad_token_id
90 | x = self.lm_feature(input_ids) # (batch_size, sequence_length, lm_hidden_size)
91 | x = self.embed_drop(x)
92 |
93 | # Apply per-label attention.
94 | if self.attention_type == 'singlehead':
95 | logits, attention = self.attention(x)
96 | else:
97 | logits, attention = self.attention(x, attention_mask)
98 |
99 | # Compute a probability for each label
100 | x = self.output(logits)
101 | return {'logits': x, 'attention': attention}
102 |
--------------------------------------------------------------------------------
/libmultilabel/nn/networks/caml.py:
--------------------------------------------------------------------------------
1 | from math import floor
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.init import xavier_uniform_
7 |
8 |
9 | class CAML(nn.Module):
10 | """CAML (Convolutional Attention for Multi-Label classification)
11 | Follows the work of Mullenbach et al. [https://aclanthology.org/N18-1100.pdf]
12 | This class is for reproducing the results in the paper.
13 | Use CNNLWAN instead for better modularization.
14 |
15 | Args:
16 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
17 | num_classes (int): Total number of classes.
18 | filter_sizes (list): Size of convolutional filters.
19 | num_filter_per_size (int): The number of filters in convolutional layers in each size. Defaults to 50.
20 | dropout (float): The dropout rate of the word embedding. Defaults to 0.2.
21 | """
22 | def __init__(
23 | self,
24 | embed_vecs,
25 | num_classes,
26 | filter_sizes=None,
27 | num_filter_per_size=50,
28 | dropout=0.2,
29 | ):
30 | super(CAML, self).__init__()
31 | if not filter_sizes and len(filter_sizes) != 1:
32 | raise ValueError(f'CAML expect 1 filter size. Got filter_sizes={filter_sizes}')
33 | filter_size = filter_sizes[0]
34 |
35 | self.embedding = nn.Embedding(len(embed_vecs), embed_vecs.shape[1], padding_idx=0)
36 | self.embedding.weight.data = embed_vecs.clone()
37 | self.embed_drop = nn.Dropout(p=dropout)
38 |
39 | # Initialize conv layer
40 | self.conv = nn.Conv1d(embed_vecs.shape[1], num_filter_per_size, kernel_size=filter_size, padding=int(floor(filter_size/2)))
41 | xavier_uniform_(self.conv.weight)
42 |
43 | """Context vectors for computing attention with
44 | (in_features, out_features) = (num_filter_per_size, num_classes)
45 | """
46 | self.Q = nn.Linear(num_filter_per_size, num_classes)
47 | xavier_uniform_(self.Q.weight)
48 |
49 | # Final layer: create a matrix to use for the #labels binary classifiers
50 | self.output = nn.Linear(num_filter_per_size, num_classes)
51 | xavier_uniform_(self.output.weight)
52 |
53 | def forward(self, input):
54 | # Get embeddings and apply dropout
55 | x = self.embedding(input['text']) # (batch_size, length, embed_dim)
56 | x = self.embed_drop(x)
57 | x = x.transpose(1, 2) # (batch_size, embed_dim, length)
58 |
59 | """ Apply convolution and nonlinearity (tanh). The shapes are:
60 | - self.conv(x): (batch_size, num_filte_per_size, length)
61 | - x after transposing the first and the second dimension and applying
62 | the activation function: (batch_size, length, num_filte_per_size)
63 | """
64 | Z = torch.tanh(self.conv(x).transpose(1, 2))
65 |
66 | """Apply per-label attention. The shapes are:
67 | - Q.weight: (num_classes, num_filte_per_size)
68 | - matrix product of U.weight and x: (batch_size, num_classes, length)
69 | - alpha: (batch_size, num_classes, length)
70 | """
71 | alpha = torch.softmax(self.Q.weight.matmul(Z.transpose(1, 2)), dim=2)
72 |
73 | # Document representations are weighted sums using the attention
74 | E = alpha.matmul(Z) # (batch_size, num_classes, num_filter_per_size)
75 |
76 | # Compute a probability for each label
77 | logits = self.output.weight.mul(E).sum(dim=2).add(self.output.bias) # (batch_size, num_classes)
78 |
79 | return {'logits': logits, 'attention': alpha}
80 |
--------------------------------------------------------------------------------
/libmultilabel/nn/networks/hierbert.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Optional, Tuple
3 |
4 | import torch
5 | import numpy as np
6 | from torch import nn
7 | from transformers.file_utils import ModelOutput
8 |
9 |
10 | @dataclass
11 | class SimpleOutput(ModelOutput):
12 | last_hidden_state: torch.FloatTensor = None
13 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
14 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
15 | attentions: Optional[Tuple[torch.FloatTensor]] = None
16 | cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
17 |
18 |
19 | def sinusoidal_init(num_embeddings: int, embedding_dim: int):
20 | # keep dim 0 for padding token position encoding zero vector
21 | position_enc = np.array([
22 | [pos / np.power(10000, 2 * i / embedding_dim) for i in range(embedding_dim)]
23 | if pos != 0 else np.zeros(embedding_dim) for pos in range(num_embeddings)])
24 |
25 | position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
26 | position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1
27 | return torch.from_numpy(position_enc).type(torch.FloatTensor)
28 |
29 |
30 | class HierarchicalBert(nn.Module):
31 |
32 | def __init__(self, encoder, max_segments=64, max_segment_length=128):
33 | super(HierarchicalBert, self).__init__()
34 | supported_models = ['bert', 'roberta', 'deberta']
35 | assert encoder.config.model_type in supported_models # other model types are not supported so far
36 | # Pre-trained segment (token-wise) encoder, e.g., BERT
37 | self.encoder = encoder
38 | # Specs for the segment-wise encoder
39 | self.hidden_size = encoder.config.hidden_size
40 | self.max_segments = max_segments
41 | self.max_segment_length = max_segment_length
42 | # Init sinusoidal positional embeddings
43 | self.seg_pos_embeddings = nn.Embedding(max_segments + 1, encoder.config.hidden_size,
44 | padding_idx=0,
45 | _weight=sinusoidal_init(max_segments + 1, encoder.config.hidden_size))
46 | # Init segment-wise transformer-based encoder
47 | self.seg_encoder = nn.Transformer(d_model=encoder.config.hidden_size,
48 | nhead=encoder.config.num_attention_heads,
49 | batch_first=True, dim_feedforward=encoder.config.intermediate_size,
50 | activation=encoder.config.hidden_act,
51 | dropout=encoder.config.hidden_dropout_prob,
52 | layer_norm_eps=encoder.config.layer_norm_eps,
53 | num_encoder_layers=2, num_decoder_layers=0).encoder
54 |
55 | def forward(self,
56 | input_ids=None,
57 | attention_mask=None,
58 | token_type_ids=None,
59 | position_ids=None,
60 | head_mask=None,
61 | inputs_embeds=None,
62 | labels=None,
63 | output_attentions=None,
64 | output_hidden_states=None,
65 | return_dict=None,
66 | ):
67 | # Hypothetical Example
68 | # Batch of 4 documents: (batch_size, n_segments, max_segment_length) --> (4, 64, 128)
69 | # BERT-BASE encoder: 768 hidden units
70 |
71 | # Squash samples and segments into a single axis (batch_size * n_segments, max_segment_length) --> (256, 128)
72 | input_ids_reshape = input_ids.contiguous().view(-1, input_ids.size(-1))
73 | attention_mask_reshape = attention_mask.contiguous().view(-1, attention_mask.size(-1))
74 | if token_type_ids is not None:
75 | token_type_ids_reshape = token_type_ids.contiguous().view(-1, token_type_ids.size(-1))
76 | else:
77 | token_type_ids_reshape = None
78 |
79 | # Encode segments with BERT --> (256, 128, 768)
80 | encoder_outputs = self.encoder(input_ids=input_ids_reshape,
81 | attention_mask=attention_mask_reshape,
82 | token_type_ids=token_type_ids_reshape)[0]
83 |
84 | # Reshape back to (batch_size, n_segments, max_segment_length, output_size) --> (4, 64, 128, 768)
85 | encoder_outputs = encoder_outputs.contiguous().view(input_ids.size(0), self.max_segments,
86 | self.max_segment_length,
87 | self.hidden_size)
88 |
89 | # Gather CLS outputs per segment --> (4, 64, 768)
90 | encoder_outputs = encoder_outputs[:, :, 0]
91 |
92 | # Infer real segments, i.e., mask paddings
93 | seg_mask = (torch.sum(input_ids, 2) != 0).to(input_ids.dtype)
94 | # Infer and collect segment positional embeddings
95 | seg_positions = torch.arange(1, self.max_segments + 1).to(input_ids.device) * seg_mask
96 | # Add segment positional embeddings to segment inputs
97 | encoder_outputs += self.seg_pos_embeddings(seg_positions)
98 |
99 | # Encode segments with segment-wise transformer
100 | seg_encoder_outputs = self.seg_encoder(encoder_outputs)
101 |
102 | # Collect document representation
103 | outputs, _ = torch.max(seg_encoder_outputs, 1)
104 |
105 | return SimpleOutput(last_hidden_state=outputs, hidden_states=outputs)
106 |
--------------------------------------------------------------------------------
/libmultilabel/nn/networks/kim_cnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .modules import Embedding, CNNEncoder
5 |
6 |
7 | class KimCNN(nn.Module):
8 | """KimCNN
9 |
10 | Args:
11 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
12 | num_classes (int): Total number of classes.
13 | filter_sizes (list): The size of convolutional filters.
14 | num_filter_per_size (int): The number of filters in convolutional layers in each size. Defaults to 128.
15 | embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2.
16 | encoder_dropout (float): The dropout rate of the encoder output. Defaults to 0.
17 | activation (str): Activation function to be used. Defaults to 'relu'.
18 | """
19 | def __init__(
20 | self,
21 | embed_vecs,
22 | num_classes,
23 | filter_sizes=None,
24 | num_filter_per_size=128,
25 | embed_dropout=0.2,
26 | encoder_dropout=0,
27 | activation='relu'
28 | ):
29 | super(KimCNN, self).__init__()
30 | self.embedding = Embedding(embed_vecs, embed_dropout)
31 | self.encoder = CNNEncoder(embed_vecs.shape[1], filter_sizes,
32 | num_filter_per_size, activation,
33 | encoder_dropout, num_pool=1)
34 | conv_output_size = num_filter_per_size * len(filter_sizes)
35 | self.linear = nn.Linear(conv_output_size, num_classes)
36 |
37 | def forward(self, input):
38 | x = self.embedding(input['text']) # (batch_size, length, embed_dim)
39 | x = self.encoder(x) # (batch_size, num_filter, 1)
40 | x = torch.squeeze(x, 2) # (batch_size, num_filter)
41 | x = self.linear(x) # (batch_size, num_classes)
42 | return {'logits': x}
43 |
--------------------------------------------------------------------------------
/libmultilabel/nn/networks/labelwise_attention_networks.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import torch.nn as nn
4 |
5 | from .modules import Embedding, GRUEncoder, LSTMEncoder, CNNEncoder, LabelwiseAttention, LabelwiseMultiHeadAttention, LabelwiseLinearOutput
6 |
7 |
8 | class LabelwiseAttentionNetwork(ABC, nn.Module):
9 | """Base class for Labelwise Attention Network
10 |
11 | Args:
12 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
13 | num_classes (int): Total number of classes.
14 | embed_dropout (float): The dropout rate of the word embedding.
15 | encoder_dropout (float): The dropout rate of the encoder output.
16 | hidden_dim (int): The output dimension of the encoder.
17 | """
18 |
19 | def __init__(self, embed_vecs, num_classes, embed_dropout, encoder_dropout, hidden_dim):
20 | super(LabelwiseAttentionNetwork, self).__init__()
21 | self.embedding = Embedding(embed_vecs, embed_dropout)
22 | self.encoder = self._get_encoder(embed_vecs.shape[1], encoder_dropout)
23 | self.attention = self._get_attention()
24 | self.output = LabelwiseLinearOutput(hidden_dim, num_classes)
25 |
26 | @abstractmethod
27 | def forward(self, input):
28 | raise NotImplementedError
29 |
30 | @abstractmethod
31 | def _get_encoder(self, input_size, dropout):
32 | raise NotImplementedError
33 |
34 | @abstractmethod
35 | def _get_attention(self):
36 | raise NotImplementedError
37 |
38 |
39 | class RNNLWAN(LabelwiseAttentionNetwork):
40 | """Base class for RNN Labelwise Attention Network
41 | """
42 |
43 | def forward(self, input):
44 | x = self.embedding(input['text']) # (batch_size, sequence_length, embed_dim)
45 | x = self.encoder(x, input['length']) # (batch_size, sequence_length, hidden_dim)
46 | x, _ = self.attention(x) # (batch_size, num_classes, hidden_dim)
47 | x = self.output(x) # (batch_size, num_classes)
48 | return {'logits': x}
49 |
50 |
51 | class BiGRULWAN(RNNLWAN):
52 | """BiGRU Labelwise Attention Network
53 |
54 | Args:
55 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
56 | num_classes (int): Total number of classes.
57 | rnn_dim (int): The size of bidirectional hidden layers. The hidden size of the GRU network
58 | is set to rnn_dim//2. Defaults to 512.
59 | rnn_layers (int): The number of recurrent layers. Defaults to 1.
60 | embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2.
61 | encoder_dropout (float): The dropout rate of the encoder output. Defaults to 0.
62 | """
63 |
64 | def __init__(
65 | self,
66 | embed_vecs,
67 | num_classes,
68 | rnn_dim=512,
69 | rnn_layers=1,
70 | embed_dropout=0.2,
71 | encoder_dropout=0
72 | ):
73 | self.num_classes = num_classes
74 | self.rnn_dim = rnn_dim
75 | self.rnn_layers = rnn_layers
76 | super(BiGRULWAN, self).__init__(embed_vecs, num_classes, embed_dropout,
77 | encoder_dropout, rnn_dim)
78 |
79 | def _get_encoder(self, input_size, dropout):
80 | assert self.rnn_dim % 2 == 0, """`rnn_dim` should be even."""
81 | return GRUEncoder(input_size, self.rnn_dim // 2, self.rnn_layers, dropout)
82 |
83 | def _get_attention(self):
84 | return LabelwiseAttention(self.rnn_dim, self.num_classes)
85 |
86 |
87 | class BiLSTMLWAN(RNNLWAN):
88 | """BiLSTM Labelwise Attention Network
89 |
90 | Args:
91 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
92 | num_classes (int): Total number of classes.
93 | rnn_dim (int): The size of bidirectional hidden layers. The hidden size of the LSTM network
94 | is set to rnn_dim//2. Defaults to 512.
95 | rnn_layers (int): The number of recurrent layers. Defaults to 1.
96 | embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2.
97 | encoder_dropout (float): The dropout rate of the encoder output. Defaults to 0.
98 | """
99 |
100 | def __init__(
101 | self,
102 | embed_vecs,
103 | num_classes,
104 | rnn_dim=512,
105 | rnn_layers=1,
106 | embed_dropout=0.2,
107 | encoder_dropout=0
108 | ):
109 | self.num_classes = num_classes
110 | self.rnn_dim = rnn_dim
111 | self.rnn_layers = rnn_layers
112 | super(BiLSTMLWAN, self).__init__(embed_vecs, num_classes, embed_dropout,
113 | encoder_dropout, rnn_dim)
114 |
115 | def _get_encoder(self, input_size, dropout):
116 | assert self.rnn_dim % 2 == 0, """`rnn_dim` should be even."""
117 | return LSTMEncoder(input_size, self.rnn_dim // 2, self.rnn_layers, dropout)
118 |
119 | def _get_attention(self):
120 | return LabelwiseAttention(self.rnn_dim, self.num_classes)
121 |
122 |
123 | class BiLSTMLWMHAN(LabelwiseAttentionNetwork):
124 | """BiLSTM Labelwise Multihead Attention Network
125 |
126 | Args:
127 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
128 | num_classes (int): Total number of classes.
129 | rnn_dim (int): The size of bidirectional hidden layers. The hidden size of the LSTM network
130 | is set to rnn_dim//2. Defaults to 512.
131 | rnn_layers (int): The number of recurrent layers. Defaults to 1.
132 | embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2.
133 | encoder_dropout (float): The dropout rate of the encoder output. Defaults to 0.
134 | num_heads (int): The number of parallel attention heads. Defaults to 8.
135 | attention_dropout (float): The dropout rate for the attention. Defaults to 0.0.
136 | """
137 |
138 | def __init__(
139 | self,
140 | embed_vecs,
141 | num_classes,
142 | rnn_dim=512,
143 | rnn_layers=1,
144 | embed_dropout=0.2,
145 | encoder_dropout=0,
146 | num_heads=8,
147 | attention_dropout=0.0
148 | ):
149 | self.num_classes = num_classes
150 | self.rnn_dim = rnn_dim
151 | self.rnn_layers = rnn_layers
152 | self.num_heads = num_heads
153 | self.attention_dropout = attention_dropout
154 | super(BiLSTMLWMHAN, self).__init__(embed_vecs, num_classes, embed_dropout,
155 | encoder_dropout, rnn_dim)
156 |
157 | def _get_encoder(self, input_size, dropout):
158 | assert self.rnn_dim % 2 == 0, """`rnn_dim` should be even."""
159 | return LSTMEncoder(input_size, self.rnn_dim // 2,
160 | self.rnn_layers, dropout)
161 |
162 | def _get_attention(self):
163 | return LabelwiseMultiHeadAttention(self.rnn_dim, self.num_classes, self.num_heads, self.attention_dropout)
164 |
165 | def forward(self, input):
166 | x = self.embedding(input['text']) # (batch_size, sequence_length, embed_dim)
167 | x = self.encoder(x, input['length']) # (batch_size, sequence_length, hidden_dim)
168 | x, _ = self.attention(x, attention_mask=input['text'] == 0) # (batch_size, num_classes, hidden_dim)
169 | x = self.output(x) # (batch_size, num_classes)
170 | return {'logits': x}
171 |
172 |
173 | class CNNLWAN(LabelwiseAttentionNetwork):
174 | """CNN Labelwise Attention Network
175 |
176 | Args:
177 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
178 | num_classes (int): Total number of classes.
179 | filter_sizes (list): Size of convolutional filters.
180 | num_filter_per_size (int): The number of filters in convolutional layers in each size. Defaults to 50.
181 | embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2.
182 | encoder_dropout (float): The dropout rate of the encoder output. Defaults to 0.
183 | activation (str): Activation function to be used. Defaults to 'tanh'.
184 | """
185 |
186 | def __init__(
187 | self,
188 | embed_vecs,
189 | num_classes,
190 | filter_sizes=None,
191 | num_filter_per_size=50,
192 | embed_dropout=0.2,
193 | encoder_dropout=0,
194 | activation='tanh'
195 | ):
196 | self.num_classes = num_classes
197 | self.filter_sizes = filter_sizes
198 | self.num_filter_per_size = num_filter_per_size
199 | self.activation = activation
200 | self.hidden_dim = num_filter_per_size * len(filter_sizes)
201 | super(CNNLWAN, self).__init__(embed_vecs, num_classes, embed_dropout,
202 | encoder_dropout, self.hidden_dim)
203 |
204 | def _get_encoder(self, input_size, dropout):
205 | return CNNEncoder(input_size, self.filter_sizes,
206 | self.num_filter_per_size, self.activation, dropout,
207 | channel_last=True)
208 |
209 | def _get_attention(self):
210 | return LabelwiseAttention(self.hidden_dim, self.num_classes)
211 |
212 | def forward(self, input):
213 | x = self.embedding(input['text']) # (batch_size, sequence_length, embed_dim)
214 | x = self.encoder(x) # (batch_size, sequence_length, hidden_dim)
215 | x, _ = self.attention(x) # (batch_size, num_classes, hidden_dim)
216 | x = self.output(x) # (batch_size, num_classes)
217 | return {'logits': x}
218 |
--------------------------------------------------------------------------------
/libmultilabel/nn/networks/modules.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
7 |
8 |
9 | class Embedding(nn.Module):
10 | """Embedding layer with dropout
11 |
12 | Args:
13 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
14 | dropout (float): The dropout rate of the word embedding. Defaults to 0.2.
15 | """
16 |
17 | def __init__(self, embed_vecs, dropout=0.2):
18 | super(Embedding, self).__init__()
19 | self.embedding = nn.Embedding.from_pretrained(
20 | embed_vecs, freeze=False, padding_idx=0)
21 | self.dropout = nn.Dropout(dropout)
22 |
23 | def forward(self, input):
24 | return self.dropout(self.embedding(input))
25 |
26 |
27 | class RNNEncoder(ABC, nn.Module):
28 | """Base class of RNN encoder with dropout
29 |
30 | Args:
31 | input_size (int): The number of expected features in the input.
32 | hidden_size (int): The number of features in the hidden state.
33 | num_layers (int): The number of recurrent layers.
34 | dropout (float): The dropout rate of the encoder. Defaults to 0.
35 | """
36 |
37 | def __init__(self, input_size, hidden_size, num_layers, dropout=0):
38 | super(RNNEncoder, self).__init__()
39 | self.rnn = self._get_rnn(input_size, hidden_size, num_layers)
40 | self.dropout = nn.Dropout(dropout)
41 |
42 | def forward(self, input, length, **kwargs):
43 | self.rnn.flatten_parameters()
44 | idx = torch.argsort(length, descending=True)
45 | length_clamped = length[idx].cpu().clamp(min=1) # avoid the empty text with length 0
46 | packed_input = pack_padded_sequence(
47 | input[idx], length_clamped, batch_first=True)
48 | outputs, _ = pad_packed_sequence(
49 | self.rnn(packed_input)[0], batch_first=True)
50 | return self.dropout(outputs[torch.argsort(idx)])
51 |
52 | @abstractmethod
53 | def _get_rnn(self, input_size, hidden_size, num_layers):
54 | raise NotImplementedError
55 |
56 |
57 | class GRUEncoder(RNNEncoder):
58 | """Bi-directional GRU encoder with dropout
59 |
60 | Args:
61 | input_size (int): The number of expected features in the input.
62 | hidden_size (int): The number of features in the hidden state.
63 | num_layers (int): The number of recurrent layers.
64 | dropout (float): The dropout rate of the encoder. Defaults to 0.
65 | """
66 |
67 | def __init__(self, input_size, hidden_size, num_layers, dropout=0):
68 | super(GRUEncoder, self).__init__(input_size, hidden_size, num_layers,
69 | dropout)
70 |
71 | def _get_rnn(self, input_size, hidden_size, num_layers):
72 | return nn.GRU(input_size, hidden_size, num_layers,
73 | batch_first=True, bidirectional=True)
74 |
75 |
76 | class LSTMEncoder(RNNEncoder):
77 | """Bi-directional LSTM encoder with dropout
78 |
79 | Args:
80 | input_size (int): The number of expected features in the input.
81 | hidden_size (int): The number of features in the hidden state.
82 | num_layers (int): The number of recurrent layers.
83 | dropout (float): The dropout rate of the encoder. Defaults to 0.
84 | """
85 |
86 | def __init__(self, input_size, hidden_size, num_layers, dropout=0):
87 | super(LSTMEncoder, self).__init__(input_size, hidden_size, num_layers,
88 | dropout)
89 |
90 | def _get_rnn(self, input_size, hidden_size, num_layers):
91 | return nn.LSTM(input_size, hidden_size, num_layers,
92 | batch_first=True, bidirectional=True)
93 |
94 |
95 | class CNNEncoder(nn.Module):
96 | """Multi-filter-size CNN encoder for text classification with max-pooling
97 |
98 | Args:
99 | input_size (int): The number of expected features in the input.
100 | filter_sizes (list): Size of convolutional filters.
101 | num_filter_per_size (int): The number of filters in convolutional layers in each size. Defaults to 128.
102 | activation (str): Activation function to be used. Defaults to 'relu'.
103 | dropout (float): The dropout rate of the encoder. Defaults to 0.
104 | num_pool (int): The number of pools for max-pooling.
105 | If num_pool = 0, do nothing.
106 | If num_pool = 1, do typical max-pooling.
107 | If num_pool > 1, do adaptive max-pooling.
108 | channel_last (bool): Whether to transpose the dimension from (batch_size, num_channel, length) to (batch_size, length, num_channel)
109 | """
110 |
111 | def __init__(self, input_size, filter_sizes, num_filter_per_size,
112 | activation, dropout=0, num_pool=0, channel_last=False):
113 | super(CNNEncoder, self).__init__()
114 | if not filter_sizes:
115 | raise ValueError(f'CNNEncoder expect non-empty filter_sizes. '
116 | f'Got: {filter_sizes}')
117 | self.channel_last = channel_last
118 | self.convs = nn.ModuleList()
119 | for filter_size in filter_sizes:
120 | conv = nn.Conv1d(
121 | in_channels=input_size,
122 | out_channels=num_filter_per_size,
123 | kernel_size=filter_size)
124 | self.convs.append(conv)
125 | self.num_pool = num_pool
126 | if num_pool > 1:
127 | self.pool = nn.AdaptiveMaxPool1d(num_pool)
128 | self.activation = getattr(torch, activation, getattr(F, activation))
129 | self.dropout = nn.Dropout(dropout)
130 |
131 | def forward(self, input):
132 | h = input.transpose(1, 2) # (batch_size, input_size, length)
133 | h_list = []
134 | for conv in self.convs:
135 | h_sub = conv(h) # (batch_size, num_filter, length)
136 | if self.num_pool == 1:
137 | h_sub = F.max_pool1d(h_sub, h_sub.shape[2]) # (batch_size, num_filter, 1)
138 | elif self.num_pool > 1:
139 | h_sub = self.pool(h_sub) # (batch_size, num_filter, num_pool)
140 | h_list.append(h_sub)
141 | h = torch.cat(h_list, 1) # (batch_size, total_num_filter, *)
142 | if self.channel_last:
143 | h = h.transpose(1, 2) # (batch_size, *, total_num_filter)
144 | h = self.activation(h)
145 | return self.dropout(h)
146 |
147 |
148 | class LabelwiseAttention(nn.Module):
149 | """Applies attention technique to summarize the sequence for each label
150 | See `Explainable Prediction of Medical Codes from Clinical Text `_
151 |
152 | Args:
153 | input_size (int): The number of expected features in the input.
154 | num_classes (int): Total number of classes.
155 | """
156 | def __init__(self, input_size, num_classes):
157 | super(LabelwiseAttention, self).__init__()
158 | self.attention = nn.Linear(input_size, num_classes, bias=False)
159 |
160 | def forward(self, input):
161 | attention = self.attention(input).transpose(1, 2) # (batch_size, num_classes, sequence_length)
162 | attention = F.softmax(attention, -1)
163 | logits = torch.bmm(attention, input) # (batch_size, num_classes, hidden_dim)
164 | return logits, attention
165 |
166 |
167 | class LabelwiseMultiHeadAttention(nn.Module):
168 | """Labelwise multi-head attention
169 |
170 | Args:
171 | input_size (int): The number of expected features in the input.
172 | num_classes (int): Total number of classes.
173 | num_heads (int): The number of parallel attention heads.
174 | attention_dropout (float): The dropout rate for the attention. Defaults to 0.0.
175 | """
176 | def __init__(self, input_size, num_classes, num_heads, attention_dropout=0.0):
177 | super(LabelwiseMultiHeadAttention, self).__init__()
178 | self.attention = nn.MultiheadAttention(embed_dim=input_size, num_heads=num_heads, dropout=attention_dropout)
179 | self.Q = nn.Linear(input_size, num_classes)
180 |
181 | def forward(self, input, attention_mask=None):
182 | key = value = input.permute(1, 0, 2) # (sequence_length, batch_size, hidden_dim)
183 | query = self.Q.weight.repeat(input.size(0), 1, 1).transpose(
184 | 0, 1) # (num_classes, batch_size, hidden_dim)
185 |
186 | logits, attention = self.attention(query, key, value, key_padding_mask=attention_mask)
187 | logits = logits.permute(1, 0, 2) # (batch_size, num_classes, hidden_dim)
188 | return logits, attention
189 |
190 |
191 | class LabelwiseLinearOutput(nn.Module):
192 | """Applies a linear transformation to the incoming data for each label
193 |
194 | Args:
195 | input_size (int): The number of expected features in the input.
196 | num_classes (int): Total number of classes.
197 | """
198 |
199 | def __init__(self, input_size, num_classes):
200 | super(LabelwiseLinearOutput, self).__init__()
201 | self.output = nn.Linear(input_size, num_classes)
202 |
203 | def forward(self, input):
204 | return (self.output.weight * input).sum(dim=-1) + self.output.bias
205 |
--------------------------------------------------------------------------------
/libmultilabel/nn/networks/xml_cnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .modules import Embedding, CNNEncoder
6 |
7 |
8 | class XMLCNN(nn.Module):
9 | """XML-CNN
10 |
11 | Args:
12 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
13 | num_classes (int): Total number of classes.
14 | embed_dropout (float): The dropout rate of the word embedding. Defaults to 0.2.
15 | hidden_dropout (float): The dropout rate of the hidden layer output. Defaults to 0.
16 | filter_sizes (list): Size of convolutional filters.
17 | hidden_dim (int): Dimension of the hidden layer. Defaults to 512.
18 | num_filter_per_size (int): The number of filters in convolutional layers in each size. Defaults to 256.
19 | num_pool (int): The number of pool for dynamic max-pooling. Defaults to 2.
20 | activation (str): Activation function to be used. Defaults to 'relu'.
21 | """
22 | def __init__(
23 | self,
24 | embed_vecs,
25 | num_classes,
26 | embed_dropout=0.2,
27 | hidden_dropout=0,
28 | filter_sizes=None,
29 | hidden_dim=512,
30 | num_filter_per_size=256,
31 | num_pool=2,
32 | activation='relu'
33 | ):
34 | super(XMLCNN, self).__init__()
35 | self.embedding = Embedding(embed_vecs, embed_dropout)
36 | self.encoder = CNNEncoder(embed_vecs.shape[1], filter_sizes,
37 | num_filter_per_size, activation,
38 | num_pool=num_pool)
39 | total_output_size = len(filter_sizes) * num_filter_per_size * num_pool
40 | self.dropout = nn.Dropout(hidden_dropout)
41 | self.linear1 = nn.Linear(total_output_size, hidden_dim)
42 | self.linear2 = nn.Linear(hidden_dim, num_classes)
43 | self.activation = getattr(torch, activation, getattr(F, activation))
44 |
45 | def forward(self, input):
46 | x = self.embedding(input['text']) # (batch_size, length, embed_dim)
47 | x = self.encoder(x) # (batch_size, num_filter, num_pool)
48 | x = x.view(x.shape[0], -1) # (batch_size, num_filter * num_pool)
49 | x = self.activation(self.linear1(x))
50 | x = self.dropout(x)
51 | x = self.linear2(x)
52 | return {'logits': x}
53 |
--------------------------------------------------------------------------------
/libmultilabel/nn/nn_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import pytorch_lightning as pl
5 | import torch
6 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping
7 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
8 | from pytorch_lightning.utilities.seed import seed_everything
9 |
10 | from ..nn import networks
11 | from ..nn.model import Model
12 |
13 |
14 | def init_device(use_cpu=False):
15 | """Initialize device to CPU if `use_cpu` is set to True otherwise GPU.
16 |
17 | Args:
18 | use_cpu (bool, optional): Whether to use CPU or not. Defaults to False.
19 |
20 | Returns:
21 | torch.device: One of cuda or cpu.
22 | """
23 |
24 | if not use_cpu and torch.cuda.is_available():
25 | # Set a debug environment variable CUBLAS_WORKSPACE_CONFIG to ":16:8" (may limit overall performance) or ":4096:8" (will increase library footprint in GPU memory by approximately 24MiB).
26 | # https://docs.nvidia.com/cuda/cublas/index.html
27 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
28 | device = torch.device('cuda')
29 | else:
30 | device = torch.device('cpu')
31 | # https://github.com/pytorch/pytorch/issues/11201
32 | torch.multiprocessing.set_sharing_strategy('file_system')
33 | logging.info(f'Using device: {device}')
34 | return device
35 |
36 |
37 | def init_model(model_name,
38 | network_config,
39 | classes,
40 | word_dict,
41 | embed_vecs,
42 | init_weight=None,
43 | log_path=None,
44 | learning_rate=0.0001,
45 | optimizer='adam',
46 | momentum=0.9,
47 | weight_decay=0,
48 | metric_threshold=0.5,
49 | monitor_metrics=None,
50 | silent=False,
51 | save_k_predictions=0,
52 | zero=False,
53 | multi_class=False,
54 | enable_ce_loss=False,
55 | hierarchical=False):
56 | """Initialize a `Model` class for initializing and training a neural network.
57 |
58 | Args:
59 | model_name (str): Model to be used such as KimCNN.
60 | network_config (dict): Configuration for defining the network.
61 | classes (list): List of class names.
62 | word_dict (torchtext.vocab.Vocab): A vocab object which maps tokens to indices.
63 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
64 | init_weight (str): Weight initialization method from `torch.nn.init`.
65 | For example, the `init_weight` of `torch.nn.init.kaiming_uniform_`
66 | is `kaiming_uniform`. Defaults to None.
67 | log_path (str): Path to a directory holding the log files and models.
68 | learning_rate (float, optional): Learning rate for optimizer. Defaults to 0.0001.
69 | optimizer (str, optional): Optimizer name (i.e., sgd, adam, or adamw). Defaults to 'adam'.
70 | momentum (float, optional): Momentum factor for SGD only. Defaults to 0.9.
71 | weight_decay (int, optional): Weight decay factor. Defaults to 0.
72 | metric_threshold (float, optional): Threshold to monitor for metrics. Defaults to 0.5.
73 | monitor_metrics (list, optional): Metrics to monitor while validating. Defaults to None.
74 | silent (bool, optional): Enable silent mode. Defaults to False.
75 | save_k_predictions (int, optional): Save top k predictions on test set. Defaults to 0.
76 | zero (bool, optional)
77 | multi_class (bool, optional)
78 | enable_ce_loss (bool, optional)
79 | hierarchical (bool, optional)
80 |
81 | Returns:
82 | Model: A class that implements `MultiLabelModel` for initializing and training a neural network.
83 | """
84 |
85 | network = getattr(networks, model_name)(
86 | embed_vecs=embed_vecs,
87 | num_classes=len(classes),
88 | hierarchical=hierarchical,
89 | **dict(network_config)
90 | )
91 |
92 | if init_weight is not None:
93 | init_weight = networks.get_init_weight_func(
94 | init_weight=init_weight)
95 | network.apply(init_weight)
96 |
97 | model = Model(
98 | classes=classes,
99 | word_dict=word_dict,
100 | embed_vecs=embed_vecs,
101 | network=network,
102 | log_path=log_path,
103 | learning_rate=learning_rate,
104 | optimizer=optimizer,
105 | momentum=momentum,
106 | weight_decay=weight_decay,
107 | metric_threshold=metric_threshold,
108 | monitor_metrics=monitor_metrics,
109 | silent=silent,
110 | save_k_predictions=save_k_predictions,
111 | zero=zero,
112 | multi_class=multi_class,
113 | enable_ce_loss=enable_ce_loss
114 | )
115 | return model
116 |
117 |
118 | def init_trainer(checkpoint_dir,
119 | epochs=10000,
120 | patience=5,
121 | mode='max',
122 | val_metric='P@1',
123 | silent=False,
124 | use_cpu=False,
125 | limit_train_batches=1.0,
126 | limit_val_batches=1.0,
127 | limit_test_batches=1.0,
128 | search_params=False,
129 | save_checkpoints=True,
130 | accumulate_grad_batches=1):
131 | """Initialize a torch lightning trainer.
132 |
133 | Args:
134 | checkpoint_dir (str): Directory for saving models and log.
135 | epochs (int): Number of epochs to train. Defaults to 10000.
136 | patience (int): Number of epochs to wait for improvement before early stopping. Defaults to 5.
137 | mode (str): One of [min, max]. Decides whether the val_metric is minimizing or maximizing.
138 | val_metric (str): The metric to monitor for early stopping. Defaults to 'P@1'.
139 | silent (bool): Enable silent mode. Defaults to False.
140 | use_cpu (bool): Disable CUDA. Defaults to False.
141 | limit_train_batches (Union[int, float]): Percentage of training dataset to use. Defaults to 1.0.
142 | limit_val_batches (Union[int, float]): Percentage of validation dataset to use. Defaults to 1.0.
143 | limit_test_batches (Union[int, float]): Percentage of test dataset to use. Defaults to 1.0.
144 | search_params (bool): Enable pytorch-lightning trainer to report the results to ray tune
145 | on validation end during hyperparameter search. Defaults to False.
146 | save_checkpoints (bool): Whether to save the last and the best checkpoint or not. Defaults to True.
147 | accumulate_grad_batches (int)
148 |
149 | Returns:
150 | pl.Trainer: A torch lightning trainer.
151 | """
152 |
153 | callbacks = []
154 | if save_checkpoints:
155 | callbacks += [ModelCheckpoint(dirpath=checkpoint_dir, filename='best_model',
156 | save_last=True, save_top_k=1,
157 | monitor=val_metric, mode=mode)]
158 | if search_params:
159 | from ray.tune.integration.pytorch_lightning import TuneReportCallback
160 | callbacks += [TuneReportCallback({f'val_{val_metric}': val_metric}, on="validation_end")]
161 |
162 | trainer = pl.Trainer(logger=False, num_sanity_val_steps=0,
163 | gpus=0 if use_cpu else 1,
164 | enable_progress_bar=False if silent else True,
165 | max_epochs=epochs,
166 | callbacks=callbacks,
167 | limit_train_batches=limit_train_batches,
168 | limit_val_batches=limit_val_batches,
169 | limit_test_batches=limit_test_batches,
170 | accumulate_grad_batches=accumulate_grad_batches,
171 | deterministic=True)
172 | return trainer
173 |
174 |
175 | def set_seed(seed):
176 | """Set seeds for numpy and pytorch.
177 |
178 | Args:
179 | seed (int): Random seed.
180 | """
181 |
182 | if seed is not None:
183 | if seed >= 0:
184 | seed_everything(seed=seed, workers=True)
185 | else:
186 | logging.warning('the random seed should be a non-negative integer')
187 |
188 |
189 | from transformers import TrainingArguments
190 | def init_training_args(config):
191 | """Initialize huggingface training arguments.
192 |
193 | Args:
194 | config (dict)
195 |
196 | Returns:
197 | dict: huggingface training arguments.
198 | """
199 |
200 | training_args = TrainingArguments(
201 | output_dir = \
202 | f"{config.result_dir}/{config.data_name}/{config.network_config['lm_weight']}/seed_{config.seed}",
203 | )
204 | training_args.model_name_or_path = config.network_config['lm_weight']
205 | training_args.do_lower_case = True
206 | training_args.do_train = True
207 | training_args.do_eval = True
208 | training_args.do_pred = True
209 | training_args.overwrite_output_dir = True
210 | training_args.load_best_model_at_end = True
211 | training_args.metric_for_best_model = config.val_metric
212 | training_args.greater_is_better = True
213 | training_args.evaluation_strategy = 'epoch'
214 | training_args.save_strategy = 'epoch'
215 | training_args.save_total_limit = 5
216 | training_args.num_train_epochs = config.epochs
217 | training_args.learning_rate = config.learning_rate
218 | training_args.per_device_train_batch_size = config.batch_size
219 | training_args.per_device_eval_batch_size = config.batch_size
220 | training_args.seed = config.seed
221 | if not config.hierarchical: # hierarchical methods will lead to errors
222 | training_args.fp16 = True
223 | training_args.fp16_full_eval = True
224 | training_args.gradient_accumulation_steps = config.accumulate_grad_batches
225 | training_args.eval_accumulation_steps = config.accumulate_grad_batches
226 | return training_args
227 |
228 |
229 | import numpy as np
230 | from sklearn.metrics import f1_score
231 | from transformers import EvalPrediction
232 | def compute_metrics(p: EvalPrediction):
233 | logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
234 | preds = np.argmax(logits, axis=1)
235 | macro_f1 = f1_score(y_true=p.label_ids, y_pred=preds, average='macro', zero_division=0)
236 | micro_f1 = f1_score(y_true=p.label_ids, y_pred=preds, average='micro', zero_division=0)
237 | return {'Macro-F1': macro_f1, 'Micro-F1': micro_f1}
238 |
--------------------------------------------------------------------------------
/linear_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from math import ceil
4 |
5 | import numpy as np
6 |
7 | import libmultilabel.linear as linear
8 | from libmultilabel.common_utils import dump_log, argsort_top_k
9 |
10 |
11 | def linear_test(config, model, datasets):
12 | metrics = linear.get_metrics(
13 | config.metric_threshold,
14 | config.monitor_metrics,
15 | datasets['test']['y'].shape[1],
16 | config.zero,
17 | config.multi_class
18 | )
19 | num_instance = datasets['test']['x'].shape[0]
20 |
21 | k = config.save_k_predictions
22 | top_k_idx = np.zeros((num_instance, k), dtype='i')
23 | top_k_scores = np.zeros((num_instance, k), dtype='d')
24 |
25 | for i in range(ceil(num_instance / config.eval_batch_size)):
26 | slice = np.s_[i*config.eval_batch_size:(i+1)*config.eval_batch_size]
27 | preds = linear.predict_values(model, datasets['test']['x'][slice])
28 | target = datasets['test']['y'][slice].toarray()
29 | metrics.update(preds, target)
30 |
31 | if k > 0:
32 | top_k_idx[slice] = argsort_top_k(preds, k, axis=1)
33 | top_k_scores[slice] = np.take_along_axis(
34 | preds, top_k_idx[slice], axis=1)
35 |
36 | metric_dict = metrics.compute()
37 | return (metric_dict, top_k_idx, top_k_scores)
38 |
39 |
40 | def linear_train(datasets, config):
41 | techniques = {'1vsrest': linear.train_1vsrest,
42 | 'thresholding': linear.train_thresholding,
43 | 'cost_sensitive': linear.train_cost_sensitive,
44 | 'cost_sensitive_micro': linear.train_cost_sensitive_micro}
45 | model = techniques[config.linear_technique](
46 | datasets['train']['y'],
47 | datasets['train']['x'],
48 | config.liblinear_options,
49 | )
50 | return model
51 |
52 |
53 | def linear_run(config):
54 | if config.seed is not None:
55 | np.random.seed(config.seed)
56 |
57 | if config.eval:
58 | preprocessor, model = linear.load_pipeline(config.checkpoint_path)
59 | datasets = preprocessor.load_data(
60 | config.train_path, config.test_path, config.eval)
61 | else:
62 | preprocessor = linear.Preprocessor(data_format=config.data_format)
63 | datasets = preprocessor.load_data(
64 | config.train_path,
65 | config.test_path,
66 | config.eval,
67 | config.label_file,
68 | config.include_test_labels,
69 | config.remove_no_label_data)
70 | model = linear_train(datasets, config)
71 | linear.save_pipeline(config.checkpoint_dir, preprocessor, model)
72 |
73 | if os.path.exists(config.test_path):
74 | metric_dict, top_k_idx, top_k_scores = linear_test(
75 | config, model, datasets)
76 |
77 | dump_log(config=config, metrics=metric_dict,
78 | split='test', log_path=config.log_path)
79 | print(linear.tabulate_metrics(metric_dict, 'test'))
80 |
81 | if config.save_k_predictions > 0:
82 | classes = preprocessor.binarizer.classes_
83 | with open(config.predict_out_path, 'w') as fp:
84 | for idx, score in zip(top_k_idx, top_k_scores):
85 | out_str = ' '.join([f'{classes[i]}:{s:.4}' for i, s in zip(
86 | idx, score)])
87 | fp.write(out_str+'\n')
88 | logging.info(f'Saved predictions to: {config.predict_out_path}')
89 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | from datetime import datetime
5 | from pathlib import Path
6 |
7 | import yaml
8 |
9 | from libmultilabel.common_utils import Timer, AttributeDict
10 |
11 |
12 | def add_all_arguments(parser):
13 | # path / directory
14 | parser.add_argument('--data_dir', default='./data/rcv1',
15 | help='The directory to load data (default: %(default)s)')
16 | parser.add_argument('--result_dir', default='./runs',
17 | help='The directory to save checkpoints and logs (default: %(default)s)')
18 |
19 | # data
20 | parser.add_argument('--data_name', default='rcv1',
21 | help='Dataset name (default: %(default)s)')
22 | parser.add_argument('--train_path',
23 | help='Path to training data (default: [data_dir]/train.txt)')
24 | parser.add_argument('--val_path',
25 | help='Path to validation data (default: [data_dir]/valid.txt)')
26 | parser.add_argument('--test_path',
27 | help='Path to test data (default: [data_dir]/test.txt)')
28 | parser.add_argument('--val_size', type=float, default=0.2,
29 | help='Training-validation split: a ratio in [0, 1] or an integer for the size of the validation set (default: %(default)s).')
30 | parser.add_argument('--min_vocab_freq', type=int, default=1,
31 | help='The minimum frequency needed to include a token in the vocabulary (default: %(default)s)')
32 | parser.add_argument('--max_seq_length', type=int, default=500,
33 | help='The maximum number of tokens of a sample (default: %(default)s)')
34 | parser.add_argument('--lm_weight', type=str,
35 | help='Pretrained model name or path (default: %(default)s)')
36 | parser.add_argument('--shuffle', type=bool, default=True,
37 | help='Whether to shuffle training data before each epoch (default: %(default)s)')
38 | parser.add_argument('--merge_train_val', action='store_true',
39 | help='Whether to merge the training and validation data. (default: %(default)s)')
40 | parser.add_argument('--include_test_labels', action='store_true',
41 | help='Whether to include labels in the test dataset. (default: %(default)s)')
42 | parser.add_argument('--remove_no_label_data', action='store_true',
43 | help='Whether to remove training and validation instances that have no labels.')
44 |
45 | # train
46 | parser.add_argument('--seed', type=int,
47 | help='Random seed (default: %(default)s)')
48 | parser.add_argument('--epochs', type=int, default=10000,
49 | help='The number of epochs to train (default: %(default)s)')
50 | parser.add_argument('--batch_size', type=int, default=16,
51 | help='Size of training batches (default: %(default)s)')
52 | parser.add_argument('--optimizer', default='adam', choices=['adam', 'adamw', 'adamax', 'sgd'],
53 | help='Optimizer (default: %(default)s)')
54 | parser.add_argument('--learning_rate', type=float, default=0.0001,
55 | help='Learning rate for optimizer (default: %(default)s)')
56 | parser.add_argument('--weight_decay', type=float, default=0,
57 | help='Weight decay factor (default: %(default)s)')
58 | parser.add_argument('--momentum', type=float, default=0.9,
59 | help='Momentum factor for SGD only (default: %(default)s)')
60 | parser.add_argument('--patience', type=int, default=5,
61 | help='The number of epochs to wait for improvement before early stopping (default: %(default)s)')
62 | parser.add_argument('--normalize_embed', action='store_true',
63 | help='Whether the embeddings of each word is normalized to a unit vector (default: %(default)s)')
64 |
65 | # model
66 | parser.add_argument('--model_name', default='KimCNN',
67 | help='Model to be used (default: %(default)s)')
68 | parser.add_argument('--init_weight', default='kaiming_uniform',
69 | help='Weight initialization to be used (default: %(default)s)')
70 |
71 | # eval
72 | parser.add_argument('--eval_batch_size', type=int, default=256,
73 | help='Size of evaluating batches (default: %(default)s)')
74 | parser.add_argument('--metric_threshold', type=float, default=0.5,
75 | help='Thresholds to monitor for metrics (default: %(default)s)')
76 | parser.add_argument('--monitor_metrics', nargs='+', default=['P@1', 'P@3', 'P@5'],
77 | help='Metrics to monitor while validating (default: %(default)s)')
78 | parser.add_argument('--val_metric', default='P@1',
79 | help='The metric to monitor for early stopping (default: %(default)s)')
80 |
81 | # pretrained vocab / embeddings
82 | parser.add_argument('--vocab_file', type=str,
83 | help='Path to a file holding vocabuaries (default: %(default)s)')
84 | parser.add_argument('--embed_file', type=str,
85 | help='Path to a file holding pre-trained embeddings (default: %(default)s)')
86 | parser.add_argument('--label_file', type=str,
87 | help='Path to a file holding all labels (default: %(default)s)')
88 |
89 | # log
90 | parser.add_argument('--save_k_predictions', type=int, nargs='?', const=100, default=0,
91 | help='Save top k predictions on test set. k=%(const)s if not specified. (default: %(default)s)')
92 | parser.add_argument('--predict_out_path',
93 | help='Path to the an output file holding top k label results (default: %(default)s)')
94 |
95 | # auto-test
96 | parser.add_argument('--limit_train_batches', type=float, default=1.0,
97 | help='Percentage of train dataset to use for auto-testing (default: %(default)s)')
98 | parser.add_argument('--limit_val_batches', type=float, default=1.0,
99 | help='Percentage of validation dataset to use for auto-testing (default: %(default)s)')
100 | parser.add_argument('--limit_test_batches', type=float, default=1.0,
101 | help='Percentage of test dataset to use for auto-testing (default: %(default)s)')
102 |
103 | # others
104 | parser.add_argument('--cpu', action='store_true',
105 | help='Disable CUDA')
106 | parser.add_argument('--silent', action='store_true',
107 | help='Enable silent mode')
108 | parser.add_argument('--data_workers', type=int, default=4,
109 | help='Use multi-cpu core for data pre-processing (default: %(default)s)')
110 | parser.add_argument('--embed_cache_dir', type=str,
111 | help='For parameter search only: path to a directory for storing embeddings for multiple runs. (default: %(default)s)')
112 | parser.add_argument('--eval', action='store_true',
113 | help='Only run evaluation on the test set (default: %(default)s)')
114 | parser.add_argument('--checkpoint_path',
115 | help='The checkpoint to warm-up with (default: %(default)s)')
116 |
117 | # linear options
118 | parser.add_argument('--linear', action='store_true',
119 | help='Train linear model')
120 | parser.add_argument('--data_format', type=str, default='txt',
121 | help='\'svm\' for SVM format or \'txt\' for LibMultiLabel format (default: %(default)s)')
122 | parser.add_argument('--liblinear_options', type=str,
123 | help='Options passed to liblinear (default: %(default)s)')
124 | parser.add_argument('--linear_technique', type=str, default='1vsrest',
125 | choices=['1vsrest', 'thresholding', 'cost_sensitive', 'cost_sensitive_micro'],
126 | help='Technique for linear classification (default: %(default)s)')
127 |
128 | parser.add_argument('-h', '--help', action='help',
129 | help="If you are trying to specify network config such as dropout or activation, use a yaml file instead. "
130 | "See example configs in example_config")
131 |
132 | # LexGLUE
133 | parser.add_argument('--zero', action='store_true')
134 | parser.add_argument('--multi_class', action='store_true')
135 | parser.add_argument('--add_special_tokens', action='store_true')
136 | parser.add_argument('--enable_ce_loss', action='store_true')
137 | parser.add_argument('--hierarchical', action='store_true')
138 | parser.add_argument('--accumulate_grad_batches', type=int, default=1)
139 | parser.add_argument('--enable_transformer_trainer', action='store_true')
140 |
141 |
142 | def get_config():
143 | parser = argparse.ArgumentParser(
144 | add_help=False,
145 | description='multi-label learning for text classification')
146 |
147 | # load params from config file
148 | parser.add_argument('-c', '--config', help='Path to configuration file')
149 | args, _ = parser.parse_known_args()
150 | config = {}
151 | if args.config:
152 | with open(args.config) as fp:
153 | config = yaml.load(fp, Loader=yaml.SafeLoader)
154 |
155 | add_all_arguments(parser)
156 |
157 | parser.set_defaults(**config)
158 | args = parser.parse_args()
159 | config = AttributeDict(vars(args))
160 |
161 | config.run_name = '{}_{}_{}'.format(
162 | config.data_name,
163 | Path(config.config).stem if config.config else config.model_name,
164 | datetime.now().strftime('%Y%m%d%H%M%S'),
165 | )
166 | config.checkpoint_dir = os.path.join(config.result_dir, config.run_name)
167 | config.log_path = os.path.join(config.checkpoint_dir, 'logs.json')
168 | config.predict_out_path = config.predict_out_path or os.path.join(
169 | config.checkpoint_dir, 'predictions.txt')
170 |
171 | config.train_path = config.train_path or os.path.join(
172 | config.data_dir, 'train.txt')
173 | config.val_path = config.val_path or os.path.join(
174 | config.data_dir, 'valid.txt')
175 | config.test_path = config.test_path or os.path.join(
176 | config.data_dir, 'test.txt')
177 |
178 | return config
179 |
180 |
181 | def check_config(config):
182 | """Check if the configuration has invalid arguments.
183 |
184 | Args:
185 | config (AttributeDict): Config of the experiment from `get_args`.
186 | """
187 | if config.model_name == 'XMLCNN' and config.seed is not None:
188 | raise ValueError("nn.AdaptiveMaxPool1d doesn't have a deterministic implementation but seed is"
189 | "specified. Please do not specify seed.")
190 |
191 | if config.eval and not os.path.exists(config.test_path):
192 | raise ValueError('--eval is specified but there is no test data set')
193 |
194 |
195 | def main():
196 | # Get config
197 | config = get_config()
198 | check_config(config)
199 |
200 | # Set up logger
201 | log_level = logging.WARNING if config.silent else logging.INFO
202 | logging.basicConfig(
203 | level=log_level, format='%(asctime)s %(levelname)s:%(message)s')
204 |
205 | logging.info(f'Run name: {config.run_name}')
206 |
207 | if config.linear:
208 | from linear_trainer import linear_run
209 | linear_run(config)
210 | else:
211 | from torch_trainer import TorchTrainer
212 | trainer = TorchTrainer(config) # initialize trainer
213 | # train
214 | if not config.eval:
215 | trainer.train()
216 | # test
217 | if 'test' in trainer.datasets:
218 | trainer.test()
219 |
220 | return config
221 |
222 |
223 | def dump_time_to_log(log_path, time):
224 | """Write time to log.
225 | Args:
226 | log_path(str): path to log path
227 | time (str): time in seconds
228 | """
229 | assert os.path.isfile(log_path)
230 | with open(log_path) as fp:
231 | import json
232 | result = json.load(fp)
233 |
234 | if time >= 60 * 60:
235 | h, s = divmod(time, 60 * 60)
236 | formatted_time = f'{h}h {s/60:.0f}m'
237 | elif time >= 60:
238 | m, s = divmod(time, 60)
239 | formatted_time = f'{m}m {s}s'
240 | else:
241 | formatted_time = f'{time}s'
242 | result['time'] = formatted_time
243 |
244 | with open(log_path, 'w') as fp:
245 | json.dump(result, fp)
246 |
247 | print(f'Wall time: {formatted_time}')
248 |
249 | if __name__ == '__main__':
250 | wall_time = Timer()
251 | config = main()
252 | dump_time_to_log(config.log_path, round(wall_time.time()))
253 | # print(f'Wall time: {wall_time.time():.2f} (s)')
254 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets==2.18.0 # or any version > 2.15.0
2 | nltk==3.7
3 | pandas==1.5.0
4 | PyYAML==6.0
5 | scikit-learn==1.1.2
6 | torch==1.12.0
7 | torchmetrics==0.9.2
8 | torchtext==0.13.0
9 | pytorch-lightning==1.6.5
10 | tqdm==4.64.1
11 | liblinear-multicore==2.45.1
12 | numba==0.56.3
13 | scipy==1.9.2
14 | transformers==4.23.0
15 |
--------------------------------------------------------------------------------
/requirements_parameter_search.txt:
--------------------------------------------------------------------------------
1 | bayesian-optimization==1.2.0 # ray[tune]
2 | optuna==2.10.1 # ray[tune]
3 | ray==1.13.0
4 | ray[tune]==1.13.0
5 | grpcio==1.43.0 # Fix issue: https://github.com/ray-project/ray/issues/22518
6 |
--------------------------------------------------------------------------------
/run_experiments.sh:
--------------------------------------------------------------------------------
1 | data_list=(ecthr_a ecthr_b scotus eurlex ledgar unfair_tos)
2 | algo_list=(1vsrest thresholding cost_sensitive bert_default bert_tuned bert_reproduced)
3 |
4 | data=$1
5 | algo=$2
6 |
7 | if [[ ! " ${data_list[*]} " =~ " ${data} " ]]; then
8 | echo "Invalid argument! Data ${data} is not in (${data_list[*]})."
9 | exit
10 | fi
11 |
12 | if [[ ! " ${algo_list[*]} " =~ " ${algo} " ]]; then
13 | echo "Invalid argument! Algorithm ${algo} is not in (${algo_list[*]})."
14 | exit
15 | fi
16 |
17 | linear_algo_list=(1vsrest thresholding cost_sensitive)
18 | bert_algo_list=(bert_default bert_tuned bert_reproduced)
19 |
20 | multilabel_unlabeled_data_list=(ecthr_a ecthr_b unfair_tos)
21 | multilabel_labeled_data_list=(eurlex)
22 | multiclass_labeled_data_list=(scotus ledgar)
23 |
24 | if [[ " ${linear_algo_list[*]} " =~ " ${algo} " ]]; then
25 | if [[ " ${multilabel_unlabeled_data_list[*]} " =~ " ${data} " ]]; then
26 | python3 main.py --config config/${data}/l2svm.yml --linear_technique ${algo} --zero
27 | elif [[ " ${multilabel_labeled_data_list[*]} " =~ " ${data} " ]]; then
28 | python3 main.py --config config/${data}/l2svm.yml --linear_technique ${algo}
29 | elif [[ " ${multiclass_labeled_data_list[*]} " =~ " ${data} " ]]; then
30 | python3 main.py --config config/${data}/l2svm.yml --linear_technique ${algo} --multi_class
31 | else
32 | echo "Should never reach here..."
33 | exit
34 | fi
35 | elif [[ " ${bert_algo_list[*]} " =~ " ${algo} " ]]; then
36 | if [[ " ${multilabel_unlabeled_data_list[*]} " =~ " ${data} " ]]; then
37 | python3 main.py --config config/${data}/${algo}.yml --zero --seed 1
38 | elif [[ " ${multilabel_labeled_data_list[*]} " =~ " ${data} " ]]; then
39 | python3 main.py --config config/${data}/${algo}.yml --seed 1
40 | elif [[ " ${multiclass_labeled_data_list[*]} " =~ " ${data} " ]]; then
41 | huggingface_trainer_algo_list=(bert_reproduced)
42 | if [[ ! " ${huggingface_trainer_algo_list[*]} " =~ " ${algo} " ]]; then
43 | python3 main.py --config config/${data}/${algo}.yml --multi_class --enable_ce_loss --seed 1
44 | else
45 | python3 main.py --config config/${data}/${algo}.yml --multi_class --enable_ce_loss --seed 1 --enable_transformer_trainer
46 | fi
47 | else
48 | echo "Should never reach here..."
49 | exit
50 | fi
51 | fi
52 |
--------------------------------------------------------------------------------
/search_params.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import logging
4 | import os
5 | import time
6 | from datetime import datetime
7 | from pathlib import Path
8 |
9 | import yaml
10 | from ray import tune
11 | from ray.tune.schedulers import ASHAScheduler
12 |
13 | import numpy as np
14 |
15 | from libmultilabel.nn import data_utils
16 | from libmultilabel.nn.nn_utils import set_seed
17 | from libmultilabel.common_utils import AttributeDict, Timer
18 | from torch_trainer import TorchTrainer
19 |
20 | logging.basicConfig(level=logging.INFO,
21 | format='%(asctime)s %(levelname)s:%(message)s')
22 |
23 |
24 | def train_libmultilabel_tune(config, datasets, classes, word_dict):
25 | """The training function for ray tune.
26 |
27 | Args:
28 | config (AttributeDict): Config of the experiment.
29 | datasets (dict): A dictionary of datasets.
30 | classes(list): List of class names.
31 | word_dict(torchtext.vocab.Vocab): A vocab object which maps tokens to indices.
32 | """
33 | set_seed(seed=config.seed)
34 | config.run_name = tune.get_trial_dir()
35 | logging.info(f'Run name: {config.run_name}')
36 | config.checkpoint_dir = os.path.join(config.result_dir, config.run_name)
37 | config.log_path = os.path.join(config.checkpoint_dir, 'logs.json')
38 |
39 | trainer = TorchTrainer(config=config,
40 | datasets=datasets,
41 | classes=classes,
42 | word_dict=word_dict,
43 | search_params=True,
44 | save_checkpoints=False)
45 | trainer.train()
46 |
47 |
48 | def load_config_from_file(config_path):
49 | """Initialize the model config.
50 |
51 | Args:
52 | config_path (str): Path to the config file.
53 |
54 | Returns:
55 | AttributeDict: Config of the experiment.
56 | """
57 | with open(config_path) as fp:
58 | config = yaml.safe_load(fp)
59 |
60 | # create directories that hold the shared data
61 | os.makedirs(config['result_dir'], exist_ok=True)
62 | if config['embed_cache_dir']:
63 | os.makedirs(config['embed_cache_dir'], exist_ok=True)
64 |
65 | # set relative path to absolute path (_path, _file, _dir)
66 | for k, v in config.items():
67 | if isinstance(v, str) and os.path.exists(v):
68 | config[k] = os.path.abspath(v)
69 |
70 | # find `train.txt`, `val.txt`, and `test.txt` from the data directory if not specified.
71 | config['train_path'] = config['train_path'] or os.path.join(config['data_dir'], 'train.txt')
72 | config['val_path'] = config['val_path'] or os.path.join(config['data_dir'], 'valid.txt')
73 | config['test_path'] = config['test_path'] or os.path.join(config['data_dir'], 'test.txt')
74 |
75 | return config
76 |
77 |
78 | def init_search_params_spaces(config, parameter_columns, prefix):
79 | """Initialize the sample space defined in ray tune.
80 | See the random distributions API listed here: https://docs.ray.io/en/master/tune/api_docs/search_space.html#random-distributions-api
81 |
82 | Args:
83 | config (dict): Config of the experiment.
84 | parameter_columns (dict): Names of parameters to include in the CLIReporter.
85 | The keys are parameter names and the values are displayed names.
86 | prefix(str): The prefix of a nested parameter such as network_config/dropout.
87 |
88 | Returns:
89 | dict: Config with parsed sample spaces.
90 | """
91 | search_spaces = ['choice', 'grid_search', 'uniform', 'quniform', 'loguniform',
92 | 'qloguniform', 'randn', 'qrandn', 'randint', 'qrandint']
93 | for key, value in config.items():
94 | if isinstance(value, list) and len(value) >= 2 and value[0] in search_spaces:
95 | search_space, search_args = value[0], value[1:]
96 | if isinstance(search_args[0], list) and any(isinstance(x, list) for x in search_args[0]) and search_space != 'grid_search':
97 | raise ValueError(
98 | """If the search values are lists, the search space must be `grid_search`.
99 | Take `filter_sizes: ['grid_search', [[2,4,8], [4,6]]]` for example, the program will grid search over
100 | [2,4,8] and [4,6]. This is the same as assigning `filter_sizes` to either [2,4,8] or [4,6] in two runs.
101 | """)
102 | else:
103 | config[key] = getattr(tune, search_space)(*search_args)
104 | parameter_columns[prefix+key] = key
105 | elif isinstance(value, dict):
106 | config[key] = init_search_params_spaces(value, parameter_columns, f'{prefix}{key}/')
107 |
108 | return config
109 |
110 |
111 | def init_search_algorithm(search_alg, metric=None, mode=None):
112 | """Specify a search algorithm and you must pip install it first.
113 | If no search algorithm is specified, the default search algorithm is BasicVariantGenerator.
114 | See more details here: https://docs.ray.io/en/master/tune/api_docs/suggestion.html
115 |
116 | Args:
117 | search_alg (str): One of 'basic_variant', 'bayesopt', or 'optuna'.
118 | metric (str): The metric to monitor for early stopping.
119 | mode (str): One of 'min' or 'max' to determine whether to minimize or maximize the metric.
120 | """
121 | if search_alg == 'optuna':
122 | assert metric and mode, "Metric and mode cannot be None for optuna."
123 | from ray.tune.suggest.optuna import OptunaSearch
124 | return OptunaSearch(metric=metric, mode=mode)
125 | elif search_alg == 'bayesopt':
126 | assert metric and mode, "Metric and mode cannot be None for bayesian optimization."
127 | from ray.tune.suggest.bayesopt import BayesOptSearch
128 | return BayesOptSearch(metric=metric, mode=mode)
129 | logging.info(f'{search_alg} search is found, run BasicVariantGenerator().')
130 |
131 |
132 | def prepare_retrain_config(best_config, best_log_dir, merge_train_val):
133 | """Prepare the configuration for re-training.
134 |
135 | Args:
136 | best_config (AttributeDict): The best hyper-parameter configuration.
137 | best_log_dir (str): The directory of the best trial of the experiment.
138 | merge_train_val (bool): Whether to merge the training and validation data.
139 | """
140 | if merge_train_val:
141 | best_config.merge_train_val = True
142 |
143 | log_path = os.path.join(best_log_dir, 'logs.json')
144 | if os.path.isfile(log_path):
145 | with open(log_path) as fp:
146 | log = json.load(fp)
147 | else:
148 | raise FileNotFoundError("The log directory does not contain a log.")
149 |
150 | # For re-training with validation data,
151 | # we use the number of epochs at the point of optimal validation performance.
152 | log_metric = np.array([l[best_config.val_metric] for l in log['val']])
153 | optimal_idx = log_metric.argmax() if best_config.mode == 'max' else log_metric.argmin()
154 | best_config.epochs = optimal_idx.item() + 1 # plus 1 for epochs
155 | else:
156 | best_config.merge_train_val = False
157 |
158 |
159 | def load_static_data(config, merge_train_val=False):
160 | """Preload static data once for multiple trials.
161 |
162 | Args:
163 | config (AttributeDict): Config of the experiment.
164 | merge_train_val (bool, optional): Whether to merge the training and validation data.
165 | Defaults to False.
166 |
167 | Returns:
168 | dict: A dict of static data containing datasets, classes, and word_dict.
169 | """
170 | datasets = data_utils.load_datasets(train_path=config.train_path,
171 | test_path=config.test_path,
172 | val_path=config.val_path,
173 | val_size=config.val_size,
174 | merge_train_val=merge_train_val,
175 | tokenize_text='lm_weight' not in config['network_config'],
176 | remove_no_label_data=config.remove_no_label_data
177 | )
178 | return {
179 | "datasets": datasets,
180 | "word_dict": None if config.embed_file is None else data_utils.load_or_build_text_dict(
181 | dataset=datasets['train'],
182 | vocab_file=config.vocab_file,
183 | min_vocab_freq=config.min_vocab_freq,
184 | embed_file=config.embed_file,
185 | embed_cache_dir=config.embed_cache_dir,
186 | silent=config.silent,
187 | normalize_embed=config.normalize_embed
188 | ),
189 | "classes": data_utils.load_or_build_label(datasets, config.label_file, config.include_test_labels)
190 | }
191 |
192 |
193 | def retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val):
194 | """Re-train the model with the best hyper-parameters.
195 | A new model is trained on the combined training and validation data if `merge_train_val` is True.
196 | If a test set is provided, it will be evaluated by the obtained model.
197 |
198 | Args:
199 | exp_name (str): The directory to save trials generated by ray tune.
200 | best_config (AttributeDict): The best hyper-parameter configuration.
201 | best_log_dir (str): The directory of the best trial of the experiment.
202 | merge_train_val (bool): Whether to merge the training and validation data.
203 | """
204 | best_config.silent = False
205 | checkpoint_dir = os.path.join(best_config.result_dir, exp_name, 'trial_best_params')
206 | os.makedirs(checkpoint_dir, exist_ok=True)
207 | with open(os.path.join(checkpoint_dir, 'params.yml'), 'w') as fp:
208 | yaml.dump(dict(best_config), fp)
209 | quit() # do not need re-training
210 | best_config.run_name = '_'.join(exp_name.split('_')[:-1]) + '_best'
211 | best_config.checkpoint_dir = checkpoint_dir
212 | best_config.log_path = os.path.join(best_config.checkpoint_dir, 'logs.json')
213 | prepare_retrain_config(best_config, best_log_dir, merge_train_val)
214 | set_seed(seed=best_config.seed)
215 |
216 | data = load_static_data(best_config, merge_train_val=best_config.merge_train_val)
217 | logging.info(f'Re-training with best config: \n{best_config}')
218 | trainer = TorchTrainer(config=best_config, **data)
219 | trainer.train()
220 |
221 | if 'test' in data['datasets']:
222 | test_results = trainer.test()
223 | logging.info(f'Test results after re-training: {test_results}')
224 | logging.info(f'Best model saved to {trainer.checkpoint_callback.best_model_path or trainer.checkpoint_callback.last_model_path}.')
225 |
226 |
227 | def main():
228 | parser = argparse.ArgumentParser()
229 | parser.add_argument(
230 | '--config', help='Path to configuration file (default: %(default)s). Please specify a config with all arguments in LibMultiLabel/main.py::get_config.')
231 | parser.add_argument('--cpu_count', type=int, default=4,
232 | help='Number of CPU per trial (default: %(default)s)')
233 | parser.add_argument('--gpu_count', type=int, default=1,
234 | help='Number of GPU per trial (default: %(default)s)')
235 | parser.add_argument('--num_samples', type=int, default=50,
236 | help='Number of running trials. If the search space is `grid_search`, the same grid will be repeated `num_samples` times. (default: %(default)s)')
237 | parser.add_argument('--mode', default='max', choices=['min', 'max'],
238 | help='Determines whether objective is minimizing or maximizing the metric attribute. (default: %(default)s)')
239 | parser.add_argument('--search_alg', default=None, choices=['basic_variant', 'bayesopt', 'optuna'],
240 | help='Search algorithms (default: %(default)s)')
241 | parser.add_argument('--no_merge_train_val', action='store_true',
242 | help='Do not add the validation set in re-training the final model after hyper-parameter search.')
243 | args, _ = parser.parse_known_args()
244 |
245 | # Load config from the config file and overwrite values specified in CLI.
246 | parameter_columns = dict() # parameters to include in progress table of CLIReporter
247 | config = load_config_from_file(args.config)
248 | config = init_search_params_spaces(config, parameter_columns, prefix='')
249 | parser.set_defaults(**config)
250 | config = AttributeDict(vars(parser.parse_args()))
251 | config.merge_train_val = False # no need to include validation during parameter search
252 |
253 | # Check if the validation set is provided.
254 | val_path = config.val_path or os.path.join(config.data_dir, 'valid.txt')
255 | assert config.val_size > 0 or os.path.exists(val_path), \
256 | "You should specify either a positive `val_size` or a `val_path` defaults to `data_dir/valid.txt` for parameter search."
257 |
258 | """Run tune analysis.
259 | - If no search algorithm is specified, the default search algorighm is BasicVariantGenerator.
260 | https://docs.ray.io/en/master/tune/api_docs/suggestion.html#tune-basicvariant
261 | - Arguments without search spaces will be ignored by `tune.run`
262 | (https://github.com/ray-project/ray/blob/34d3d9294c50aea4005b7367404f6a5d9e0c2698/python/ray/tune/suggest/variant_generator.py#L333),
263 | so we parse the whole config to `tune.run` here for simplicity.
264 | """
265 | data = load_static_data(config)
266 | reporter = tune.CLIReporter(metric_columns=[f'val_{metric}' for metric in config.monitor_metrics],
267 | parameter_columns=parameter_columns,
268 | metric=f'val_{config.val_metric}',
269 | mode=args.mode,
270 | sort_by_metric=True)
271 | if config.scheduler is not None:
272 | scheduler = ASHAScheduler(metric=f'val_{config.val_metric}',
273 | mode=args.mode,
274 | **config.scheduler)
275 | else:
276 | scheduler = None
277 |
278 | exp_name = '{}_{}_{}'.format(
279 | config.data_name,
280 | Path(config.config).stem if config.config else config.model_name,
281 | datetime.now().strftime('%Y%m%d%H%M%S'),
282 | )
283 | analysis = tune.run(
284 | tune.with_parameters(
285 | train_libmultilabel_tune,
286 | **data),
287 | search_alg=init_search_algorithm(
288 | config.search_alg, metric=config.val_metric, mode=args.mode),
289 | scheduler=scheduler,
290 | local_dir=config.result_dir,
291 | num_samples=config.num_samples,
292 | resources_per_trial={
293 | 'cpu': args.cpu_count, 'gpu': args.gpu_count},
294 | progress_reporter=reporter,
295 | config=config,
296 | name=exp_name,
297 | )
298 |
299 | # Save best model after parameter search.
300 | best_config = analysis.get_best_config(f'val_{config.val_metric}', args.mode, scope='all')
301 | best_log_dir = analysis.get_best_logdir(f'val_{config.val_metric}', args.mode, scope='all')
302 | retrain_best_model(exp_name, best_config, best_log_dir, merge_train_val=not args.no_merge_train_val)
303 |
304 |
305 | if __name__ == '__main__':
306 | # calculate wall time.
307 | wall_time = Timer()
308 | main()
309 | print(f'Wall time: {wall_time.time():.2f} (s)')
310 |
--------------------------------------------------------------------------------
/search_params.sh:
--------------------------------------------------------------------------------
1 | data_list=(ecthr_a ecthr_b scotus eurlex ledgar unfair_tos)
2 |
3 | data=$1
4 |
5 | if [[ ! " ${data_list[*]} " =~ " ${data} " ]]; then
6 | echo "Invalid argument! Data ${data} is not in (${data_list[*]})."
7 | exit
8 | fi
9 |
10 | python3 search_params.py --config config/${data}/bert_tune.yml
11 |
--------------------------------------------------------------------------------
/torch_trainer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import numpy as np
5 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
6 | from transformers import AutoTokenizer
7 |
8 | from libmultilabel.nn import data_utils
9 | from libmultilabel.nn.model import Model
10 | from libmultilabel.nn.nn_utils import init_device, init_model, init_trainer
11 | from libmultilabel.common_utils import dump_log
12 |
13 |
14 | class TorchTrainer:
15 | """A wrapper for training neural network models with pytorch lightning trainer.
16 |
17 | Args:
18 | config (AttributeDict): Config of the experiment.
19 | datasets (dict, optional): Datasets for training, validation, and test. Defaults to None.
20 | classes(list, optional): List of class names.
21 | word_dict(torchtext.vocab.Vocab, optional): A vocab object which maps tokens to indices.
22 | embed_vecs (torch.Tensor, optional): The pre-trained word vectors of shape (vocab_size, embed_dim).
23 | search_params (bool, optional): Enable pytorch-lightning trainer to report the results to ray tune
24 | on validation end during hyperparameter search. Defaults to False.
25 | save_checkpoints (bool, optional): Whether to save the last and the best checkpoint or not.
26 | Defaults to True.
27 | """
28 | def __init__(
29 | self,
30 | config: dict,
31 | datasets: dict = None,
32 | classes: list = None,
33 | word_dict: dict = None,
34 | embed_vecs = None,
35 | search_params: bool = False,
36 | save_checkpoints: bool = True
37 | ):
38 | self.run_name = config.run_name
39 | self.checkpoint_dir = config.checkpoint_dir
40 | self.log_path = config.log_path
41 | os.makedirs(self.checkpoint_dir, exist_ok=True)
42 |
43 | # Set up seed & device
44 | if not config.enable_transformer_trainer:
45 | from libmultilabel.nn.nn_utils import set_seed
46 | set_seed(seed=config.seed)
47 | self.device = init_device(use_cpu=config.cpu)
48 | self.config = config
49 |
50 | # Load pretrained tokenizer for dataset loader
51 | self.tokenizer = None
52 | tokenize_text = 'lm_weight' not in config.network_config
53 | if not tokenize_text:
54 | self.tokenizer = AutoTokenizer.from_pretrained(config.network_config['lm_weight'], use_fast=True)
55 | # Load dataset
56 | if datasets is None:
57 | self.datasets = data_utils.load_datasets(
58 | train_path=config.train_path,
59 | test_path=config.test_path,
60 | val_path=config.val_path,
61 | val_size=config.val_size,
62 | merge_train_val=config.merge_train_val,
63 | tokenize_text=tokenize_text,
64 | remove_no_label_data=config.remove_no_label_data
65 | )
66 | else:
67 | self.datasets = datasets
68 |
69 | self._setup_model(classes=classes,
70 | word_dict=word_dict,
71 | embed_vecs=embed_vecs,
72 | log_path=self.log_path,
73 | checkpoint_path=config.checkpoint_path)
74 | if config.enable_transformer_trainer:
75 | from transformers import EarlyStoppingCallback, set_seed, Trainer
76 |
77 | from libmultilabel.nn.data_utils import generate_transformer_batch
78 | from libmultilabel.nn.nn_utils import init_training_args, compute_metrics
79 |
80 | training_args = init_training_args(config)
81 |
82 | set_seed(training_args.seed)
83 |
84 | self.train_dataset = self._get_dataset_loader(split='train', shuffle=config.shuffle).dataset
85 | self.val_dataset = self._get_dataset_loader(split='val').dataset
86 | self.test_dataset = self._get_dataset_loader(split='test').dataset
87 | self.trainer = Trainer(
88 | model=self.model.network.lm,
89 | args=training_args,
90 | train_dataset=self.train_dataset,
91 | eval_dataset=self.val_dataset,
92 | compute_metrics=compute_metrics,
93 | tokenizer=self.tokenizer,
94 | data_collator=generate_transformer_batch,
95 | callbacks=[EarlyStoppingCallback(early_stopping_patience=config.patience)]
96 | )
97 | else:
98 | self.trainer = init_trainer(checkpoint_dir=self.checkpoint_dir,
99 | epochs=config.epochs,
100 | patience=config.patience,
101 | val_metric=config.val_metric,
102 | silent=config.silent,
103 | use_cpu=config.cpu,
104 | limit_train_batches=config.limit_train_batches,
105 | limit_val_batches=config.limit_val_batches,
106 | limit_test_batches=config.limit_test_batches,
107 | search_params=search_params,
108 | save_checkpoints=save_checkpoints,
109 | accumulate_grad_batches=config.accumulate_grad_batches)
110 | callbacks = [callback for callback in self.trainer.callbacks if isinstance(callback, ModelCheckpoint)]
111 | self.checkpoint_callback = callbacks[0] if callbacks else None
112 |
113 | # Dump config to log
114 | dump_log(self.log_path, config=config)
115 |
116 | def _setup_model(
117 | self,
118 | classes: list = None,
119 | word_dict: dict = None,
120 | embed_vecs = None,
121 | log_path: str = None,
122 | checkpoint_path: str = None
123 | ):
124 | """Setup model from checkpoint if a checkpoint path is passed in or specified in the config.
125 | Otherwise, initialize model from scratch.
126 |
127 | Args:
128 | classes(list): List of class names.
129 | word_dict(torchtext.vocab.Vocab): A vocab object which maps tokens to indices.
130 | embed_vecs (torch.Tensor): The pre-trained word vectors of shape (vocab_size, embed_dim).
131 | log_path (str): Path to the log file. The log file contains the validation
132 | results for each epoch and the test results. If the `log_path` is None, no performance
133 | results will be logged.
134 | checkpoint_path (str): The checkpoint to warm-up with.
135 | """
136 | if 'checkpoint_path' in self.config and self.config.checkpoint_path is not None:
137 | checkpoint_path = self.config.checkpoint_path
138 |
139 | if checkpoint_path is not None:
140 | logging.info(f'Loading model from `{checkpoint_path}`...')
141 | self.model = Model.load_from_checkpoint(checkpoint_path)
142 | else:
143 | logging.info('Initialize model from scratch.')
144 | if self.config.embed_file is not None:
145 | logging.info('Load word dictionary ')
146 | word_dict, embed_vecs = data_utils.load_or_build_text_dict(
147 | dataset=self.datasets['train'],
148 | vocab_file=self.config.vocab_file,
149 | min_vocab_freq=self.config.min_vocab_freq,
150 | embed_file=self.config.embed_file,
151 | silent=self.config.silent,
152 | normalize_embed=self.config.normalize_embed,
153 | embed_cache_dir=self.config.embed_cache_dir
154 | )
155 | if not classes:
156 | classes = data_utils.load_or_build_label(
157 | self.datasets, self.config.label_file, self.config.include_test_labels)
158 |
159 | if self.config.val_metric not in self.config.monitor_metrics:
160 | logging.warn(
161 | f'{self.config.val_metric} is not in `monitor_metrics`. Add {self.config.val_metric} to `monitor_metrics`.')
162 | self.config.monitor_metrics += [self.config.val_metric]
163 |
164 | self.model = init_model(model_name=self.config.model_name,
165 | network_config=dict(self.config.network_config),
166 | classes=classes,
167 | word_dict=word_dict,
168 | embed_vecs=embed_vecs,
169 | init_weight=self.config.init_weight,
170 | log_path=log_path,
171 | learning_rate=self.config.learning_rate,
172 | optimizer=self.config.optimizer,
173 | momentum=self.config.momentum,
174 | weight_decay=self.config.weight_decay,
175 | metric_threshold=self.config.metric_threshold,
176 | monitor_metrics=self.config.monitor_metrics,
177 | silent=self.config.silent,
178 | save_k_predictions=self.config.save_k_predictions,
179 | zero=self.config.zero,
180 | multi_class=self.config.multi_class,
181 | enable_ce_loss=self.config.enable_ce_loss,
182 | hierarchical=self.config.hierarchical
183 | )
184 |
185 | def _get_dataset_loader(self, split, shuffle=False):
186 | """Get dataset loader.
187 |
188 | Args:
189 | split (str): One of 'train', 'test', or 'val'.
190 | shuffle (bool): Whether to shuffle training data before each epoch. Defaults to False.
191 |
192 | Returns:
193 | torch.utils.data.DataLoader: Dataloader for the train, test, or valid dataset.
194 | """
195 | return data_utils.get_dataset_loader(
196 | data=self.datasets[split],
197 | word_dict=self.model.word_dict,
198 | classes=self.model.classes,
199 | device=self.device,
200 | max_seq_length=self.config.max_seq_length,
201 | batch_size=self.config.batch_size if split == 'train' else self.config.eval_batch_size,
202 | shuffle=shuffle,
203 | data_workers=self.config.data_workers,
204 | tokenizer=self.tokenizer,
205 | add_special_tokens=self.config.add_special_tokens,
206 | hierarchical=self.config.hierarchical,
207 | enable_transformer_trainer=self.config.enable_transformer_trainer,
208 | multi_class=self.config.multi_class
209 | )
210 |
211 | def train(self):
212 | """Train model with pytorch lightning trainer. Set model to the best model after the training
213 | process is finished.
214 | """
215 | assert self.trainer is not None, "Please make sure the trainer is successfully initialized by `self._setup_trainer()`."
216 | if self.config.enable_transformer_trainer:
217 | # Training
218 | train_result = self.trainer.train()
219 | metrics = train_result.metrics
220 | metrics['train_samples'] = len(self.train_dataset)
221 | self.trainer.save_model()
222 | self.trainer.log_metrics('train', metrics)
223 | self.trainer.save_metrics('train', metrics)
224 | self.trainer.save_state()
225 | return
226 |
227 | train_loader = self._get_dataset_loader(split='train', shuffle=self.config.shuffle)
228 |
229 | if 'val' not in self.datasets:
230 | logging.info('No validation dataset is provided. Train without vaildation.')
231 | self.trainer.fit(self.model, train_loader)
232 | else:
233 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping
234 | self.trainer.callbacks += [EarlyStopping(patience=self.config.patience,
235 | monitor=self.config.val_metric,
236 | mode='max')] # tentative hard code
237 | val_loader = self._get_dataset_loader(split='val')
238 | self.trainer.fit(self.model, train_loader, val_loader)
239 |
240 | # Set model to the best model. If the validation process is skipped during
241 | # training (i.e., val_size=0), the model is set to the last model.
242 | model_path = self.checkpoint_callback.best_model_path or self.checkpoint_callback.last_model_path
243 | if model_path:
244 | logging.info(f'Finished training. Load best model from {model_path}.')
245 | self._setup_model(checkpoint_path=model_path)
246 | else:
247 | logging.info('No model is saved during training. \
248 | If you want to save the best and the last model, please set `save_checkpoints` to True.')
249 |
250 | def test(self, split='test'):
251 | """Test model with pytorch lightning trainer. Top-k predictions are saved
252 | if `save_k_predictions` > 0.
253 |
254 | Args:
255 | split (str, optional): One of 'train', 'test', or 'val'. Defaults to 'test'.
256 |
257 | Returns:
258 | dict: Scores for all metrics in the dictionary format.
259 | """
260 | assert 'test' in self.datasets and self.trainer is not None
261 |
262 | if self.config.enable_transformer_trainer:
263 | # Validation
264 | metrics = self.trainer.evaluate(eval_dataset=self.val_dataset)
265 | metrics['val_samples'] = len(self.val_dataset)
266 | self.trainer.log_metrics('val', metrics)
267 | self.trainer.save_metrics('val', metrics)
268 | # Testing
269 | predictions, labels, metrics = self.trainer.predict(self.test_dataset, metric_key_prefix='test')
270 | metrics['test_samples'] = len(self.test_dataset)
271 | self.trainer.log_metrics('test', metrics)
272 | self.trainer.save_metrics('test', metrics)
273 | return
274 |
275 | logging.info(f'Testing on {split} set.')
276 | test_loader = self._get_dataset_loader(split=split)
277 | metric_dict = self.trainer.test(self.model, dataloaders=test_loader)[0]
278 |
279 | if self.config.save_k_predictions > 0:
280 | self._save_predictions(test_loader, self.config.predict_out_path)
281 |
282 | return metric_dict
283 |
284 | def _save_predictions(self, dataloader, predict_out_path):
285 | """Save top k label results.
286 |
287 | Args:
288 | dataloader (torch.utils.data.DataLoader): Dataloader for the test or valid dataset.
289 | predict_out_path (str): Path to the an output file holding top k label results.
290 | """
291 | batch_predictions = self.trainer.predict(self.model, dataloaders=dataloader)
292 | pred_labels = np.vstack([batch['top_k_pred']
293 | for batch in batch_predictions])
294 | pred_scores = np.vstack([batch['top_k_pred_scores']
295 | for batch in batch_predictions])
296 | with open(predict_out_path, 'w') as fp:
297 | for pred_label, pred_score in zip(pred_labels, pred_scores):
298 | out_str = ' '.join([f'{self.model.classes[label]}:{score:.4}' for label, score in zip(
299 | pred_label, pred_score)])
300 | fp.write(out_str+'\n')
301 | logging.info(f'Saved predictions to: {predict_out_path}')
302 |
--------------------------------------------------------------------------------