├── .gitignore ├── README.md ├── configure ├── datasets │ ├── Amazon-3M.yaml │ ├── Amazon-670K.yaml │ ├── AmazonCat-13K.yaml │ ├── EUR-Lex.yaml │ ├── Wiki-500K.yaml │ └── Wiki10-31K.yaml └── models │ ├── AttentionXML-AmazonCat-13K.yaml │ ├── AttentionXML-EUR-Lex.yaml │ ├── AttentionXML-Wiki10-31K.yaml │ ├── FastAttentionXML-Amazon-3M.yaml │ ├── FastAttentionXML-Amazon-670K.yaml │ └── FastAttentionXML-Wiki-500K.yaml ├── data └── README.md ├── deepxml ├── __init__.py ├── cluster.py ├── data_utils.py ├── dataset.py ├── evaluation.py ├── models.py ├── modules.py ├── networks.py ├── optimizers.py └── tree.py ├── ensemble.py ├── evaluation.py ├── main.py ├── preprocess.py ├── requirements.txt └── scripts ├── run_amazon.sh ├── run_amazon3m.sh ├── run_amazoncat.sh ├── run_eurlex.sh ├── run_preprocess.sh ├── run_wiki.sh ├── run_wiki10.sh └── run_xml.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### JetBrains template 3 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 4 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 5 | 6 | # User-specific stuff 7 | .idea/**/workspace.xml 8 | .idea/**/tasks.xml 9 | .idea/**/usage.statistics.xml 10 | .idea/**/dictionaries 11 | .idea/**/shelf 12 | 13 | # Sensitive or high-churn files 14 | .idea/**/dataSources/ 15 | .idea/**/dataSources.ids 16 | .idea/**/dataSources.local.xml 17 | .idea/**/sqlDataSources.xml 18 | .idea/**/dynamic.xml 19 | .idea/**/uiDesigner.xml 20 | .idea/**/dbnavigator.xml 21 | 22 | # Gradle 23 | .idea/**/gradle.xml 24 | .idea/**/libraries 25 | 26 | # Gradle and Maven with auto-import 27 | # When using Gradle or Maven with auto-import, you should exclude module files, 28 | # since they will be recreated, and may cause churn. Uncomment if using 29 | # auto-import. 30 | # .idea/modules.xml 31 | # .idea/*.iml 32 | # .idea/modules 33 | 34 | # CMake 35 | cmake-build-*/ 36 | 37 | # Mongo Explorer plugin 38 | .idea/**/mongoSettings.xml 39 | 40 | # File-based project format 41 | *.iws 42 | 43 | # IntelliJ 44 | out/ 45 | 46 | # mpeltonen/sbt-idea plugin 47 | .idea_modules/ 48 | 49 | # JIRA plugin 50 | atlassian-ide-plugin.xml 51 | 52 | # Cursive Clojure plugin 53 | .idea/replstate.xml 54 | 55 | # Crashlytics plugin (for Android Studio and IntelliJ) 56 | com_crashlytics_export_strings.xml 57 | crashlytics.properties 58 | crashlytics-build.properties 59 | fabric.properties 60 | 61 | # Editor-based Rest Client 62 | .idea/httpRequests 63 | ### Windows template 64 | # Windows thumbnail cache files 65 | Thumbs.db 66 | ehthumbs.db 67 | ehthumbs_vista.db 68 | 69 | # Dump file 70 | *.stackdump 71 | 72 | # Folder config file 73 | [Dd]esktop.ini 74 | 75 | # Recycle Bin used on file shares 76 | $RECYCLE.BIN/ 77 | 78 | # Windows Installer files 79 | *.cab 80 | *.msi 81 | *.msix 82 | *.msm 83 | *.msp 84 | 85 | # Windows shortcuts 86 | *.lnk 87 | ### Python template 88 | # Byte-compiled / optimized / DLL files 89 | __pycache__/ 90 | *.py[cod] 91 | *$py.class 92 | 93 | # C extensions 94 | *.so 95 | 96 | # Distribution / packaging 97 | .Python 98 | build/ 99 | develop-eggs/ 100 | dist/ 101 | downloads/ 102 | eggs/ 103 | .eggs/ 104 | lib/ 105 | lib64/ 106 | parts/ 107 | sdist/ 108 | var/ 109 | wheels/ 110 | *.egg-info/ 111 | .installed.cfg 112 | *.egg 113 | MANIFEST 114 | 115 | # PyInstaller 116 | # Usually these files are written by a python script from a template 117 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 118 | *.manifest 119 | *.spec 120 | 121 | # Installer logs 122 | pip-log.txt 123 | pip-delete-this-directory.txt 124 | 125 | # Unit test / coverage reports 126 | htmlcov/ 127 | .tox/ 128 | .coverage 129 | .coverage.* 130 | .cache 131 | nosetests.xml 132 | coverage.xml 133 | *.cover 134 | .hypothesis/ 135 | .pytest_cache/ 136 | 137 | # Translations 138 | *.mo 139 | *.pot 140 | 141 | # Django stuff: 142 | *.log 143 | local_settings.py 144 | db.sqlite3 145 | 146 | # Flask stuff: 147 | instance/ 148 | .webassets-cache 149 | 150 | # Scrapy stuff: 151 | .scrapy 152 | 153 | # Sphinx documentation 154 | docs/_build/ 155 | 156 | # PyBuilder 157 | target/ 158 | 159 | # Jupyter Notebook 160 | .ipynb_checkpoints 161 | 162 | # pyenv 163 | .python-version 164 | 165 | # celery beat schedule file 166 | celerybeat-schedule 167 | 168 | # SageMath parsed files 169 | *.sage.py 170 | 171 | # Environments 172 | .env 173 | .venv 174 | env/ 175 | venv/ 176 | ENV/ 177 | env.bak/ 178 | venv.bak/ 179 | 180 | # Spyder project settings 181 | .spyderproject 182 | .spyproject 183 | 184 | # Rope project settings 185 | .ropeproject 186 | 187 | # mkdocs documentation 188 | /site 189 | 190 | # mypy 191 | .mypy_cache/ 192 | 193 | .idea/ 194 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AttentionXML 2 | [AttentionXML: Label Tree-based Attention-Aware Deep Model for High-Performance Extreme Multi-Label Text Classification](https://arxiv.org/abs/1811.01727) 3 | 4 | ## Requirements 5 | 6 | * python==3.7.4 7 | * click==7.0 8 | * ruamel.yaml==0.16.5 9 | * numpy==1.16.2 10 | * scipy==1.3.1 11 | * scikit-learn==0.21.2 12 | * gensim==3.4.0 13 | * torch==1.0.1 14 | * nltk==3.4 15 | * tqdm==4.31.1 16 | * joblib==0.13.2 17 | * logzero==1.5.0 18 | 19 | ## Datasets 20 | 21 | * [EUR-Lex](https://drive.google.com/open?id=1iPGbr5-z2LogtMFG1rwwekV_aTubvAb2) 22 | * [Wiki10-31K](https://drive.google.com/open?id=1Tv4MHQzDWTUC9hRFihRhG8_jt1h0VhnR) 23 | * [AmazonCat-13K](https://drive.google.com/open?id=1VwHAbri6y6oh8lkpZ6sSY_b1FRNnCLFL) 24 | * [Amazon-670K](https://drive.google.com/open?id=1Xd4BPFy1RPmE7MEXMu77E2_xWOhR1pHW) 25 | * [Wiki-500K](https://drive.google.com/open?id=1bGEcCagh8zaDV0ZNGsgF0QtwjcAm0Afk) 26 | * [Amazon-3M](https://drive.google.com/open?id=187vt5vAkGI2mS2WOMZ2Qv48YKSjNbQv4) 27 | 28 | Download the GloVe embedding (840B,300d) and convert it to gensim format (which can be loaded by **gensim.models.KeyedVectors.load**). 29 | 30 | We also provide a converted GloVe embedding at [here](https://drive.google.com/file/d/10w_HuLklGc8GA_FtUSdnHT8Yo1mxYziP/view?usp=sharing). 31 | 32 | ## XML Experiments 33 | 34 | XML experiments in paper can be run directly such as: 35 | ```bash 36 | ./scripts/run_eurlex.sh 37 | ``` 38 | ## Preprocess 39 | 40 | Run preprocess.py for train and test datasets with tokenized texts as follows: 41 | ```bash 42 | python preprocess.py \ 43 | --text-path data/EUR-Lex/train_texts.txt \ 44 | --label-path data/EUR-Lex/train_labels.txt \ 45 | --vocab-path data/EUR-Lex/vocab.npy \ 46 | --emb-path data/EUR-Lex/emb_init.npy \ 47 | --w2v-model data/glove.840B.300d.gensim 48 | 49 | python preprocess.py \ 50 | --text-path data/EUR-Lex/test_texts.txt \ 51 | --label-path data/EUR-Lex/test_labels.txt \ 52 | --vocab-path data/EUR-Lex/vocab.npy 53 | ``` 54 | 55 | Or run preprocss.py including tokenizing the raw texts by NLTK as follows: 56 | ```bash 57 | python preprocess.py \ 58 | --text-path data/Wiki10-31K/train_raw_texts.txt \ 59 | --tokenized-path data/Wiki10-31K/train_texts.txt \ 60 | --label-path data/Wiki10-31K/train_labels.txt \ 61 | --vocab-path data/Wiki10-31K/vocab.npy \ 62 | --emb-path data/Wiki10-31K/emb_init.npy \ 63 | --w2v-model data/glove.840B.300d.gensim 64 | 65 | python preprocess.py \ 66 | --text-path data/Wiki10-31K/test_raw_texts.txt \ 67 | --tokenized-path data/Wiki10-31K/test_texts.txt \ 68 | --label-path data/Wiki10-31K/test_labels.txt \ 69 | --vocab-path data/Wiki10-31K/vocab.npy 70 | ``` 71 | 72 | 73 | ## Train and Predict 74 | 75 | Train and predict as follows: 76 | ```bash 77 | python main.py --data-cnf configure/datasets/EUR-Lex.yaml --model-cnf configure/models/AttentionXML-EUR-Lex.yaml 78 | ``` 79 | 80 | Or do prediction only with option "--mode eval". 81 | 82 | ## Ensemble 83 | 84 | Train and predict with an ensemble: 85 | ```bash 86 | python main.py --data-cnf configure/datasets/Wiki-500K.yaml --model-cnf configure/models/FastAttentionXML-Wiki-500K.yaml -t 0 87 | python main.py --data-cnf configure/datasets/Wiki-500K.yaml --model-cnf configure/models/FastAttentionXML-Wiki-500K.yaml -t 1 88 | python main.py --data-cnf configure/datasets/Wiki-500K.yaml --model-cnf configure/models/FastAttentionXML-Wiki-500K.yaml -t 2 89 | python ensemble.py -p results/FastAttentionXML-Wiki-500K -t 3 90 | ``` 91 | 92 | ## Evaluation 93 | 94 | ```bash 95 | python evaluation.py --results results/AttentionXML-EUR-Lex-labels.npy --targets data/EUR-Lex/test_labels.npy 96 | ``` 97 | Or get propensity scored metrics together: 98 | 99 | ```bash 100 | python evaluation.py \ 101 | --results results/FastAttentionXML-Amazon-670K-labels.npy \ 102 | --targets data/Amazon-670K/test_labels.npy \ 103 | --train-labels data/Amazon-670K/train_labels.npy \ 104 | -a 0.6 \ 105 | -b 2.6 106 | 107 | ``` 108 | 109 | ## Reference 110 | You et al., [AttentionXML: Label Tree-based Attention-Aware Deep Model for High-Performance Extreme Multi-Label Text Classification](https://arxiv.org/abs/1811.01727), NeurIPS 2019 111 | 112 | ## Declaration 113 | It is free for non-commercial use. For commercial use, please contact Mr. Ronghi You and Prof. Shanfeng Zhu (zhusf@fudan.edu.cn). -------------------------------------------------------------------------------- /configure/datasets/Amazon-3M.yaml: -------------------------------------------------------------------------------- 1 | name: Amazon-3M 2 | 3 | train: 4 | sparse: data/Amazon-3M/train_v1.txt 5 | texts: data/Amazon-3M/train_texts.npy 6 | labels: data/Amazon-3M/train_labels.npy 7 | 8 | valid: 9 | size: 4000 10 | 11 | test: 12 | texts: data/Amazon-3M/test_texts.npy 13 | 14 | embedding: 15 | emb_init: data/Amazon-3M/emb_init.npy 16 | 17 | output: 18 | res: results 19 | 20 | labels_binarizer: data/Amazon-3M/labels_binarizer 21 | 22 | model: 23 | emb_size: 300 24 | -------------------------------------------------------------------------------- /configure/datasets/Amazon-670K.yaml: -------------------------------------------------------------------------------- 1 | name: Amazon-670K 2 | 3 | train: 4 | sparse: data/Amazon-670K/train_v1.txt 5 | texts: data/Amazon-670K/train_texts.npy 6 | labels: data/Amazon-670K/train_labels.npy 7 | 8 | valid: 9 | size: 4000 10 | 11 | test: 12 | texts: data/Amazon-670K/test_texts.npy 13 | 14 | embedding: 15 | emb_init: data/Amazon-670K/emb_init.npy 16 | 17 | output: 18 | res: results 19 | 20 | labels_binarizer: data/Amazon-670K/labels_binarizer 21 | 22 | model: 23 | emb_size: 300 24 | -------------------------------------------------------------------------------- /configure/datasets/AmazonCat-13K.yaml: -------------------------------------------------------------------------------- 1 | name: AmazonCat-13K 2 | 3 | train: 4 | texts: data/AmazonCat-13K/train_texts.npy 5 | labels: data/AmazonCat-13K/train_labels.npy 6 | 7 | valid: 8 | size: 4000 9 | 10 | test: 11 | texts: data/AmazonCat-13K/test_texts.npy 12 | 13 | embedding: 14 | emb_init: data/AmazonCat-13K/emb_init.npy 15 | 16 | output: 17 | res: results 18 | 19 | labels_binarizer: data/AmazonCat-13K/labels_binarizer 20 | 21 | model: 22 | emb_size: 300 23 | -------------------------------------------------------------------------------- /configure/datasets/EUR-Lex.yaml: -------------------------------------------------------------------------------- 1 | name: EUR-Lex 2 | 3 | train: 4 | texts: data/EUR-Lex/train_texts.npy 5 | labels: data/EUR-Lex/train_labels.npy 6 | 7 | valid: 8 | size: 200 9 | 10 | test: 11 | texts: data/EUR-Lex/test_texts.npy 12 | 13 | embedding: 14 | emb_init: data/EUR-Lex/emb_init.npy 15 | 16 | output: 17 | res: results 18 | 19 | labels_binarizer: data/EUR-Lex/labels_binarizer 20 | 21 | model: 22 | emb_size: 300 23 | -------------------------------------------------------------------------------- /configure/datasets/Wiki-500K.yaml: -------------------------------------------------------------------------------- 1 | name: Wiki-500K 2 | 3 | train: 4 | sparse: data/Wiki-500K/train.txt 5 | texts: data/Wiki-500K/train_texts.npy 6 | labels: data/Wiki-500K/train_labels.npy 7 | 8 | valid: 9 | size: 4000 10 | 11 | test: 12 | texts: data/Wiki-500K/test_texts.npy 13 | 14 | embedding: 15 | emb_init: data/Wiki-500K/emb_init.npy 16 | 17 | output: 18 | res: results 19 | 20 | labels_binarizer: data/Wiki-500K/labels_binarizer 21 | 22 | model: 23 | emb_size: 300 24 | -------------------------------------------------------------------------------- /configure/datasets/Wiki10-31K.yaml: -------------------------------------------------------------------------------- 1 | name: Wiki10-31K 2 | 3 | train: 4 | texts: data/Wiki10-31K/train_texts.npy 5 | labels: data/Wiki10-31K/train_labels.npy 6 | 7 | valid: 8 | size: 200 9 | 10 | test: 11 | texts: data/Wiki10-31K/test_texts.npy 12 | 13 | embedding: 14 | emb_init: data/Wiki10-31K/emb_init.npy 15 | 16 | output: 17 | res: results 18 | 19 | labels_binarizer: data/Wiki10-31K/labels_binarizer 20 | 21 | model: 22 | emb_size: 300 23 | -------------------------------------------------------------------------------- /configure/models/AttentionXML-AmazonCat-13K.yaml: -------------------------------------------------------------------------------- 1 | name: AttentionXML 2 | 3 | model: 4 | hidden_size: 512 5 | layers_num: 1 6 | linear_size: [512, 256] 7 | dropout: 0.5 8 | 9 | train: 10 | batch_size: 200 11 | nb_epoch: 10 12 | swa_warmup: 2 13 | 14 | valid: 15 | batch_size: 200 16 | 17 | predict: 18 | batch_size: 200 19 | 20 | path: models 21 | -------------------------------------------------------------------------------- /configure/models/AttentionXML-EUR-Lex.yaml: -------------------------------------------------------------------------------- 1 | name: AttentionXML 2 | 3 | model: 4 | hidden_size: 256 5 | layers_num: 1 6 | linear_size: [256] 7 | dropout: 0.5 8 | emb_trainable: False 9 | 10 | train: 11 | batch_size: 40 12 | nb_epoch: 30 13 | swa_warmup: 10 14 | 15 | valid: 16 | batch_size: 40 17 | 18 | predict: 19 | batch_size: 40 20 | 21 | path: models 22 | -------------------------------------------------------------------------------- /configure/models/AttentionXML-Wiki10-31K.yaml: -------------------------------------------------------------------------------- 1 | name: AttentionXML 2 | 3 | model: 4 | hidden_size: 256 5 | layers_num: 1 6 | linear_size: [256] 7 | dropout: 0.5 8 | emb_trainable: False 9 | 10 | train: 11 | batch_size: 40 12 | nb_epoch: 30 13 | swa_warmup: 4 14 | 15 | valid: 16 | batch_size: 40 17 | 18 | predict: 19 | batch_size: 40 20 | 21 | path: models 22 | -------------------------------------------------------------------------------- /configure/models/FastAttentionXML-Amazon-3M.yaml: -------------------------------------------------------------------------------- 1 | name: FastAttentionXML 2 | 3 | level: 4 4 | k: 8 5 | top: 160 6 | 7 | model: 8 | hidden_size: 512 9 | layers_num: 1 10 | linear_size: [512, 256] 11 | dropout: 0.5 12 | 13 | cluster: 14 | max_leaf: 8 15 | eps: 1e-4 16 | levels: [13, 16, 19] 17 | 18 | 19 | train: 20 | [{batch_size: 200, nb_epoch: 5, swa_warmup: 2}, 21 | {batch_size: 200, nb_epoch: 5, swa_warmup: 1}, 22 | {batch_size: 200, nb_epoch: 5, swa_warmup: 1}, 23 | {batch_size: 200, nb_epoch: 5, swa_warmup: 1}] 24 | 25 | valid: 26 | batch_size: 200 27 | 28 | predict: 29 | batch_size: 200 30 | 31 | path: models 32 | -------------------------------------------------------------------------------- /configure/models/FastAttentionXML-Amazon-670K.yaml: -------------------------------------------------------------------------------- 1 | name: FastAttentionXML 2 | 3 | level: 4 4 | k: 8 5 | top: 160 6 | 7 | model: 8 | hidden_size: 512 9 | layers_num: 1 10 | linear_size: [512, 256] 11 | dropout: 0.5 12 | 13 | cluster: 14 | max_leaf: 8 15 | eps: 1e-4 16 | levels: [11, 14, 17] 17 | 18 | train: 19 | [{batch_size: 200, nb_epoch: 10, swa_warmup: 6}, 20 | {batch_size: 200, nb_epoch: 10, swa_warmup: 2}, 21 | {batch_size: 200, nb_epoch: 10, swa_warmup: 2}, 22 | {batch_size: 200, nb_epoch: 10, swa_warmup: 2}] 23 | 24 | valid: 25 | batch_size: 200 26 | 27 | predict: 28 | batch_size: 200 29 | 30 | path: models 31 | -------------------------------------------------------------------------------- /configure/models/FastAttentionXML-Wiki-500K.yaml: -------------------------------------------------------------------------------- 1 | name: FastAttentionXML 2 | 3 | level: 2 4 | k: 64 5 | top: 15 6 | 7 | model: 8 | hidden_size: 512 9 | layers_num: 1 10 | linear_size: [512, 256] 11 | dropout: 0.5 12 | 13 | cluster: 14 | max_leaf: 64 15 | eps: 1e-4 16 | levels: [13] 17 | 18 | train: 19 | [{batch_size: 200, nb_epoch: 5, swa_warmup: 2}, 20 | {batch_size: 200, nb_epoch: 5, swa_warmup: 1}] 21 | 22 | valid: 23 | batch_size: 200 24 | 25 | predict: 26 | batch_size: 200 27 | 28 | path: models 29 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ## Datasets 2 | 3 | * [EUR-Lex](https://drive.google.com/open?id=1iPGbr5-z2LogtMFG1rwwekV_aTubvAb2) 4 | * [Wiki10-31K](https://drive.google.com/open?id=1Tv4MHQzDWTUC9hRFihRhG8_jt1h0VhnR) 5 | * [AmazonCat-13K](https://drive.google.com/open?id=1VwHAbri6y6oh8lkpZ6sSY_b1FRNnCLFL) 6 | * [Amazon-670K](https://drive.google.com/open?id=1Xd4BPFy1RPmE7MEXMu77E2_xWOhR1pHW) 7 | * [Wiki-500K](https://drive.google.com/open?id=1bGEcCagh8zaDV0ZNGsgF0QtwjcAm0Afk) 8 | * [Amazon-3M](https://drive.google.com/open?id=187vt5vAkGI2mS2WOMZ2Qv48YKSjNbQv4) 9 | 10 | 11 | Download the GloVe embedding (840B,300d) and convert it to gensim format (which can be loaded by **gensim.models.KeyedVectors.load**). 12 | 13 | We also provide a converted GloVe embedding at [here](https://drive.google.com/file/d/10w_HuLklGc8GA_FtUSdnHT8Yo1mxYziP/view?usp=sharing). -------------------------------------------------------------------------------- /deepxml/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 3 | """ 4 | Created on 2018/10/17 5 | @author yrh 6 | 7 | """ -------------------------------------------------------------------------------- /deepxml/cluster.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 3 | """ 4 | Created on 2018/12/24 5 | @author yrh 6 | 7 | """ 8 | 9 | import os 10 | import numpy as np 11 | from scipy.sparse import csr_matrix, csc_matrix 12 | from sklearn.preprocessing import normalize 13 | from logzero import logger 14 | 15 | from deepxml.data_utils import get_sparse_feature 16 | 17 | 18 | __all__ = ['build_tree_by_level'] 19 | 20 | 21 | def build_tree_by_level(sparse_data_x, sparse_data_y, mlb, eps: float, max_leaf: int, levels: list, groups_path): 22 | os.makedirs(os.path.split(groups_path)[0], exist_ok=True) 23 | logger.info('Clustering') 24 | sparse_x, sparse_labels = get_sparse_feature(sparse_data_x, sparse_data_y) 25 | sparse_y = mlb.transform(sparse_labels) 26 | logger.info('Getting Labels Feature') 27 | labels_f = normalize(csr_matrix(sparse_y.T) @ csc_matrix(sparse_x)) 28 | logger.info(F'Start Clustering {levels}') 29 | levels, q = [2**x for x in levels], None 30 | for i in range(len(levels)-1, -1, -1): 31 | if os.path.exists(F'{groups_path}-Level-{i}.npy'): 32 | labels_list = np.load(F'{groups_path}-Level-{i}.npy') 33 | q = [(labels_i, labels_f[labels_i]) for labels_i in labels_list] 34 | break 35 | if q is None: 36 | q = [(np.arange(labels_f.shape[0]), labels_f)] 37 | while q: 38 | labels_list = np.asarray([x[0] for x in q]) 39 | assert sum(len(labels) for labels in labels_list) == labels_f.shape[0] 40 | if len(labels_list) in levels: 41 | level = levels.index(len(labels_list)) 42 | logger.info(F'Finish Clustering Level-{level}') 43 | np.save(F'{groups_path}-Level-{level}.npy', np.asarray(labels_list)) 44 | else: 45 | logger.info(F'Finish Clustering {len(labels_list)}') 46 | next_q = [] 47 | for node_i, node_f in q: 48 | if len(node_i) > max_leaf: 49 | next_q += list(split_node(node_i, node_f, eps)) 50 | q = next_q 51 | logger.info('Finish Clustering') 52 | 53 | 54 | def split_node(labels_i: np.ndarray, labels_f: csr_matrix, eps: float): 55 | n = len(labels_i) 56 | c1, c2 = np.random.choice(np.arange(n), 2, replace=False) 57 | centers, old_dis, new_dis = labels_f[[c1, c2]].toarray(), -10000.0, -1.0 58 | l_labels_i, r_labels_i = None, None 59 | while new_dis - old_dis >= eps: 60 | dis = labels_f @ centers.T # N, 2 61 | partition = np.argsort(dis[:, 1] - dis[:, 0]) 62 | l_labels_i, r_labels_i = partition[:n//2], partition[n//2:] 63 | old_dis, new_dis = new_dis, (dis[l_labels_i, 0].sum() + dis[r_labels_i, 1].sum()) / n 64 | centers = normalize(np.asarray([np.squeeze(np.asarray(labels_f[l_labels_i].sum(axis=0))), 65 | np.squeeze(np.asarray(labels_f[r_labels_i].sum(axis=0)))])) 66 | return (labels_i[l_labels_i], labels_f[l_labels_i]), (labels_i[r_labels_i], labels_f[r_labels_i]) 67 | -------------------------------------------------------------------------------- /deepxml/data_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 3 | """ 4 | Created on 2018/12/9 5 | @author yrh 6 | 7 | """ 8 | 9 | import os 10 | import numpy as np 11 | import joblib 12 | from collections import Counter 13 | from sklearn.preprocessing import MultiLabelBinarizer, normalize 14 | from sklearn.datasets import load_svmlight_file 15 | from gensim.models import KeyedVectors 16 | from tqdm import tqdm 17 | from typing import Union, Iterable 18 | 19 | 20 | __all__ = ['build_vocab', 'get_data', 'convert_to_binary', 'truncate_text', 'get_word_emb', 'get_mlb', 21 | 'get_sparse_feature', 'output_res'] 22 | 23 | 24 | def build_vocab(texts: Iterable, w2v_model: Union[KeyedVectors, str], vocab_size=500000, 25 | pad='', unknown='', sep='/SEP/', max_times=1, freq_times=1): 26 | if isinstance(w2v_model, str): 27 | w2v_model = KeyedVectors.load(w2v_model) 28 | emb_size = w2v_model.vector_size 29 | vocab, emb_init = [pad, unknown], [np.zeros(emb_size), np.random.uniform(-1.0, 1.0, emb_size)] 30 | counter = Counter(token for t in texts for token in set(t.split())) 31 | for word, freq in sorted(counter.items(), key=lambda x: (x[1], x[0] in w2v_model), reverse=True): 32 | if word in w2v_model or freq >= freq_times: 33 | vocab.append(word) 34 | # We used embedding of '.' as embedding of '/SEP/' symbol. 35 | word = '.' if word == sep else word 36 | emb_init.append(w2v_model[word] if word in w2v_model else np.random.uniform(-1.0, 1.0, emb_size)) 37 | if freq < max_times or vocab_size == len(vocab): 38 | break 39 | return np.asarray(vocab), np.asarray(emb_init) 40 | 41 | 42 | def get_word_emb(vec_path, vocab_path=None): 43 | if vocab_path is not None: 44 | with open(vocab_path) as fp: 45 | vocab = {word: idx for idx, word in enumerate(fp)} 46 | return np.load(vec_path), vocab 47 | else: 48 | return np.load(vec_path) 49 | 50 | 51 | def get_data(text_file, label_file=None): 52 | return np.load(text_file), np.load(label_file) if label_file is not None else None 53 | 54 | 55 | def convert_to_binary(text_file, label_file=None, max_len=None, vocab=None, pad='', unknown=''): 56 | with open(text_file) as fp: 57 | texts = np.asarray([[vocab.get(word, vocab[unknown]) for word in line.split()] 58 | for line in tqdm(fp, desc='Converting token to id', leave=False)]) 59 | labels = None 60 | if label_file is not None: 61 | with open(label_file) as fp: 62 | labels = np.asarray([[label for label in line.split()] 63 | for line in tqdm(fp, desc='Converting labels', leave=False)]) 64 | return truncate_text(texts, max_len, vocab[pad], vocab[unknown]), labels 65 | 66 | 67 | def truncate_text(texts, max_len=500, padding_idx=0, unknown_idx=1): 68 | if max_len is None: 69 | return texts 70 | texts = np.asarray([list(x[:max_len]) + [padding_idx] * (max_len - len(x)) for x in texts]) 71 | texts[(texts == padding_idx).all(axis=1), 0] = unknown_idx 72 | return texts 73 | 74 | 75 | def get_mlb(mlb_path, labels=None) -> MultiLabelBinarizer: 76 | if os.path.exists(mlb_path): 77 | return joblib.load(mlb_path) 78 | mlb = MultiLabelBinarizer(sparse_output=True) 79 | mlb.fit(labels) 80 | joblib.dump(mlb, mlb_path) 81 | return mlb 82 | 83 | 84 | def get_sparse_feature(feature_file, label_file): 85 | sparse_x, _ = load_svmlight_file(feature_file, multilabel=True) 86 | return normalize(sparse_x), np.load(label_file) if label_file is not None else None 87 | 88 | 89 | def output_res(output_path, name, scores, labels): 90 | os.makedirs(output_path, exist_ok=True) 91 | np.save(os.path.join(output_path, F'{name}-scores'), scores) 92 | np.save(os.path.join(output_path, F'{name}-labels'), labels) 93 | -------------------------------------------------------------------------------- /deepxml/dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 3 | """ 4 | Created on 2018/12/10 5 | @author yrh 6 | 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import Dataset 12 | from scipy.sparse import csr_matrix 13 | from tqdm import tqdm 14 | from typing import Sequence, Optional, Union 15 | 16 | 17 | __all__ = ['MultiLabelDataset', 'XMLDataset'] 18 | 19 | TDataX = Sequence[Sequence] 20 | TDataY = Optional[csr_matrix] 21 | TCandidate = TGroup = Optional[np.ndarray] 22 | TGroupLabel = TGroupScore = Optional[Union[np.ndarray, torch.Tensor]] 23 | 24 | 25 | class MultiLabelDataset(Dataset): 26 | """ 27 | 28 | """ 29 | def __init__(self, data_x: TDataX, data_y: TDataY = None, training=True): 30 | self.data_x, self.data_y, self.training = data_x, data_y, training 31 | 32 | def __getitem__(self, item): 33 | data_x = self.data_x[item] 34 | if self.training and self.data_y is not None: 35 | data_y = self.data_y[item].toarray().squeeze(0).astype(np.float32) 36 | return data_x, data_y 37 | else: 38 | return data_x 39 | 40 | def __len__(self): 41 | return len(self.data_x) 42 | 43 | 44 | class XMLDataset(MultiLabelDataset): 45 | """ 46 | 47 | """ 48 | def __init__(self, data_x: TDataX, data_y: TDataY = None, training=True, 49 | labels_num=None, candidates: TCandidate = None, candidates_num=None, 50 | groups: TGroup = None, group_labels: TGroupLabel = None, group_scores: TGroupScore = None): 51 | super(XMLDataset, self).__init__(data_x, data_y, training) 52 | self.labels_num, self.candidates, self.candidates_num = labels_num, candidates, candidates_num 53 | self.groups, self.group_labels, self.group_scores = groups, group_labels, group_scores 54 | if self.candidates is None: 55 | self.candidates = [np.concatenate([self.groups[g] for g in group_labels]) 56 | for group_labels in tqdm(self.group_labels, leave=False, desc='Candidates')] 57 | if self.group_scores is not None: 58 | self.candidates_scores = [np.concatenate([[s] * len(self.groups[g]) 59 | for g, s in zip(group_labels, group_scores)]) 60 | for group_labels, group_scores in zip(self.group_labels, self.group_scores)] 61 | else: 62 | self.candidates_scores = [np.ones_like(candidates) for candidates in self.candidates] 63 | if self.candidates_num is None: 64 | self.candidates_num = self.group_labels.shape[1] * max(len(g) for g in groups) 65 | 66 | def __getitem__(self, item): 67 | data_x, candidates = self.data_x[item], np.asarray(self.candidates[item], dtype=np.int) 68 | if self.training and self.data_y is not None: 69 | if len(candidates) < self.candidates_num: 70 | sample = np.random.randint(self.labels_num, size=self.candidates_num - len(candidates)) 71 | candidates = np.concatenate([candidates, sample]) 72 | elif len(candidates) > self.candidates_num: 73 | candidates = np.random.choice(candidates, self.candidates_num, replace=False) 74 | data_y = self.data_y[item, candidates].toarray().squeeze(0).astype(np.float32) 75 | return (data_x, candidates), data_y 76 | else: 77 | scores = self.candidates_scores[item] 78 | if len(candidates) < self.candidates_num: 79 | scores = np.concatenate([scores, [-np.inf] * (self.candidates_num - len(candidates))]) 80 | candidates = np.concatenate([candidates, [self.labels_num] * (self.candidates_num - len(candidates))]) 81 | scores = np.asarray(scores, dtype=np.float32) 82 | return data_x, candidates, scores 83 | 84 | def __len__(self): 85 | return len(self.data_x) 86 | -------------------------------------------------------------------------------- /deepxml/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 3 | """ 4 | Created on 2018/12/9 5 | @author yrh 6 | 7 | """ 8 | 9 | import numpy as np 10 | from functools import partial 11 | from scipy.sparse import csr_matrix 12 | from sklearn.preprocessing import MultiLabelBinarizer 13 | from typing import Union, Optional, List, Iterable, Hashable 14 | 15 | 16 | __all__ = ['get_precision', 'get_p_1', 'get_p_3', 'get_p_5', 'get_p_10', 17 | 'get_ndcg', 'get_n_1', 'get_n_3', 'get_n_5', 'get_n_10', 18 | 'get_inv_propensity', 'get_psp', 19 | 'get_psp_1', 'get_psp_3', 'get_psp_5', 'get_psp_10', 20 | 'get_psndcg_1', 'get_psndcg_3', 'get_psndcg_5', 'get_psndcg_10'] 21 | 22 | TPredict = np.ndarray 23 | TTarget = Union[Iterable[Iterable[Hashable]], csr_matrix] 24 | TMlb = Optional[MultiLabelBinarizer] 25 | TClass = Optional[List[Hashable]] 26 | 27 | 28 | def get_mlb(classes: TClass = None, mlb: TMlb = None, targets: TTarget = None): 29 | if classes is not None: 30 | mlb = MultiLabelBinarizer(classes, sparse_output=True) 31 | if mlb is None and targets is not None: 32 | if isinstance(targets, csr_matrix): 33 | mlb = MultiLabelBinarizer(range(targets.shape[1]), sparse_output=True) 34 | mlb.fit(None) 35 | else: 36 | mlb = MultiLabelBinarizer(sparse_output=True) 37 | mlb.fit(targets) 38 | return mlb 39 | 40 | 41 | def get_precision(prediction: TPredict, targets: TTarget, mlb: TMlb = None, classes: TClass = None, top=5): 42 | mlb = get_mlb(classes, mlb, targets) 43 | if not isinstance(targets, csr_matrix): 44 | targets = mlb.transform(targets) 45 | prediction = mlb.transform(prediction[:, :top]) 46 | return prediction.multiply(targets).sum() / (top * targets.shape[0]) 47 | 48 | 49 | get_p_1 = partial(get_precision, top=1) 50 | get_p_3 = partial(get_precision, top=3) 51 | get_p_5 = partial(get_precision, top=5) 52 | get_p_10 = partial(get_precision, top=10) 53 | 54 | 55 | def get_ndcg(prediction: TPredict, targets: TTarget, mlb: TMlb = None, classes: TClass = None, top=5): 56 | mlb = get_mlb(classes, mlb, targets) 57 | log = 1.0 / np.log2(np.arange(top) + 2) 58 | dcg = np.zeros((targets.shape[0], 1)) 59 | if not isinstance(targets, csr_matrix): 60 | targets = mlb.transform(targets) 61 | for i in range(top): 62 | p = mlb.transform(prediction[:, i: i+1]) 63 | dcg += p.multiply(targets).sum(axis=-1) * log[i] 64 | return np.average(dcg / log.cumsum()[np.minimum(targets.sum(axis=-1), top) - 1]) 65 | 66 | 67 | get_n_1 = partial(get_ndcg, top=1) 68 | get_n_3 = partial(get_ndcg, top=3) 69 | get_n_5 = partial(get_ndcg, top=5) 70 | get_n_10 = partial(get_ndcg, top=10) 71 | 72 | 73 | def get_inv_propensity(train_y: csr_matrix, a=0.55, b=1.5): 74 | n, number = train_y.shape[0], np.asarray(train_y.sum(axis=0)).squeeze() 75 | c = (np.log(n) - 1) * ((b + 1) ** a) 76 | return 1.0 + c * (number + b) ** (-a) 77 | 78 | 79 | def get_psp(prediction: TPredict, targets: TTarget, inv_w: np.ndarray, mlb: TMlb = None, 80 | classes: TClass = None, top=5): 81 | mlb = get_mlb(classes, mlb) 82 | if not isinstance(targets, csr_matrix): 83 | targets = mlb.transform(targets) 84 | prediction = mlb.transform(prediction[:, :top]).multiply(inv_w) 85 | num = prediction.multiply(targets).sum() 86 | t, den = csr_matrix(targets.multiply(inv_w)), 0 87 | for i in range(t.shape[0]): 88 | den += np.sum(np.sort(t.getrow(i).data)[-top:]) 89 | return num / den 90 | 91 | 92 | get_psp_1 = partial(get_psp, top=1) 93 | get_psp_3 = partial(get_psp, top=3) 94 | get_psp_5 = partial(get_psp, top=5) 95 | get_psp_10 = partial(get_psp, top=10) 96 | 97 | 98 | def get_psndcg(prediction: TPredict, targets: TTarget, inv_w: np.ndarray, mlb: TMlb = None, 99 | classes: TClass = None, top=5): 100 | mlb = get_mlb(classes, mlb) 101 | log = 1.0 / np.log2(np.arange(top) + 2) 102 | psdcg = 0.0 103 | if not isinstance(targets, csr_matrix): 104 | targets = mlb.transform(targets) 105 | for i in range(top): 106 | p = mlb.transform(prediction[:, i: i+1]).multiply(inv_w) 107 | psdcg += p.multiply(targets).sum() * log[i] 108 | t, den = csr_matrix(targets.multiply(inv_w)), 0.0 109 | for i in range(t.shape[0]): 110 | num = min(top, len(t.getrow(i).data)) 111 | den += -np.sum(np.sort(-t.getrow(i).data)[:num] * log[:num]) 112 | return psdcg / den 113 | 114 | 115 | get_psndcg_1 = partial(get_psndcg, top=1) 116 | get_psndcg_3 = partial(get_psndcg, top=3) 117 | get_psndcg_5 = partial(get_psndcg, top=5) 118 | get_psndcg_10 = partial(get_psndcg, top=10) 119 | -------------------------------------------------------------------------------- /deepxml/models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 3 | """ 4 | Created on 2018/12/9 5 | @author yrh 6 | 7 | """ 8 | 9 | import os 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | from collections import deque 14 | from torch.utils.data import DataLoader 15 | from tqdm import tqdm 16 | from logzero import logger 17 | from typing import Optional, Mapping, Tuple 18 | 19 | from deepxml.evaluation import get_p_5, get_n_5 20 | from deepxml.modules import * 21 | from deepxml.optimizers import * 22 | 23 | 24 | __all__ = ['Model', 'XMLModel'] 25 | 26 | 27 | class Model(object): 28 | """ 29 | 30 | """ 31 | def __init__(self, network, model_path, gradient_clip_value=5.0, device_ids=None, **kwargs): 32 | self.model = nn.DataParallel(network(**kwargs).cuda(), device_ids=device_ids) 33 | self.loss_fn = nn.BCEWithLogitsLoss() 34 | self.model_path, self.state = model_path, {} 35 | os.makedirs(os.path.split(self.model_path)[0], exist_ok=True) 36 | self.gradient_clip_value, self.gradient_norm_queue = gradient_clip_value, deque([np.inf], maxlen=5) 37 | self.optimizer = None 38 | 39 | def train_step(self, train_x: torch.Tensor, train_y: torch.Tensor): 40 | self.optimizer.zero_grad() 41 | self.model.train() 42 | scores = self.model(train_x) 43 | loss = self.loss_fn(scores, train_y) 44 | loss.backward() 45 | self.clip_gradient() 46 | self.optimizer.step(closure=None) 47 | return loss.item() 48 | 49 | def predict_step(self, data_x: torch.Tensor, k: int): 50 | self.model.eval() 51 | with torch.no_grad(): 52 | scores, labels = torch.topk(self.model(data_x), k) 53 | return torch.sigmoid(scores).cpu(), labels.cpu() 54 | 55 | def get_optimizer(self, **kwargs): 56 | self.optimizer = DenseSparseAdam(self.model.parameters(), **kwargs) 57 | 58 | def train(self, train_loader: DataLoader, valid_loader: DataLoader, opt_params: Optional[Mapping] = None, 59 | nb_epoch=100, step=100, k=5, early=50, verbose=True, swa_warmup=None, **kwargs): 60 | self.get_optimizer(**({} if opt_params is None else opt_params)) 61 | global_step, best_n5, e = 0, 0.0, 0 62 | for epoch_idx in range(nb_epoch): 63 | if epoch_idx == swa_warmup: 64 | self.swa_init() 65 | for i, (train_x, train_y) in enumerate(train_loader, 1): 66 | global_step += 1 67 | loss = self.train_step(train_x, train_y.cuda()) 68 | if global_step % step == 0: 69 | self.swa_step() 70 | self.swap_swa_params() 71 | labels = np.concatenate([self.predict_step(valid_x, k)[1] for valid_x in valid_loader]) 72 | targets = valid_loader.dataset.data_y 73 | p5, n5 = get_p_5(labels, targets), get_n_5(labels, targets) 74 | if n5 > best_n5: 75 | self.save_model() 76 | best_n5, e = n5, 0 77 | else: 78 | e += 1 79 | if early is not None and e > early: 80 | return 81 | self.swap_swa_params() 82 | if verbose: 83 | logger.info(F'{epoch_idx} {i * train_loader.batch_size} train loss: {round(loss, 5)} ' 84 | F'P@5: {round(p5, 5)} nDCG@5: {round(n5, 5)} early stop: {e}') 85 | 86 | def predict(self, data_loader: DataLoader, k=100, desc='Predict', **kwargs): 87 | self.load_model() 88 | scores_list, labels_list = zip(*(self.predict_step(data_x, k) 89 | for data_x in tqdm(data_loader, desc=desc, leave=False))) 90 | return np.concatenate(scores_list), np.concatenate(labels_list) 91 | 92 | def save_model(self): 93 | torch.save(self.model.module.state_dict(), self.model_path) 94 | 95 | def load_model(self): 96 | self.model.module.load_state_dict(torch.load(self.model_path)) 97 | 98 | def clip_gradient(self): 99 | if self.gradient_clip_value is not None: 100 | max_norm = max(self.gradient_norm_queue) 101 | total_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm * self.gradient_clip_value) 102 | self.gradient_norm_queue.append(min(total_norm, max_norm * 2.0, 1.0)) 103 | if total_norm > max_norm * self.gradient_clip_value: 104 | logger.warn(F'Clipping gradients with total norm {round(total_norm, 5)} ' 105 | F'and max norm {round(max_norm, 5)}') 106 | 107 | def swa_init(self): 108 | if 'swa' not in self.state: 109 | logger.info('SWA Initializing') 110 | swa_state = self.state['swa'] = {'models_num': 1} 111 | for n, p in self.model.named_parameters(): 112 | swa_state[n] = p.data.clone().detach() 113 | 114 | def swa_step(self): 115 | if 'swa' in self.state: 116 | swa_state = self.state['swa'] 117 | swa_state['models_num'] += 1 118 | beta = 1.0 / swa_state['models_num'] 119 | with torch.no_grad(): 120 | for n, p in self.model.named_parameters(): 121 | swa_state[n].mul_(1.0 - beta).add_(beta, p.data) 122 | 123 | def swap_swa_params(self): 124 | if 'swa' in self.state: 125 | swa_state = self.state['swa'] 126 | for n, p in self.model.named_parameters(): 127 | p.data, swa_state[n] = swa_state[n], p.data 128 | 129 | def disable_swa(self): 130 | if 'swa' in self.state: 131 | del self.state['swa'] 132 | 133 | 134 | class XMLModel(Model): 135 | """ 136 | 137 | """ 138 | def __init__(self, labels_num, hidden_size, device_ids=None, attn_device_ids=None, 139 | most_labels_parallel_attn=80000, **kwargs): 140 | parallel_attn = labels_num <= most_labels_parallel_attn 141 | super(XMLModel, self).__init__(hidden_size=hidden_size, device_ids=device_ids, labels_num=labels_num, 142 | parallel_attn=parallel_attn, **kwargs) 143 | self.network, self.attn_weights = self.model, nn.Sequential() 144 | if not parallel_attn: 145 | self.attn_weights = AttentionWeights(labels_num, hidden_size*2, attn_device_ids) 146 | self.model = nn.ModuleDict({'Network': self.network.module, 'AttentionWeights': self.attn_weights}) 147 | self.state['best'] = {} 148 | 149 | def train_step(self, train_x: Tuple[torch.Tensor, torch.Tensor], train_y: torch.Tensor): 150 | self.optimizer.zero_grad() 151 | train_x, candidates = train_x 152 | self.model.train() 153 | scores = self.network(train_x, candidates=candidates, attn_weights=self.attn_weights) 154 | loss = self.loss_fn(scores, train_y) 155 | loss.backward() 156 | self.clip_gradient() 157 | self.optimizer.step(closure=None) 158 | return loss.item() 159 | 160 | def predict_step(self, data_x: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], k): 161 | data_x, candidates, group_scores = data_x 162 | self.model.eval() 163 | with torch.no_grad(): 164 | scores = torch.sigmoid(self.network(data_x, candidates=candidates, attn_weights=self.attn_weights)) 165 | scores, labels = torch.topk(scores * group_scores.cuda(), k) 166 | return scores.cpu(), candidates[np.arange(len(data_x)).reshape(-1, 1), labels.cpu()] 167 | 168 | def train(self, *args, **kwargs): 169 | super(XMLModel, self).train(*args, **kwargs) 170 | self.save_model_to_disk() 171 | 172 | def save_model(self): 173 | model_dict = self.model.state_dict() 174 | for key in model_dict: 175 | self.state['best'][key] = model_dict[key].cpu().detach() 176 | 177 | def save_model_to_disk(self): 178 | model_dict = self.model.state_dict() 179 | for key in model_dict: 180 | model_dict[key][:] = self.state['best'][key] 181 | torch.save(self.model.state_dict(), self.model_path) 182 | 183 | def load_model(self): 184 | self.model.load_state_dict(torch.load(self.model_path)) 185 | -------------------------------------------------------------------------------- /deepxml/modules.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 3 | """ 4 | Created on 2018/12/29 5 | @author yrh 6 | 7 | """ 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | __all__ = ['Embedding', 'LSTMEncoder', 'MLAttention', 'AttentionWeights', 'FastMLAttention', 'MLLinear'] 16 | 17 | 18 | class Embedding(nn.Module): 19 | """ 20 | 21 | """ 22 | def __init__(self, vocab_size=None, emb_size=None, emb_init=None, emb_trainable=True, padding_idx=0, dropout=0.2): 23 | super(Embedding, self).__init__() 24 | if emb_init is not None: 25 | if vocab_size is not None: 26 | assert vocab_size == emb_init.shape[0] 27 | if emb_size is not None: 28 | assert emb_size == emb_init.shape[1] 29 | vocab_size, emb_size = emb_init.shape 30 | self.emb = nn.Embedding(vocab_size, emb_size, padding_idx=padding_idx, sparse=True, 31 | _weight=torch.from_numpy(emb_init).float() if emb_init is not None else None) 32 | self.emb.weight.requires_grad = emb_trainable 33 | self.dropout = nn.Dropout(dropout) 34 | self.padding_idx = padding_idx 35 | 36 | def forward(self, inputs): 37 | emb_out = self.dropout(self.emb(inputs)) 38 | lengths, masks = (inputs != self.padding_idx).sum(dim=-1), inputs != self.padding_idx 39 | return emb_out[:, :lengths.max()], lengths, masks[:, :lengths.max()] 40 | 41 | 42 | class LSTMEncoder(nn.Module): 43 | """ 44 | 45 | """ 46 | def __init__(self, input_size, hidden_size, layers_num, dropout): 47 | super(LSTMEncoder, self).__init__() 48 | self.lstm = nn.LSTM(input_size, hidden_size, layers_num, batch_first=True, bidirectional=True) 49 | self.init_state = nn.Parameter(torch.zeros(2*2*layers_num, 1, hidden_size)) 50 | self.dropout = nn.Dropout(dropout) 51 | 52 | def forward(self, inputs, lengths, **kwargs): 53 | self.lstm.flatten_parameters() 54 | init_state = self.init_state.repeat([1, inputs.size(0), 1]) 55 | cell_init, hidden_init = init_state[:init_state.size(0)//2], init_state[init_state.size(0)//2:] 56 | idx = torch.argsort(lengths, descending=True) 57 | packed_inputs = nn.utils.rnn.pack_padded_sequence(inputs[idx], lengths[idx], batch_first=True) 58 | outputs, _ = nn.utils.rnn.pad_packed_sequence( 59 | self.lstm(packed_inputs, (hidden_init, cell_init))[0], batch_first=True) 60 | return self.dropout(outputs[torch.argsort(idx)]) 61 | 62 | 63 | class MLAttention(nn.Module): 64 | """ 65 | 66 | """ 67 | def __init__(self, labels_num, hidden_size): 68 | super(MLAttention, self).__init__() 69 | self.attention = nn.Linear(hidden_size, labels_num, bias=False) 70 | nn.init.xavier_uniform_(self.attention.weight) 71 | 72 | def forward(self, inputs, masks): 73 | masks = torch.unsqueeze(masks, 1) # N, 1, L 74 | attention = self.attention(inputs).transpose(1, 2).masked_fill(1.0 - masks, -np.inf) # N, labels_num, L 75 | attention = F.softmax(attention, -1) 76 | return attention @ inputs # N, labels_num, hidden_size 77 | 78 | 79 | class AttentionWeights(nn.Module): 80 | """ 81 | 82 | """ 83 | def __init__(self, labels_num, hidden_size, device_ids=None): 84 | super(AttentionWeights, self).__init__() 85 | if device_ids is None: 86 | device_ids = list(range(1, torch.cuda.device_count())) 87 | assert labels_num >= len(device_ids) 88 | group_size, plus_num = labels_num // len(device_ids), labels_num % len(device_ids) 89 | self.group = [group_size + 1] * plus_num + [group_size] * (len(device_ids) - plus_num) 90 | assert sum(self.group) == labels_num 91 | self.emb = nn.ModuleList(nn.Embedding(size, hidden_size, sparse=True).cuda(device_ids[i]) 92 | for i, size in enumerate(self.group)) 93 | std = (6.0 / (labels_num + hidden_size)) ** 0.5 94 | with torch.no_grad(): 95 | for emb in self.emb: 96 | emb.weight.data.uniform_(-std, std) 97 | self.group_offset, self.hidden_size = np.cumsum([0] + self.group), hidden_size 98 | 99 | def forward(self, inputs: torch.Tensor): 100 | outputs = torch.zeros(*inputs.size(), self.hidden_size, device=inputs.device) 101 | for left, right, emb in zip(self.group_offset[:-1], self.group_offset[1:], self.emb): 102 | index = (left <= inputs) & (inputs < right) 103 | group_inputs = (inputs[index] - left).to(emb.weight.device) 104 | outputs[index] = emb(group_inputs).to(inputs.device) 105 | return outputs 106 | 107 | 108 | class FastMLAttention(nn.Module): 109 | """ 110 | 111 | """ 112 | def __init__(self, labels_num, hidden_size, parallel_attn=False): 113 | super(FastMLAttention, self).__init__() 114 | if parallel_attn: 115 | self.attention = nn.Embedding(labels_num + 1, hidden_size, sparse=True) 116 | nn.init.xavier_uniform_(self.attention.weight) 117 | 118 | def forward(self, inputs, masks, candidates, attn_weights: nn.Module): 119 | masks = torch.unsqueeze(masks, 1) # N, 1, L 120 | attn_inputs = inputs.transpose(1, 2) # N, hidden, L 121 | attn_weights = self.attention(candidates) if hasattr(self, 'attention') else attn_weights(candidates) 122 | attention = (attn_weights @ attn_inputs).masked_fill(1.0 - masks, -np.inf) # N, sampled_size, L 123 | attention = F.softmax(attention, -1) # N, sampled_size, L 124 | return attention @ inputs # N, sampled_size, hidden_size 125 | 126 | 127 | class MLLinear(nn.Module): 128 | """ 129 | 130 | """ 131 | def __init__(self, linear_size, output_size): 132 | super(MLLinear, self).__init__() 133 | self.linear = nn.ModuleList(nn.Linear(in_s, out_s) 134 | for in_s, out_s in zip(linear_size[:-1], linear_size[1:])) 135 | for linear in self.linear: 136 | nn.init.xavier_uniform_(linear.weight) 137 | self.output = nn.Linear(linear_size[-1], output_size) 138 | nn.init.xavier_uniform_(self.output.weight) 139 | 140 | def forward(self, inputs): 141 | linear_out = inputs 142 | for linear in self.linear: 143 | linear_out = F.relu(linear(linear_out)) 144 | return torch.squeeze(self.output(linear_out), -1) 145 | -------------------------------------------------------------------------------- /deepxml/networks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 3 | """ 4 | Created on 2018/12/9 5 | @author yrh 6 | 7 | """ 8 | 9 | import torch.nn as nn 10 | 11 | from deepxml.modules import * 12 | 13 | 14 | __all__ = ['AttentionRNN', 'FastAttentionRNN'] 15 | 16 | 17 | class Network(nn.Module): 18 | """ 19 | 20 | """ 21 | def __init__(self, emb_size, vocab_size=None, emb_init=None, emb_trainable=True, padding_idx=0, emb_dropout=0.2, 22 | **kwargs): 23 | super(Network, self).__init__() 24 | self.emb = Embedding(vocab_size, emb_size, emb_init, emb_trainable, padding_idx, emb_dropout) 25 | 26 | def forward(self, *args, **kwargs): 27 | raise NotImplementedError 28 | 29 | 30 | class AttentionRNN(Network): 31 | """ 32 | 33 | """ 34 | def __init__(self, labels_num, emb_size, hidden_size, layers_num, linear_size, dropout, **kwargs): 35 | super(AttentionRNN, self).__init__(emb_size, **kwargs) 36 | self.lstm = LSTMEncoder(emb_size, hidden_size, layers_num, dropout) 37 | self.attention = MLAttention(labels_num, hidden_size * 2) 38 | self.linear = MLLinear([hidden_size * 2] + linear_size, 1) 39 | 40 | def forward(self, inputs, **kwargs): 41 | emb_out, lengths, masks = self.emb(inputs, **kwargs) 42 | rnn_out = self.lstm(emb_out, lengths) # N, L, hidden_size * 2 43 | attn_out = self.attention(rnn_out, masks) # N, labels_num, hidden_size * 2 44 | return self.linear(attn_out) 45 | 46 | 47 | class FastAttentionRNN(Network): 48 | """ 49 | 50 | """ 51 | def __init__(self, labels_num, emb_size, hidden_size, layers_num, linear_size, dropout, parallel_attn, **kwargs): 52 | super(FastAttentionRNN, self).__init__(emb_size, **kwargs) 53 | self.lstm = LSTMEncoder(emb_size, hidden_size, layers_num, dropout) 54 | self.attention = FastMLAttention(labels_num, hidden_size * 2, parallel_attn) 55 | self.linear = MLLinear([hidden_size * 2] + linear_size, 1) 56 | 57 | def forward(self, inputs, candidates, attn_weights: nn.Module, **kwargs): 58 | emb_out, lengths, masks = self.emb(inputs, **kwargs) 59 | rnn_out = self.lstm(emb_out, lengths) # N, L, hidden_size * 2 60 | attn_out = self.attention(rnn_out, masks, candidates, attn_weights) # N, sampled_size, hidden_size * 2 61 | return self.linear(attn_out) 62 | -------------------------------------------------------------------------------- /deepxml/optimizers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 3 | """ 4 | Created on 2019/3/7 5 | @author yrh 6 | 7 | """ 8 | 9 | import math 10 | import torch 11 | from torch.optim.optimizer import Optimizer 12 | 13 | 14 | __all__ = ['DenseSparseAdam'] 15 | 16 | 17 | class DenseSparseAdam(Optimizer): 18 | """ 19 | 20 | """ 21 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0): 22 | if not 0.0 <= lr: 23 | raise ValueError("Invalid learning rate: {}".format(lr)) 24 | if not 0.0 <= eps: 25 | raise ValueError("Invalid epsilon value: {}".format(eps)) 26 | if not 0.0 <= betas[0] < 1.0: 27 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 28 | if not 0.0 <= betas[1] < 1.0: 29 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 30 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 31 | super(DenseSparseAdam, self).__init__(params, defaults) 32 | 33 | def step(self, closure=None): 34 | """ 35 | Performs a single optimization step. 36 | 37 | Parameters 38 | ---------- 39 | closure : ``callable``, optional. 40 | A closure that reevaluates the model and returns the loss. 41 | """ 42 | loss = None 43 | if closure is not None: 44 | loss = closure() 45 | 46 | for group in self.param_groups: 47 | for p in group['params']: 48 | if p.grad is None: 49 | continue 50 | grad = p.grad.data 51 | 52 | state = self.state[p] 53 | 54 | # State initialization 55 | if 'step' not in state: 56 | state['step'] = 0 57 | if 'exp_avg' not in state: 58 | # Exponential moving average of gradient values 59 | state['exp_avg'] = torch.zeros_like(p.data) 60 | if 'exp_avg_sq' not in state: 61 | # Exponential moving average of squared gradient values 62 | state['exp_avg_sq'] = torch.zeros_like(p.data) 63 | 64 | state['step'] += 1 65 | 66 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 67 | beta1, beta2 = group['betas'] 68 | 69 | weight_decay = group['weight_decay'] 70 | 71 | if grad.is_sparse: 72 | grad = grad.coalesce() # the update is non-linear so indices must be unique 73 | grad_indices = grad._indices() 74 | grad_values = grad._values() 75 | size = grad.size() 76 | 77 | def make_sparse(values): 78 | constructor = grad.new 79 | if grad_indices.dim() == 0 or values.dim() == 0: 80 | return constructor().resize_as_(grad) 81 | return constructor(grad_indices, values, size) 82 | 83 | # Decay the first and second moment running average coefficient 84 | # old <- b * old + (1 - b) * new 85 | # <==> old += (1 - b) * (new - old) 86 | old_exp_avg_values = exp_avg.sparse_mask(grad)._values() 87 | exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1) 88 | exp_avg.add_(make_sparse(exp_avg_update_values)) 89 | old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values() 90 | exp_avg_sq_update_values = grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2) 91 | exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values)) 92 | 93 | # Dense addition again is intended, avoiding another sparse_mask 94 | numer = exp_avg_update_values.add_(old_exp_avg_values) 95 | exp_avg_sq_update_values.add_(old_exp_avg_sq_values) 96 | denom = exp_avg_sq_update_values.sqrt_().add_(group['eps']) 97 | del exp_avg_update_values, exp_avg_sq_update_values 98 | 99 | bias_correction1 = 1 - beta1 ** state['step'] 100 | bias_correction2 = 1 - beta2 ** state['step'] 101 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 102 | 103 | p.data.add_(make_sparse(-step_size * numer.div_(denom))) 104 | if weight_decay > 0.0: 105 | p.data.add_(-group['lr'] * weight_decay, p.data.sparse_mask(grad)) 106 | else: 107 | # Decay the first and second moment running average coefficient 108 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 109 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 110 | denom = exp_avg_sq.sqrt().add_(group['eps']) 111 | 112 | bias_correction1 = 1 - beta1 ** state['step'] 113 | bias_correction2 = 1 - beta2 ** state['step'] 114 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 115 | 116 | p.data.addcdiv_(-step_size, exp_avg, denom) 117 | if weight_decay > 0.0: 118 | p.data.add_(-group['lr'] * weight_decay, p.data) 119 | 120 | return loss 121 | -------------------------------------------------------------------------------- /deepxml/tree.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 3 | """ 4 | Created on 2019/2/26 5 | @author yrh 6 | 7 | """ 8 | 9 | import os 10 | import time 11 | import numpy as np 12 | import torch 13 | from multiprocessing import Process 14 | from scipy.sparse import csr_matrix 15 | from torch.utils.data import DataLoader 16 | from tqdm import tqdm 17 | from logzero import logger 18 | 19 | from deepxml.data_utils import get_word_emb 20 | from deepxml.dataset import MultiLabelDataset, XMLDataset 21 | from deepxml.models import Model, XMLModel 22 | from deepxml.cluster import build_tree_by_level 23 | from deepxml.networks import * 24 | 25 | 26 | __all__ = ['FastAttentionXML'] 27 | 28 | 29 | class FastAttentionXML(object): 30 | """ 31 | 32 | """ 33 | def __init__(self, labels_num, data_cnf, model_cnf, tree_id=''): 34 | self.data_cnf, self.model_cnf = data_cnf.copy(), model_cnf.copy() 35 | model_name, data_name = model_cnf['name'], data_cnf['name'] 36 | self.model_path = os.path.join(model_cnf['path'], F'{model_name}-{data_name}{tree_id}') 37 | self.emb_init, self.level = get_word_emb(data_cnf['embedding']['emb_init']), model_cnf['level'] 38 | self.labels_num, self.models = labels_num, {} 39 | self.inter_group_size, self.top = model_cnf['k'], model_cnf['top'] 40 | self.groups_path = os.path.join(model_cnf['path'], F'{model_name}-{data_name}{tree_id}-cluster') 41 | 42 | @staticmethod 43 | def get_mapping_y(groups, labels_num, *args): 44 | mapping = np.empty(labels_num + 1, dtype=np.long) 45 | for idx, labels_list in enumerate(groups): 46 | mapping[labels_list] = idx 47 | mapping[labels_num] = len(groups) 48 | return (FastAttentionXML.get_group_y(mapping, y, len(groups)) for y in args) 49 | 50 | @staticmethod 51 | def get_group_y(mapping: np.ndarray, data_y: csr_matrix, groups_num): 52 | r, c, d = [], [], [] 53 | for i in range(data_y.shape[0]): 54 | g = np.unique(mapping[data_y.indices[data_y.indptr[i]: data_y.indptr[i + 1]]]) 55 | r += [i] * len(g) 56 | c += g.tolist() 57 | d += [1] * len(g) 58 | return csr_matrix((d, (r, c)), shape=(data_y.shape[0], groups_num)) 59 | 60 | def train_level(self, level, train_x, train_y, valid_x, valid_y): 61 | model_cnf, data_cnf = self.model_cnf, self.data_cnf 62 | if level == 0: 63 | while not os.path.exists(F'{self.groups_path}-Level-{level}.npy'): 64 | time.sleep(30) 65 | groups = np.load(F'{self.groups_path}-Level-{level}.npy') 66 | train_y, valid_y = self.get_mapping_y(groups, self.labels_num, train_y, valid_y) 67 | labels_num = len(groups) 68 | train_loader = DataLoader(MultiLabelDataset(train_x, train_y), 69 | model_cnf['train'][level]['batch_size'], num_workers=4, shuffle=True) 70 | valid_loader = DataLoader(MultiLabelDataset(valid_x, valid_y, training=False), 71 | model_cnf['valid']['batch_size'], num_workers=4) 72 | model = Model(AttentionRNN, labels_num=labels_num, model_path=F'{self.model_path}-Level-{level}', 73 | emb_init=self.emb_init, **data_cnf['model'], **model_cnf['model']) 74 | if not os.path.exists(model.model_path): 75 | logger.info(F'Training Level-{level}, Number of Labels: {labels_num}') 76 | model.train(train_loader, valid_loader, **model_cnf['train'][level]) 77 | model.optimizer = None 78 | logger.info(F'Finish Training Level-{level}') 79 | self.models[level] = model 80 | logger.info(F'Generating Candidates for Level-{level+1}, ' 81 | F'Number of Labels: {labels_num}, Top: {self.top}') 82 | train_loader = DataLoader(MultiLabelDataset(train_x), model_cnf['valid']['batch_size'], num_workers=4) 83 | return train_y, model.predict(train_loader, k=self.top), model.predict(valid_loader, k=self.top) 84 | else: 85 | train_group_y, train_group, valid_group = self.train_level(level - 1, train_x, train_y, valid_x, valid_y) 86 | torch.cuda.empty_cache() 87 | 88 | logger.info('Getting Candidates') 89 | _, group_labels = train_group 90 | group_candidates = np.empty((len(train_x), self.top), dtype=np.int) 91 | for i, labels in tqdm(enumerate(group_labels), leave=False, desc='Parents'): 92 | ys, ye = train_group_y.indptr[i], train_group_y.indptr[i + 1] 93 | positive = set(train_group_y.indices[ys: ye]) 94 | if self.top >= len(positive): 95 | candidates = positive 96 | for la in labels: 97 | if len(candidates) == self.top: 98 | break 99 | if la not in candidates: 100 | candidates.add(la) 101 | else: 102 | candidates = set() 103 | for la in labels: 104 | if la in positive: 105 | candidates.add(la) 106 | if len(candidates) == self.top: 107 | break 108 | if len(candidates) < self.top: 109 | candidates = (list(candidates) + list(positive - candidates))[:self.top] 110 | group_candidates[i] = np.asarray(list(candidates)) 111 | 112 | if level < self.level - 1: 113 | while not os.path.exists(F'{self.groups_path}-Level-{level}.npy'): 114 | time.sleep(30) 115 | groups = np.load(F'{self.groups_path}-Level-{level}.npy') 116 | train_y, valid_y = self.get_mapping_y(groups, self.labels_num, train_y, valid_y) 117 | labels_num, last_groups = len(groups), self.get_inter_groups(len(groups)) 118 | else: 119 | groups, labels_num = None, train_y.shape[1] 120 | last_groups = np.load(F'{self.groups_path}-Level-{level-1}.npy') 121 | 122 | train_loader = DataLoader(XMLDataset(train_x, train_y, labels_num=labels_num, 123 | groups=last_groups, group_labels=group_candidates), 124 | model_cnf['train'][level]['batch_size'], num_workers=4, shuffle=True) 125 | group_scores, group_labels = valid_group 126 | valid_loader = DataLoader(XMLDataset(valid_x, valid_y, training=False, labels_num=labels_num, 127 | groups=last_groups, group_labels=group_labels, 128 | group_scores=group_scores), 129 | model_cnf['valid']['batch_size'], num_workers=4) 130 | model = XMLModel(network=FastAttentionRNN, labels_num=labels_num, emb_init=self.emb_init, 131 | model_path=F'{self.model_path}-Level-{level}', **data_cnf['model'], **model_cnf['model']) 132 | if not os.path.exists(model.model_path): 133 | logger.info(F'Loading parameters of Level-{level} from Level-{level-1}') 134 | last_model = self.get_last_models(level - 1) 135 | model.network.module.emb.load_state_dict(last_model.module.emb.state_dict()) 136 | model.network.module.lstm.load_state_dict(last_model.module.lstm.state_dict()) 137 | model.network.module.linear.load_state_dict(last_model.module.linear.state_dict()) 138 | logger.info(F'Training Level-{level}, ' 139 | F'Number of Labels: {labels_num}, ' 140 | F'Candidates Number: {train_loader.dataset.candidates_num}') 141 | model.train(train_loader, valid_loader, **model_cnf['train'][level]) 142 | model.optimizer = model.state = None 143 | logger.info(F'Finish Training Level-{level}') 144 | self.models[level] = model 145 | if level == self.level - 1: 146 | return 147 | logger.info(F'Generating Candidates for Level-{level+1}, ' 148 | F'Number of Labels: {labels_num}, Top: {self.top}') 149 | group_scores, group_labels = train_group 150 | train_loader = DataLoader(XMLDataset(train_x, labels_num=labels_num, 151 | groups=last_groups, group_labels=group_labels, 152 | group_scores=group_scores), 153 | model_cnf['valid']['batch_size'], num_workers=4) 154 | return train_y, model.predict(train_loader, k=self.top), model.predict(valid_loader, k=self.top) 155 | 156 | def get_last_models(self, level): 157 | return self.models[level].model if level == 0 else self.models[level].network 158 | 159 | def predict_level(self, level, test_x, k, labels_num): 160 | data_cnf, model_cnf = self.data_cnf, self.model_cnf 161 | model = self.models.get(level, None) 162 | if level == 0: 163 | logger.info(F'Predicting Level-{level}, Top: {k}') 164 | if model is None: 165 | model = Model(AttentionRNN, labels_num=labels_num, model_path=F'{self.model_path}-Level-{level}', 166 | emb_init=self.emb_init, **data_cnf['model'], **model_cnf['model']) 167 | test_loader = DataLoader(MultiLabelDataset(test_x), model_cnf['predict']['batch_size'], 168 | num_workers=4) 169 | return model.predict(test_loader, k=k) 170 | else: 171 | if level == self.level - 1: 172 | groups = np.load(F'{self.groups_path}-Level-{level-1}.npy') 173 | else: 174 | groups = self.get_inter_groups(labels_num) 175 | group_scores, group_labels = self.predict_level(level - 1, test_x, self.top, len(groups)) 176 | torch.cuda.empty_cache() 177 | logger.info(F'Predicting Level-{level}, Top: {k}') 178 | if model is None: 179 | model = XMLModel(network=FastAttentionRNN, labels_num=labels_num, 180 | model_path=F'{self.model_path}-Level-{level}', 181 | emb_init=self.emb_init, **data_cnf['model'], **model_cnf['model']) 182 | test_loader = DataLoader(XMLDataset(test_x, labels_num=labels_num, 183 | groups=groups, group_labels=group_labels, group_scores=group_scores), 184 | model_cnf['predict']['batch_size'], num_workers=4) 185 | return model.predict(test_loader, k=k) 186 | 187 | def get_inter_groups(self, labels_num): 188 | assert labels_num % self.inter_group_size == 0 189 | return np.asarray([list(range(i, i + self.inter_group_size)) 190 | for i in range(0, labels_num, self.inter_group_size)]) 191 | 192 | def train(self, train_x, train_y, valid_x, valid_y, mlb): 193 | self.model_cnf['cluster']['groups_path'] = self.groups_path 194 | cluster_process = Process(target=build_tree_by_level, 195 | args=(self.data_cnf['train']['sparse'], self.data_cnf['train']['labels'], mlb), 196 | kwargs=self.model_cnf['cluster']) 197 | cluster_process.start() 198 | self.train_level(self.level - 1, train_x, train_y, valid_x, valid_y) 199 | cluster_process.join() 200 | cluster_process.close() 201 | 202 | def predict(self, test_x): 203 | return self.predict_level(self.level - 1, test_x, self.model_cnf['predict'].get('k', 100), self.labels_num) 204 | -------------------------------------------------------------------------------- /ensemble.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 3 | """ 4 | Created on 2019/6/11 5 | @author yrh 6 | 7 | """ 8 | 9 | import click 10 | import numpy as np 11 | from collections import defaultdict 12 | from tqdm import tqdm 13 | 14 | 15 | @click.command() 16 | @click.option('-p', '--prefix', help='Prefix of results.') 17 | @click.option('-t', '--trees', type=click.INT, help='The number of results using for ensemble.') 18 | def main(prefix, trees): 19 | labels, scores = [], [] 20 | for i in range(trees): 21 | labels.append(np.load(F'{prefix}-Tree-{i}-labels.npy')) 22 | scores.append(np.load(F'{prefix}-Tree-{i}-scores.npy')) 23 | ensemble_labels, ensemble_scores = [], [] 24 | for i in tqdm(range(len(labels[0]))): 25 | s = defaultdict(float) 26 | for j in range(len(labels[0][i])): 27 | for k in range(trees): 28 | s[labels[k][i][j]] += scores[k][i][j] 29 | s = sorted(s.items(), key=lambda x: x[1], reverse=True) 30 | ensemble_labels.append([x[0] for x in s[:len(labels[0][i])]]) 31 | ensemble_scores.append([x[1] for x in s[:len(labels[0][i])]]) 32 | np.save(F'{prefix}-Ensemble-labels', np.asarray(ensemble_labels)) 33 | np.save(F'{prefix}-Ensemble-scores', np.asarray(ensemble_scores)) 34 | 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 3 | """ 4 | Created on 2019/8/21 5 | @author yrh 6 | 7 | """ 8 | 9 | import warnings 10 | warnings.filterwarnings('ignore') 11 | 12 | import click 13 | import numpy as np 14 | from sklearn.preprocessing import MultiLabelBinarizer 15 | 16 | from deepxml.evaluation import * 17 | 18 | 19 | @click.command() 20 | @click.option('-r', '--results', type=click.Path(exists=True), help='Path of results.') 21 | @click.option('-t', '--targets', type=click.Path(exists=True), help='Path of targets.') 22 | @click.option('--train-labels', type=click.Path(exists=True), default=None, help='Path of labels for training set.') 23 | @click.option('-a', type=click.FLOAT, default=0.55, help='Parameter A for propensity score.') 24 | @click.option('-b', type=click.FLOAT, default=1.5, help='Parameter B for propensity score.') 25 | def main(results, targets, train_labels, a, b): 26 | res, targets = np.load(results), np.load(targets) 27 | mlb = MultiLabelBinarizer(sparse_output=True) 28 | targets = mlb.fit_transform(targets) 29 | print('Precision@1,3,5:', get_p_1(res, targets, mlb), get_p_3(res, targets, mlb), get_p_5(res, targets, mlb)) 30 | print('nDCG@1,3,5:', get_n_1(res, targets, mlb), get_n_3(res, targets, mlb), get_n_5(res, targets, mlb)) 31 | if train_labels is not None: 32 | train_labels = np.load(train_labels) 33 | inv_w = get_inv_propensity(mlb.transform(train_labels), a, b) 34 | print('PSPrecision@1,3,5:', get_psp_1(res, targets, inv_w, mlb), get_psp_3(res, targets, inv_w, mlb), 35 | get_psp_5(res, targets, inv_w, mlb)) 36 | print('PSnDCG@1,3,5:', get_psndcg_1(res, targets, inv_w, mlb), get_psndcg_3(res, targets, inv_w, mlb), 37 | get_psndcg_5(res, targets, inv_w, mlb)) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 3 | """ 4 | Created on 2018/12/9 5 | @author yrh 6 | 7 | """ 8 | 9 | import os 10 | import click 11 | import numpy as np 12 | from pathlib import Path 13 | from ruamel.yaml import YAML 14 | from sklearn.model_selection import train_test_split 15 | from torch.utils.data import DataLoader 16 | from logzero import logger 17 | 18 | from deepxml.dataset import MultiLabelDataset 19 | from deepxml.data_utils import get_data, get_mlb, get_word_emb, output_res 20 | from deepxml.models import Model 21 | from deepxml.tree import FastAttentionXML 22 | from deepxml.networks import AttentionRNN 23 | 24 | 25 | @click.command() 26 | @click.option('-d', '--data-cnf', type=click.Path(exists=True), help='Path of dataset configure yaml.') 27 | @click.option('-m', '--model-cnf', type=click.Path(exists=True), help='Path of model configure yaml.') 28 | @click.option('--mode', type=click.Choice(['train', 'eval']), default=None) 29 | @click.option('-t', '--tree-id', type=click.INT, default=None) 30 | def main(data_cnf, model_cnf, mode, tree_id): 31 | tree_id = F'-Tree-{tree_id}' if tree_id is not None else '' 32 | yaml = YAML(typ='safe') 33 | data_cnf, model_cnf = yaml.load(Path(data_cnf)), yaml.load(Path(model_cnf)) 34 | model, model_name, data_name = None, model_cnf['name'], data_cnf['name'] 35 | model_path = os.path.join(model_cnf['path'], F'{model_name}-{data_name}{tree_id}') 36 | emb_init = get_word_emb(data_cnf['embedding']['emb_init']) 37 | logger.info(F'Model Name: {model_name}') 38 | 39 | if mode is None or mode == 'train': 40 | logger.info('Loading Training and Validation Set') 41 | train_x, train_labels = get_data(data_cnf['train']['texts'], data_cnf['train']['labels']) 42 | if 'size' in data_cnf['valid']: 43 | random_state = data_cnf['valid'].get('random_state', 1240) 44 | train_x, valid_x, train_labels, valid_labels = train_test_split(train_x, train_labels, 45 | test_size=data_cnf['valid']['size'], 46 | random_state=random_state) 47 | else: 48 | valid_x, valid_labels = get_data(data_cnf['valid']['texts'], data_cnf['valid']['labels']) 49 | mlb = get_mlb(data_cnf['labels_binarizer'], np.hstack((train_labels, valid_labels))) 50 | train_y, valid_y = mlb.transform(train_labels), mlb.transform(valid_labels) 51 | labels_num = len(mlb.classes_) 52 | logger.info(F'Number of Labels: {labels_num}') 53 | logger.info(F'Size of Training Set: {len(train_x)}') 54 | logger.info(F'Size of Validation Set: {len(valid_x)}') 55 | 56 | logger.info('Training') 57 | if 'cluster' not in model_cnf: 58 | train_loader = DataLoader(MultiLabelDataset(train_x, train_y), 59 | model_cnf['train']['batch_size'], shuffle=True, num_workers=4) 60 | valid_loader = DataLoader(MultiLabelDataset(valid_x, valid_y, training=False), 61 | model_cnf['valid']['batch_size'], num_workers=4) 62 | model = Model(network=AttentionRNN, labels_num=labels_num, model_path=model_path, emb_init=emb_init, 63 | **data_cnf['model'], **model_cnf['model']) 64 | model.train(train_loader, valid_loader, **model_cnf['train']) 65 | else: 66 | model = FastAttentionXML(labels_num, data_cnf, model_cnf, tree_id) 67 | model.train(train_x, train_y, valid_x, valid_y, mlb) 68 | logger.info('Finish Training') 69 | 70 | if mode is None or mode == 'eval': 71 | logger.info('Loading Test Set') 72 | mlb = get_mlb(data_cnf['labels_binarizer']) 73 | labels_num = len(mlb.classes_) 74 | test_x, _ = get_data(data_cnf['test']['texts'], None) 75 | logger.info(F'Size of Test Set: {len(test_x)}') 76 | 77 | logger.info('Predicting') 78 | if 'cluster' not in model_cnf: 79 | test_loader = DataLoader(MultiLabelDataset(test_x), model_cnf['predict']['batch_size'], 80 | num_workers=4) 81 | if model is None: 82 | model = Model(network=AttentionRNN, labels_num=labels_num, model_path=model_path, emb_init=emb_init, 83 | **data_cnf['model'], **model_cnf['model']) 84 | scores, labels = model.predict(test_loader, k=model_cnf['predict'].get('k', 100)) 85 | else: 86 | if model is None: 87 | model = FastAttentionXML(labels_num, data_cnf, model_cnf, tree_id) 88 | scores, labels = model.predict(test_x) 89 | logger.info('Finish Predicting') 90 | labels = mlb.classes_[labels] 91 | output_res(data_cnf['output']['res'], F'{model_name}-{data_name}{tree_id}', scores, labels) 92 | 93 | 94 | if __name__ == '__main__': 95 | main() 96 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 3 | """ 4 | Created on 2019/1/20 5 | @author yrh 6 | 7 | """ 8 | 9 | import os 10 | import re 11 | import click 12 | import numpy as np 13 | from nltk.tokenize import word_tokenize 14 | from tqdm import tqdm 15 | from logzero import logger 16 | 17 | from deepxml.data_utils import * 18 | 19 | 20 | def tokenize(sentence: str, sep='/SEP/'): 21 | # We added a /SEP/ symbol between titles and descriptions such as Amazon datasets. 22 | return [token.lower() if token != sep else token for token in word_tokenize(sentence) 23 | if len(re.sub(r'[^\w]', '', token)) > 0] 24 | 25 | 26 | @click.command() 27 | @click.option('--text-path', type=click.Path(exists=True), help='Path of text.') 28 | @click.option('--tokenized-path', type=click.Path(), default=None, help='Path of tokenized text.') 29 | @click.option('--label-path', type=click.Path(exists=True), default=None, help='Path of labels.') 30 | @click.option('--vocab-path', type=click.Path(), default=None, 31 | help='Path of vocab, if it doesn\'t exit, build one and save it.') 32 | @click.option('--emb-path', type=click.Path(), default=None, help='Path of word embedding.') 33 | @click.option('--w2v-model', type=click.Path(), default=None, help='Path of Gensim Word2Vec Model.') 34 | @click.option('--vocab-size', type=click.INT, default=500000, help='Size of vocab.') 35 | @click.option('--max-len', type=click.INT, default=500, help='Truncated length.') 36 | def main(text_path, tokenized_path, label_path, vocab_path, emb_path, w2v_model, vocab_size, max_len): 37 | if tokenized_path is not None: 38 | logger.info(F'Tokenizing Text. {text_path}') 39 | with open(text_path) as fp, open(tokenized_path, 'w') as fout: 40 | for line in tqdm(fp, desc='Tokenizing'): 41 | print(*tokenize(line), file=fout) 42 | text_path = tokenized_path 43 | 44 | if not os.path.exists(vocab_path): 45 | logger.info(F'Building Vocab. {text_path}') 46 | with open(text_path) as fp: 47 | vocab, emb_init = build_vocab(fp, w2v_model, vocab_size=vocab_size) 48 | np.save(vocab_path, vocab) 49 | np.save(emb_path, emb_init) 50 | vocab = {word: i for i, word in enumerate(np.load(vocab_path))} 51 | logger.info(F'Vocab Size: {len(vocab)}') 52 | 53 | logger.info(F'Getting Dataset: {text_path} Max Length: {max_len}') 54 | texts, labels = convert_to_binary(text_path, label_path, max_len, vocab) 55 | logger.info(F'Size of Samples: {len(texts)}') 56 | np.save(os.path.splitext(text_path)[0], texts) 57 | if labels is not None: 58 | assert len(texts) == len(labels) 59 | np.save(os.path.splitext(label_path)[0], labels) 60 | 61 | 62 | if __name__ == '__main__': 63 | main() 64 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click==7.0 2 | ruamel.yaml==0.16.5 3 | numpy==1.16.2 4 | scipy==1.3.1 5 | scikit-learn==0.21.2 6 | gensim==3.4.0 7 | torch==1.0.1 8 | nltk==3.4 9 | tqdm==4.31.1 10 | joblib==0.13.2 11 | logzero==1.5.0 12 | -------------------------------------------------------------------------------- /scripts/run_amazon.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATA=Amazon-670K 4 | MODEL=FastAttentionXML 5 | 6 | ./scripts/run_preprocess.sh $DATA 7 | ./scripts/run_xml.sh $DATA $MODEL 8 | 9 | python evaluation.py \ 10 | --results results/$MODEL-$DATA-Ensemble-labels.npy \ 11 | --targets data/$DATA/test_labels.npy \ 12 | --train-labels data/$DATA/train_labels.npy \ 13 | -a 0.6 \ 14 | -b 2.6 15 | -------------------------------------------------------------------------------- /scripts/run_amazon3m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATA=Amazon-3M 4 | MODEL=FastAttentionXML 5 | 6 | ./scripts/run_preprocess.sh $DATA 7 | ./scripts/run_xml.sh $DATA $MODEL 8 | 9 | python evaluation.py \ 10 | --results results/$MODEL-$DATA-Ensemble-labels.npy \ 11 | --targets data/$DATA/test_labels.npy \ 12 | --train-labels data/$DATA/train_labels.npy \ 13 | -a 0.6 \ 14 | -b 2.6 15 | -------------------------------------------------------------------------------- /scripts/run_amazoncat.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATA=AmazonCat-13K 4 | MODEL=AttentionXML 5 | 6 | ./scripts/run_preprocess.sh $DATA 7 | ./scripts/run_xml.sh $DATA $MODEL 8 | 9 | python evaluation.py \ 10 | --results results/$MODEL-$DATA-Ensemble-labels.npy \ 11 | --targets data/$DATA/test_labels.npy \ 12 | --train-labels data/$DATA/train_labels.npy 13 | -------------------------------------------------------------------------------- /scripts/run_eurlex.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATA=EUR-Lex 4 | MODEL=AttentionXML 5 | 6 | ./scripts/run_preprocess.sh $DATA 7 | ./scripts/run_xml.sh $DATA $MODEL 8 | 9 | python evaluation.py \ 10 | --results results/$MODEL-$DATA-Ensemble-labels.npy \ 11 | --targets data/$DATA/test_labels.npy \ 12 | --train-labels data/$DATA/train_labels.npy 13 | -------------------------------------------------------------------------------- /scripts/run_preprocess.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if [ $1 == "EUR-Lex" ]; then 4 | TRAIN_TEXT="--text-path data/$1/train_texts.txt" 5 | TEST_TEXT="--text-path data/$1/test_texts.txt" 6 | else 7 | TRAIN_TEXT="--text-path data/$1/train_raw_texts.txt --tokenized-path data/$1/train_texts.txt" 8 | TEST_TEXT="--text-path data/$1/test_raw_texts.txt --tokenized-path data/$1/test_texts.txt" 9 | fi 10 | 11 | if [ ! -f data/$1/train_texts.npy ]; then 12 | python preprocess.py $TRAIN_TEXT --label-path data/$1/train_labels.txt --vocab-path data/$1/vocab.npy --emb-path data/$1/emb_init.npy --w2v-model data/glove.840B.300d.gensim 13 | fi 14 | if [ ! -f data/$1/test_texts.npy ]; then 15 | python preprocess.py $TEST_TEXT --label-path data/$1/test_labels.txt --vocab-path data/$1/vocab.npy 16 | fi 17 | -------------------------------------------------------------------------------- /scripts/run_wiki.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATA=Wiki-500K 4 | MODEL=FastAttentionXML 5 | 6 | ./scripts/run_preprocess.sh $DATA 7 | ./scripts/run_xml.sh $DATA $MODEL 8 | 9 | python evaluation.py \ 10 | --results results/$MODEL-$DATA-Ensemble-labels.npy \ 11 | --targets data/$DATA/test_labels.npy \ 12 | --train-labels data/$DATA/train_labels.npy \ 13 | -a 0.5 \ 14 | -b 0.4 15 | -------------------------------------------------------------------------------- /scripts/run_wiki10.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DATA=Wiki10-31K 4 | MODEL=AttentionXML 5 | 6 | ./scripts/run_preprocess.sh $DATA 7 | ./scripts/run_xml.sh $DATA $MODEL 8 | 9 | python evaluation.py \ 10 | --results results/$MODEL-$DATA-Ensemble-labels.npy \ 11 | --targets data/$DATA/test_labels.npy \ 12 | --train-labels data/$DATA/train_labels.npy 13 | -------------------------------------------------------------------------------- /scripts/run_xml.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python main.py --data-cnf configure/datasets/$1.yaml --model-cnf configure/models/$2-$1.yaml -t 0 4 | python main.py --data-cnf configure/datasets/$1.yaml --model-cnf configure/models/$2-$1.yaml -t 1 5 | python main.py --data-cnf configure/datasets/$1.yaml --model-cnf configure/models/$2-$1.yaml -t 2 6 | python ensemble.py -p results/$2-$1 -t 3 7 | --------------------------------------------------------------------------------