├── .gitignore ├── LICENSE ├── README.md ├── deepxml ├── configs │ ├── DeepXML-ANNS │ │ └── EURLex-4K.json │ ├── DeepXML-OVA │ │ └── EURLex-4K.json │ └── DeepXML │ │ ├── Amazon-670K.json │ │ ├── AmazonTitles-3M.json │ │ ├── AmazonTitles-670K.json │ │ ├── EURLex-4K.json │ │ ├── LF-Amazon-131K.json │ │ ├── LF-AmazonTitles-1.3M.json │ │ ├── LF-AmazonTitles-131K.json │ │ ├── LF-WikiSeeAlso-320K.json │ │ ├── LF-WikiSeeAlsoTitles-320K.json │ │ ├── LF-WikiTitles-500K.json │ │ ├── WikiSeeAlsoTitles-350K.json │ │ └── WikiTitles-500K.json ├── libs │ ├── __init__.py │ ├── collate_fn.py │ ├── dataset.py │ ├── dataset_base.py │ ├── dist_utils.py │ ├── features.py │ ├── labels.py │ ├── lookup.py │ ├── loss.py │ ├── model.py │ ├── model_base.py │ ├── optimizer.py │ ├── parameters.py │ ├── parameters_base.py │ ├── sampling.py │ ├── shortlist.py │ ├── shortlist_handler.py │ ├── tracking.py │ └── utils.py ├── main.py ├── models │ ├── __init__.py │ ├── astec.py │ ├── embedding_layer.py │ ├── linear_layer.py │ ├── mlp.py │ ├── network.py │ ├── residual_layer.py │ └── transform_layer.py ├── run_scripts │ ├── Astec.json │ ├── Identity.json │ ├── RNN.json │ ├── run_datasets.sh │ └── run_main.sh ├── runner.py └── tools │ ├── convert_format.pl │ ├── evaluate.py │ ├── evaluate_ensemble.py │ └── surrogate_mapping.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | 118 | # mypy 119 | .mypy_cache/ 120 | .dmypy.json 121 | dmypy.json 122 | 123 | # Pyre type checker 124 | .pyre/ 125 | .vscode -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kunal Dahiya 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepXML 2 | 3 | Code for _DeepXML: A Deep Extreme Multi-Label Learning Framework Applied to Short Text Documents_ 4 | 5 | --- 6 | 7 | ## Architectures and algorithms 8 | 9 | DeepXML supports multiple feature architectures such as Bag-of-embedding/Astec, RNN, CNN etc. The code uses a json file to construct the feature architecture. Features could be computed using following encoders: 10 | 11 | * Bag-of-embedding/Astec: As used in the DeepXML paper [1]. 12 | * RNN: RNN based sequential models. Support for RNN, GRU, and LSTM. 13 | * XML-CNN: CNN architecture as proposed in the XML-CNN paper [4]. 14 | 15 | --- 16 | 17 | ## Best Practices for features creation 18 | 19 | --- 20 | 21 | * Adding sub-words on top of unigrams to the vocabulary can help in training more accurate embeddings and classifiers. 22 | 23 | --- 24 | 25 | ## Setting up 26 | 27 | --- 28 | 29 | ### Expected directory structure 30 | 31 | ```txt 32 | +-- 33 | | +-- programs 34 | | | +-- deepxml 35 | | | +-- deepxml 36 | | +-- data 37 | | +-- 38 | | +-- models 39 | | +-- results 40 | 41 | ``` 42 | 43 | ### Download data for Astec 44 | 45 | ```txt 46 | * Download the (zipped file) BoW features from XML repository. 47 | * Extract the zipped file into data directory. 48 | * The following files should be available in /data/ for new datasets (ignore the next step) 49 | - trn_X_Xf.txt 50 | - trn_X_Y.txt 51 | - tst_X_Xf.txt 52 | - tst_X_Y.txt 53 | - fasttextB_embeddings_300d.npy or fasttextB_embeddings_512d.npy 54 | * The following files should be available in /data/ if the dataset is in old format (please refer to next step to convert the data to new format) 55 | - train.txt 56 | - test.txt 57 | - fasttextB_embeddings_300d.npy or fasttextB_embeddings_512d.npy 58 | ``` 59 | 60 | ### Convert to new data format 61 | 62 | ```perl 63 | # A perl script is provided (in deepxml/tools) to convert the data into new format as expected by Astec 64 | # Either set the $data_dir variable to the data directory of a particular dataset or replace it with the path 65 | perl convert_format.pl $data_dir/train.txt $data_dir/trn_X_Xf.txt $data_dir/trn_X_Y.txt 66 | perl convert_format.pl $data_dir/test.txt $data_dir/tst_X_Xf.txt $data_dir/tst_X_Y.txt 67 | ``` 68 | 69 | ## Example use cases 70 | 71 | --- 72 | 73 | ### A single learner with DeepXML framework 74 | 75 | The DeepXML framework can be utilized as follows. A json file is used to specify architecture and other arguments. Please refer to the full documentation below for more details. 76 | 77 | ```bash 78 | ./run_main.sh 0 DeepXML EURLex-4K 0 108 79 | ``` 80 | 81 | ### An ensemble of multiple learners with DeepXML framework 82 | 83 | An ensemble can be trained as follows. A json file is used to specify architecture and other arguments. 84 | 85 | ```bash 86 | ./run_main.sh 0 DeepXML EURLex-4K 0 108,666,786 87 | ``` 88 | 89 | ## Full Documentation 90 | 91 | ```txt 92 | ./run_main.sh 93 | 94 | * gpu_id: Run the program on this GPU. 95 | 96 | * framework 97 | - DeepXML: Divides the XML problems in 4 modules as proposed in the paper. 98 | - DeepXML-OVA: Train the architecture in 1-vs-all fashion [4][5], i.e., loss is computed for each label in each iteration. 99 | - DeepXML-ANNS: Train the architecture using a label shortlist. Support is available for a fixed graph or periodic training of the ANNS graph. 100 | 101 | * dataset 102 | - Name of the dataset. 103 | - Astec expects the following files in /data/ 104 | - trn_X_Xf.txt 105 | - trn_X_Y.txt 106 | - tst_X_Xf.txt 107 | - tst_X_Y.txt 108 | - fasttextB_embeddings_300d.npy or fasttextB_embeddings_512d.npy 109 | - You can set the 'embedding_dims' in config file to switch between 300d and 512d embeddings. 110 | 111 | * version 112 | - different runs could be managed by version and seed. 113 | - models and results are stored with this argument. 114 | 115 | * seed 116 | - seed value as used by numpy and PyTorch. 117 | - an ensemble is learned if multiple comma separated values are passed. 118 | ``` 119 | 120 | ### Notes 121 | 122 | ```txt 123 | * Other file formats such as npy, npz, pickle are also supported. 124 | * Initializing with token embeddings (computed from FastText) leads to noticible accuracy gain in Astec. Please ensure that the token embedding file is available in data directory, if 'init=token_embeddings', otherwise it'll throw an error. 125 | * Config files are made available in deepxml/configs// for datasets in XC repository. You can use them when trying out Astec/DeepXML on new datasets. 126 | * We conducted our experiments on a 24-core Intel Xeon 2.6 GHz machine with 440GB RAM with a single Nvidia P40 GPU. 128GB memory should suffice for most datasets. 127 | * Astec make use of CPU (mainly for nmslib) as well as GPU. 128 | ``` 129 | 130 | ## Cite as 131 | 132 | ```bib 133 | @InProceedings{Dahiya21, 134 | author = "Dahiya, K. and Saini, D. and Mittal, A. and Shaw, A. and Dave, K. and Soni, A. and Jain, H. and Agarwal, S. and Varma, M.", 135 | title = "DeepXML: A Deep Extreme Multi-Label Learning Framework Applied to Short Text Documents", 136 | booktitle = "Proceedings of the ACM International Conference on Web Search and Data Mining", 137 | month = "March", 138 | year = "2021" 139 | } 140 | ``` 141 | 142 | 143 | ## YOU MAY ALSO LIKE 144 | - [DECAF: Deep Extreme Classification with Label Features](https://github.com/Extreme-classification/DECAF) 145 | - [GalaXC: Graph Neural Networks with Labelwise Attention for Extreme Classification](https://github.com/Extreme-classification/GalaXC) 146 | - [ECLARE: Extreme Classification with Label Graph Correlations](https://github.com/Extreme-classification/ECLARE) 147 | 148 | ## References 149 | 150 | --- 151 | [1] K. Dahiya, D. Saini, A. Mittal, A. Shaw, K. Dave, A. Soni, H. Jain, S. Agarwal, and M. Varma. Deepxml: A deep extreme multi-label learning framework applied to short text documents. In WSDM, 2021. 152 | 153 | [2] pyxclib: 154 | 155 | [3] H. Jain, V. Balasubramanian, B. Chunduri and M. Varma, Slice: Scalable linear extreme classifiers trained on 100 million labels for related searches, In WSDM 2019. 156 | 157 | [4] J. Liu, W.-C. Chang, Y. Wu and Y. Yang, XML-CNN: Deep Learning for Extreme Multi-label Text Classification, In SIGIR 2017. 158 | 159 | [5] R. Babbar, and B. Schölkopf, DiSMEC - Distributed Sparse Machines for Extreme Multi-label Classification In WSDM, 2017. 160 | 161 | [6] P., Bojanowski, E. Grave, A. Joulin, and T. Mikolov. Enriching word vectors with subword information. In TACL, 2017. 162 | -------------------------------------------------------------------------------- /deepxml/configs/DeepXML-ANNS/EURLex-4K.json: -------------------------------------------------------------------------------- 1 | { 2 | "global": { 3 | "dataset": "EURLex-4K", 4 | "feature_type": "sparse", 5 | "num_labels": 3993, 6 | "arch": "Astec", 7 | "A": 0.55, 8 | "B": 1.5, 9 | "use_reranker": true, 10 | "beta": 0.5, 11 | "surrogate_threshold": 0, 12 | "surrogate_method": 0, 13 | "embedding_dims": 300, 14 | "top_k": 100, 15 | "save_predictions": true, 16 | "trn_label_fname": "trn_X_Y.txt", 17 | "val_label_fname": "tst_X_Y.txt", 18 | "tst_label_fname": "tst_X_Y.txt", 19 | "trn_feat_fname": "trn_X_Xf.txt", 20 | "val_feat_fname": "tst_X_Xf.txt", 21 | "tst_feat_fname": "tst_X_Xf.txt" 22 | }, 23 | "extreme": { 24 | "num_epochs": 30, 25 | "dlr_factor": 0.5, 26 | "learning_rate": 0.007, 27 | "batch_size": 255, 28 | "dlr_step": 14, 29 | "ns_method": "ensemble", 30 | "num_centroids": 1, 31 | "efC": 300, 32 | "efS": 400, 33 | "M": 100, 34 | "num_nbrs": 500, 35 | "ann_threads": 18, 36 | "beta": 0.5, 37 | "retrain_hnsw_after": 5, 38 | "use_intermediate_for_shorty": true, 39 | "update_shortlist": true, 40 | "surrogate_mapping": null, 41 | "num_clf_partitions": 1, 42 | "optim": "Adam", 43 | "freeze_intermediate": false, 44 | "validate": true, 45 | "model_method": "shortlist", 46 | "normalize": true, 47 | "shortlist_method": "hybrid", 48 | "init": "token_embeddings", 49 | "use_shortlist": true, 50 | "embeddings": "fasttextB_embeddings_300d.npy" 51 | } 52 | } -------------------------------------------------------------------------------- /deepxml/configs/DeepXML-OVA/EURLex-4K.json: -------------------------------------------------------------------------------- 1 | { 2 | "global": { 3 | "dataset": "EURLex-4K", 4 | "feature_type": "sparse", 5 | "num_labels": 3993, 6 | "arch": "Astec", 7 | "A": 0.55, 8 | "B": 1.5, 9 | "use_reranker": true, 10 | "surrogate_threshold": 0, 11 | "surrogate_method": 0, 12 | "embedding_dims": 300, 13 | "top_k": 100, 14 | "save_predictions": true, 15 | "trn_label_fname": "trn_X_Y.txt", 16 | "val_label_fname": "tst_X_Y.txt", 17 | "tst_label_fname": "tst_X_Y.txt", 18 | "trn_feat_fname": "trn_X_Xf.txt", 19 | "val_feat_fname": "tst_X_Xf.txt", 20 | "tst_feat_fname": "tst_X_Xf.txt" 21 | }, 22 | "extreme": { 23 | "num_epochs": 30, 24 | "dlr_factor": 0.5, 25 | "learning_rate": 0.01, 26 | "batch_size": 255, 27 | "dlr_step": 14, 28 | "surrogate_mapping": null, 29 | "num_clf_partitions": 1, 30 | "optim": "Adam", 31 | "freeze_intermediate": false, 32 | "validate": true, 33 | "model_method": "full", 34 | "normalize": true, 35 | "init": "token_embeddings", 36 | "use_shortlist": false, 37 | "embeddings": "fasttextB_embeddings_300d.npy" 38 | } 39 | } -------------------------------------------------------------------------------- /deepxml/configs/DeepXML/Amazon-670K.json: -------------------------------------------------------------------------------- 1 | { 2 | "global": { 3 | "dataset": "Amazon-670K", 4 | "feature_type": "sparse", 5 | "num_labels": 670091, 6 | "arch": "Astec", 7 | "A": 0.6, 8 | "B": 2.6, 9 | "use_reranker": true, 10 | "surrogate_threshold": 65536, 11 | "surrogate_method": 1, 12 | "embedding_dims": 512, 13 | "top_k": 250, 14 | "beta": 0.10, 15 | "save_predictions": true, 16 | "trn_label_fname": "trn_X_Y.txt", 17 | "val_label_fname": "tst_X_Y.txt", 18 | "tst_label_fname": "tst_X_Y.txt", 19 | "trn_feat_fname": "trn_X_Xf.txt", 20 | "val_feat_fname": "tst_X_Xf.txt", 21 | "tst_feat_fname": "tst_X_Xf.txt" 22 | }, 23 | "surrogate": { 24 | "num_epochs": 20, 25 | "dlr_factor": 0.5, 26 | "learning_rate": 0.02, 27 | "batch_size": 255, 28 | "dlr_step": 14, 29 | "normalize": true, 30 | "optim": "Adam", 31 | "init": "token_embeddings", 32 | "embeddings": "fasttextB_embeddings_512d.npy", 33 | "validate": true, 34 | "save_intermediate": true 35 | }, 36 | "extreme": { 37 | "num_epochs": 20, 38 | "dlr_factor": 0.5, 39 | "learning_rate": 0.0005, 40 | "batch_size": 255, 41 | "dlr_step": 14, 42 | "ns_method": "ensemble", 43 | "num_centroids": 1, 44 | "efC": 300, 45 | "efS": 400, 46 | "M": 100, 47 | "num_nbrs": 500, 48 | "ann_threads": 18, 49 | "beta": 0.5, 50 | "surrogate_mapping": null, 51 | "num_clf_partitions": 1, 52 | "optim": "Adam", 53 | "freeze_intermediate": true, 54 | "validate": true, 55 | "model_method": "shortlist", 56 | "normalize": true, 57 | "shortlist_method": "hybrid", 58 | "init": "intermediate", 59 | "use_shortlist": true, 60 | "use_intermediate_for_shorty": true 61 | }, 62 | "reranker": { 63 | "num_epochs": 15, 64 | "dlr_factor": 0.5, 65 | "learning_rate": 0.001, 66 | "batch_size": 255, 67 | "dlr_step": 8, 68 | "beta": 0.5, 69 | "num_clf_partitions": 1, 70 | "optim": "Adam", 71 | "validate": true, 72 | "model_method": "reranker", 73 | "shortlist_method": "static", 74 | "surrogate_mapping": null, 75 | "normalize": true, 76 | "use_shortlist": true, 77 | "init": "token_embeddings", 78 | "save_intermediate": false, 79 | "keep_invalid": true, 80 | "freeze_intermediate": false, 81 | "update_shortlist": false, 82 | "use_pretrained_shortlist": true, 83 | "embeddings": "fasttextB_embeddings_512d.npy" 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /deepxml/configs/DeepXML/AmazonTitles-3M.json: -------------------------------------------------------------------------------- 1 | { 2 | "global": { 3 | "dataset": "AmazonTitles-3M", 4 | "feature_type": "sparse", 5 | "num_labels": 2812281, 6 | "arch": "Astec", 7 | "A": 0.6, 8 | "B": 2.6, 9 | "use_reranker": true, 10 | "surrogate_threshold": 65536, 11 | "surrogate_method": 1, 12 | "embedding_dims": 300, 13 | "beta": 0.10, 14 | "top_k": 300, 15 | "save_top_k": 100, 16 | "save_predictions": true, 17 | "trn_label_fname": "trn_X_Y.txt", 18 | "val_label_fname": "tst_X_Y.txt", 19 | "tst_label_fname": "tst_X_Y.txt", 20 | "trn_feat_fname": "trn_X_Xf.txt", 21 | "val_feat_fname": "tst_X_Xf.txt", 22 | "tst_feat_fname": "tst_X_Xf.txt" 23 | }, 24 | "surrogate": { 25 | "num_epochs": 20, 26 | "dlr_factor": 0.5, 27 | "learning_rate": 0.003, 28 | "batch_size": 255, 29 | "dlr_step": 14, 30 | "normalize": true, 31 | "optim": "Adam", 32 | "init": "token_embeddings", 33 | "embeddings": "fasttextB_embeddings_300d.npy", 34 | "validate": true, 35 | "save_intermediate": true 36 | }, 37 | "extreme": { 38 | "num_epochs": 15, 39 | "dlr_factor": 0.5, 40 | "learning_rate": 0.0005, 41 | "batch_size": 255, 42 | "dlr_step": 14, 43 | "ns_method": "ensemble", 44 | "num_centroids": 300, 45 | "efC": 300, 46 | "efS": 400, 47 | "M": 100, 48 | "num_nbrs": 500, 49 | "ann_threads": 18, 50 | "beta": 0.5, 51 | "surrogate_mapping": null, 52 | "num_clf_partitions": 1, 53 | "optim": "Adam", 54 | "freeze_intermediate": true, 55 | "validate": true, 56 | "model_method": "shortlist", 57 | "normalize": true, 58 | "shortlist_method": "hybrid", 59 | "init": "intermediate", 60 | "use_shortlist": true, 61 | "use_intermediate_for_shorty": true 62 | }, 63 | "reranker": { 64 | "num_epochs": 10, 65 | "dlr_factor": 0.5, 66 | "learning_rate": 0.002, 67 | "batch_size": 255, 68 | "dlr_step": 7, 69 | "beta": 0.6, 70 | "num_clf_partitions": 1, 71 | "optim": "Adam", 72 | "validate": true, 73 | "model_method": "reranker", 74 | "shortlist_method": "static", 75 | "surrogate_mapping": null, 76 | "normalize": true, 77 | "use_shortlist": true, 78 | "init": "token_embeddings", 79 | "save_intermediate": false, 80 | "keep_invalid": true, 81 | "freeze_intermediate": false, 82 | "update_shortlist": false, 83 | "use_pretrained_shortlist": true, 84 | "embeddings": "fasttextB_embeddings_300d.npy" 85 | } 86 | } -------------------------------------------------------------------------------- /deepxml/configs/DeepXML/AmazonTitles-670K.json: -------------------------------------------------------------------------------- 1 | { 2 | "global": { 3 | "dataset": "AmazonTitles-670K", 4 | "feature_type": "sparse", 5 | "num_labels": 670091, 6 | "arch": "Astec", 7 | "A": 0.6, 8 | "B": 2.6, 9 | "use_reranker": true, 10 | "surrogate_threshold": 65536, 11 | "surrogate_method": 1, 12 | "embedding_dims": 300, 13 | "top_k": 200, 14 | "beta": 0.10, 15 | "save_predictions": true, 16 | "trn_label_fname": "trn_X_Y.txt", 17 | "val_label_fname": "tst_X_Y.txt", 18 | "tst_label_fname": "tst_X_Y.txt", 19 | "trn_feat_fname": "trn_X_Xf.txt", 20 | "val_feat_fname": "tst_X_Xf.txt", 21 | "tst_feat_fname": "tst_X_Xf.txt" 22 | }, 23 | "surrogate": { 24 | "num_epochs": 20, 25 | "dlr_factor": 0.5, 26 | "learning_rate": 0.02, 27 | "batch_size": 255, 28 | "dlr_step": 10, 29 | "normalize": true, 30 | "optim": "Adam", 31 | "init": "token_embeddings", 32 | "embeddings": "fasttextB_embeddings_300d.npy", 33 | "validate": true, 34 | "save_intermediate": true 35 | }, 36 | "extreme": { 37 | "num_epochs": 15, 38 | "dlr_factor": 0.5, 39 | "learning_rate": 0.001, 40 | "batch_size": 255, 41 | "dlr_step": 14, 42 | "ns_method": "ensemble", 43 | "num_centroids": 1, 44 | "efC": 300, 45 | "efS": 400, 46 | "M": 100, 47 | "num_nbrs": 500, 48 | "ann_threads": 18, 49 | "beta": 0.5, 50 | "surrogate_mapping": null, 51 | "num_clf_partitions": 1, 52 | "optim": "Adam", 53 | "freeze_intermediate": true, 54 | "validate": true, 55 | "model_method": "shortlist", 56 | "normalize": true, 57 | "shortlist_method": "hybrid", 58 | "init": "intermediate", 59 | "use_shortlist": true, 60 | "use_intermediate_for_shorty": true 61 | }, 62 | "reranker": { 63 | "num_epochs": 12, 64 | "dlr_factor": 0.5, 65 | "learning_rate": 0.002, 66 | "batch_size": 255, 67 | "dlr_step": 8, 68 | "beta": 0.6, 69 | "num_clf_partitions": 1, 70 | "optim": "Adam", 71 | "validate": true, 72 | "model_method": "reranker", 73 | "shortlist_method": "static", 74 | "surrogate_mapping": null, 75 | "normalize": true, 76 | "use_shortlist": true, 77 | "init": "token_embeddings", 78 | "save_intermediate": false, 79 | "keep_invalid": true, 80 | "freeze_intermediate": false, 81 | "update_shortlist": false, 82 | "use_pretrained_shortlist": true, 83 | "embeddings": "fasttextB_embeddings_300d.npy" 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /deepxml/configs/DeepXML/EURLex-4K.json: -------------------------------------------------------------------------------- 1 | { 2 | "global": { 3 | "dataset": "EURLex-4K", 4 | "feature_type": "sparse", 5 | "num_labels": 3993, 6 | "arch": "Astec", 7 | "A": 0.55, 8 | "B": 1.5, 9 | "use_reranker": true, 10 | "surrogate_threshold": 1024, 11 | "surrogate_method": 1, 12 | "embedding_dims": 300, 13 | "top_k": 100, 14 | "beta": 0.3, 15 | "save_predictions": true, 16 | "trn_label_fname": "trn_X_Y.txt", 17 | "val_label_fname": "tst_X_Y.txt", 18 | "tst_label_fname": "tst_X_Y.txt", 19 | "trn_feat_fname": "trn_X_Xf.txt", 20 | "val_feat_fname": "tst_X_Xf.txt", 21 | "tst_feat_fname": "tst_X_Xf.txt" 22 | }, 23 | "surrogate": { 24 | "num_epochs": 20, 25 | "dlr_factor": 0.5, 26 | "learning_rate": 0.01, 27 | "batch_size": 255, 28 | "dlr_step": 14, 29 | "normalize": true, 30 | "init": "token_embeddings", 31 | "optim": "Adam", 32 | "embeddings": "fasttextB_embeddings_300d.npy", 33 | "validate": true, 34 | "save_intermediate": true 35 | }, 36 | "extreme": { 37 | "num_epochs": 20, 38 | "dlr_factor": 0.5, 39 | "learning_rate": 0.007, 40 | "batch_size": 255, 41 | "dlr_step": 14, 42 | "ns_method": "ensemble", 43 | "num_centroids": 1, 44 | "efC": 300, 45 | "efS": 400, 46 | "M": 100, 47 | "num_nbrs": 500, 48 | "ann_threads": 18, 49 | "beta": 0.5, 50 | "surrogate_mapping": null, 51 | "num_clf_partitions": 1, 52 | "optim": "Adam", 53 | "freeze_intermediate": true, 54 | "validate": true, 55 | "model_method": "shortlist", 56 | "normalize": true, 57 | "shortlist_method": "hybrid", 58 | "init": "intermediate", 59 | "use_shortlist": true, 60 | "use_intermediate_for_shorty": true 61 | }, 62 | "reranker": { 63 | "num_epochs": 15, 64 | "dlr_factor": 0.5, 65 | "learning_rate": 0.005, 66 | "batch_size": 255, 67 | "dlr_step": 10, 68 | "beta": 0.6, 69 | "num_clf_partitions": 1, 70 | "optim": "Adam", 71 | "validate": true, 72 | "model_method": "reranker", 73 | "shortlist_method": "static", 74 | "surrogate_mapping": null, 75 | "normalize": true, 76 | "use_shortlist": true, 77 | "init": "token_embeddings", 78 | "save_intermediate": false, 79 | "keep_invalid": true, 80 | "freeze_intermediate": false, 81 | "update_shortlist": false, 82 | "use_pretrained_shortlist": true, 83 | "embeddings": "fasttextB_embeddings_300d.npy" 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /deepxml/configs/DeepXML/LF-Amazon-131K.json: -------------------------------------------------------------------------------- 1 | { 2 | "global": { 3 | "dataset": "LF-Amazon-131K", 4 | "feature_type": "sparse", 5 | "num_labels": 131073, 6 | "arch": "Astec", 7 | "A": 0.6, 8 | "B": 2.6, 9 | "use_reranker": true, 10 | "surrogate_threshold": 65536, 11 | "surrogate_method": 1, 12 | "embedding_dims": 512, 13 | "top_k": 250, 14 | "beta": 0.10, 15 | "save_predictions": true, 16 | "trn_label_fname": "trn_X_Y.txt", 17 | "val_label_fname": "tst_X_Y.txt", 18 | "tst_label_fname": "tst_X_Y.txt", 19 | "trn_feat_fname": "trn_X_Xf.txt", 20 | "val_feat_fname": "tst_X_Xf.txt", 21 | "tst_feat_fname": "tst_X_Xf.txt" 22 | }, 23 | "surrogate": { 24 | "num_epochs": 20, 25 | "dlr_factor": 0.5, 26 | "learning_rate": 0.02, 27 | "batch_size": 255, 28 | "dlr_step": 14, 29 | "normalize": true, 30 | "optim": "Adam", 31 | "init": "token_embeddings", 32 | "embeddings": "fasttextB_embeddings_512d.npy", 33 | "validate": true, 34 | "save_intermediate": true 35 | }, 36 | "extreme": { 37 | "num_epochs": 20, 38 | "dlr_factor": 0.5, 39 | "learning_rate": 0.0005, 40 | "batch_size": 255, 41 | "dlr_step": 14, 42 | "ns_method": "ensemble", 43 | "num_centroids": 1, 44 | "efC": 300, 45 | "efS": 400, 46 | "M": 100, 47 | "num_nbrs": 500, 48 | "ann_threads": 18, 49 | "beta": 0.5, 50 | "surrogate_mapping": null, 51 | "num_clf_partitions": 1, 52 | "optim": "Adam", 53 | "freeze_intermediate": true, 54 | "validate": true, 55 | "model_method": "shortlist", 56 | "normalize": true, 57 | "shortlist_method": "hybrid", 58 | "init": "intermediate", 59 | "use_shortlist": true, 60 | "use_intermediate_for_shorty": true 61 | }, 62 | "reranker": { 63 | "num_epochs": 15, 64 | "dlr_factor": 0.5, 65 | "learning_rate": 0.001, 66 | "batch_size": 255, 67 | "dlr_step": 8, 68 | "beta": 0.5, 69 | "num_clf_partitions": 1, 70 | "optim": "Adam", 71 | "validate": true, 72 | "model_method": "reranker", 73 | "shortlist_method": "static", 74 | "surrogate_mapping": null, 75 | "normalize": true, 76 | "use_shortlist": true, 77 | "init": "token_embeddings", 78 | "save_intermediate": false, 79 | "keep_invalid": true, 80 | "freeze_intermediate": false, 81 | "update_shortlist": false, 82 | "use_pretrained_shortlist": true, 83 | "embeddings": "fasttextB_embeddings_512d.npy" 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /deepxml/configs/DeepXML/LF-AmazonTitles-1.3M.json: -------------------------------------------------------------------------------- 1 | { 2 | "global": { 3 | "dataset": "LF-AmazonTitles-1.3M", 4 | "feature_type": "sparse", 5 | "num_labels": 1305265, 6 | "arch": "Astec", 7 | "A": 0.6, 8 | "B": 2.6, 9 | "use_reranker": true, 10 | "surrogate_threshold": 65536, 11 | "surrogate_method": 1, 12 | "embedding_dims": 300, 13 | "top_k": 350, 14 | "save_top_k": 100, 15 | "beta": 0.60, 16 | "save_predictions": true, 17 | "trn_label_fname": "trn_X_Y.txt", 18 | "val_label_fname": "tst_X_Y.txt", 19 | "tst_label_fname": "tst_X_Y.txt", 20 | "trn_feat_fname": "trn_X_Xf.txt", 21 | "val_feat_fname": "tst_X_Xf.txt", 22 | "tst_feat_fname": "tst_X_Xf.txt" 23 | }, 24 | "surrogate": { 25 | "num_epochs": 20, 26 | "dlr_factor": 0.5, 27 | "learning_rate": 0.003, 28 | "batch_size": 255, 29 | "dlr_step": 14, 30 | "normalize": true, 31 | "optim": "Adam", 32 | "init": "token_embeddings", 33 | "embeddings": "fasttextB_embeddings_300d.npy", 34 | "validate": true, 35 | "save_intermediate": true 36 | }, 37 | "extreme": { 38 | "num_epochs": 15, 39 | "dlr_factor": 0.5, 40 | "learning_rate": 0.0005, 41 | "batch_size": 255, 42 | "dlr_step": 14, 43 | "ns_method": "ensemble", 44 | "num_centroids": 300, 45 | "efC": 300, 46 | "efS": 400, 47 | "M": 100, 48 | "num_nbrs": 500, 49 | "ann_threads": 18, 50 | "beta": 0.5, 51 | "surrogate_mapping": null, 52 | "num_clf_partitions": 1, 53 | "optim": "Adam", 54 | "freeze_intermediate": true, 55 | "validate": true, 56 | "model_method": "shortlist", 57 | "normalize": true, 58 | "shortlist_method": "hybrid", 59 | "init": "intermediate", 60 | "use_shortlist": true, 61 | "use_intermediate_for_shorty": true 62 | }, 63 | "reranker": { 64 | "num_epochs": 10, 65 | "dlr_factor": 0.5, 66 | "learning_rate": 0.002, 67 | "batch_size": 255, 68 | "dlr_step": 7, 69 | "beta": 0.5, 70 | "num_clf_partitions": 1, 71 | "optim": "Adam", 72 | "validate": true, 73 | "model_method": "reranker", 74 | "shortlist_method": "static", 75 | "surrogate_mapping": null, 76 | "normalize": true, 77 | "use_shortlist": true, 78 | "init": "token_embeddings", 79 | "save_intermediate": false, 80 | "keep_invalid": true, 81 | "freeze_intermediate": false, 82 | "update_shortlist": false, 83 | "use_pretrained_shortlist": true, 84 | "embeddings": "fasttextB_embeddings_300d.npy" 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /deepxml/configs/DeepXML/LF-AmazonTitles-131K.json: -------------------------------------------------------------------------------- 1 | { 2 | "global": { 3 | "dataset": "LF-AmazonTitles-131K", 4 | "feature_type": "sparse", 5 | "num_labels": 131073, 6 | "arch": "Astec", 7 | "A": 0.6, 8 | "B": 2.6, 9 | "use_reranker": true, 10 | "surrogate_threshold": 65536, 11 | "surrogate_method": 1, 12 | "embedding_dims": 300, 13 | "top_k": 200, 14 | "beta": 0.10, 15 | "save_predictions": true, 16 | "trn_label_fname": "trn_X_Y.txt", 17 | "val_label_fname": "tst_X_Y.txt", 18 | "tst_label_fname": "tst_X_Y.txt", 19 | "trn_feat_fname": "trn_X_Xf.txt", 20 | "val_feat_fname": "tst_X_Xf.txt", 21 | "tst_feat_fname": "tst_X_Xf.txt" 22 | }, 23 | "surrogate": { 24 | "num_epochs": 20, 25 | "dlr_factor": 0.5, 26 | "learning_rate": 0.03, 27 | "batch_size": 255, 28 | "dlr_step": 14, 29 | "normalize": true, 30 | "optim": "Adam", 31 | "init": "token_embeddings", 32 | "embeddings": "fasttextB_embeddings_300d.npy", 33 | "validate": true, 34 | "save_intermediate": true 35 | }, 36 | "extreme": { 37 | "num_epochs": 15, 38 | "dlr_factor": 0.5, 39 | "learning_rate": 0.001, 40 | "batch_size": 255, 41 | "dlr_step": 14, 42 | "ns_method": "ensemble", 43 | "num_centroids": 1, 44 | "efC": 300, 45 | "efS": 400, 46 | "M": 100, 47 | "num_nbrs": 500, 48 | "ann_threads": 18, 49 | "beta": 0.5, 50 | "surrogate_mapping": null, 51 | "num_clf_partitions": 1, 52 | "optim": "Adam", 53 | "freeze_intermediate": true, 54 | "validate": true, 55 | "model_method": "shortlist", 56 | "normalize": true, 57 | "shortlist_method": "hybrid", 58 | "init": "intermediate", 59 | "use_shortlist": true, 60 | "use_intermediate_for_shorty": true 61 | }, 62 | "reranker": { 63 | "num_epochs": 10, 64 | "dlr_factor": 0.5, 65 | "learning_rate": 0.002, 66 | "batch_size": 255, 67 | "dlr_step": 8, 68 | "beta": 0.6, 69 | "num_clf_partitions": 1, 70 | "optim": "Adam", 71 | "validate": true, 72 | "model_method": "reranker", 73 | "shortlist_method": "static", 74 | "surrogate_mapping": null, 75 | "normalize": true, 76 | "use_shortlist": true, 77 | "init": "token_embeddings", 78 | "save_intermediate": false, 79 | "keep_invalid": true, 80 | "freeze_intermediate": false, 81 | "update_shortlist": false, 82 | "use_pretrained_shortlist": true, 83 | "embeddings": "fasttextB_embeddings_300d.npy" 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /deepxml/configs/DeepXML/LF-WikiSeeAlso-320K.json: -------------------------------------------------------------------------------- 1 | { 2 | "global": { 3 | "dataset": "LF-WikiSeeAlso-320K", 4 | "feature_type": "sparse", 5 | "num_labels": 312330, 6 | "arch": "Astec", 7 | "A": 0.55, 8 | "B": 1.5, 9 | "use_reranker": true, 10 | "surrogate_threshold": 65536, 11 | "surrogate_method": 1, 12 | "embedding_dims": 512, 13 | "top_k": 250, 14 | "beta": 0.30, 15 | "save_predictions": true, 16 | "trn_label_fname": "trn_X_Y.txt", 17 | "val_label_fname": "tst_X_Y.txt", 18 | "tst_label_fname": "tst_X_Y.txt", 19 | "trn_feat_fname": "trn_X_Xf.txt", 20 | "val_feat_fname": "tst_X_Xf.txt", 21 | "tst_feat_fname": "tst_X_Xf.txt" 22 | }, 23 | "surrogate": { 24 | "num_epochs": 20, 25 | "dlr_factor": 0.5, 26 | "learning_rate": 0.005, 27 | "batch_size": 255, 28 | "dlr_step": 14, 29 | "normalize": true, 30 | "optim": "Adam", 31 | "init": "token_embeddings", 32 | "embeddings": "fasttextB_embeddings_512d.npy", 33 | "validate": true, 34 | "save_intermediate": true 35 | }, 36 | "extreme": { 37 | "num_epochs": 15, 38 | "dlr_factor": 0.5, 39 | "learning_rate": 0.002, 40 | "batch_size": 255, 41 | "dlr_step": 14, 42 | "ns_method": "ensemble", 43 | "num_centroids": 1, 44 | "efC": 300, 45 | "efS": 400, 46 | "M": 100, 47 | "num_nbrs": 500, 48 | "ann_threads": 18, 49 | "beta": 0.5, 50 | "surrogate_mapping": null, 51 | "num_clf_partitions": 1, 52 | "optim": "Adam", 53 | "freeze_intermediate": true, 54 | "validate": true, 55 | "model_method": "shortlist", 56 | "normalize": true, 57 | "shortlist_method": "hybrid", 58 | "init": "intermediate", 59 | "use_shortlist": true, 60 | "use_intermediate_for_shorty": true 61 | }, 62 | "reranker": { 63 | "num_epochs": 12, 64 | "dlr_factor": 0.5, 65 | "learning_rate": 0.002, 66 | "batch_size": 255, 67 | "dlr_step": 7, 68 | "beta": 0.6, 69 | "num_clf_partitions": 1, 70 | "optim": "Adam", 71 | "validate": true, 72 | "model_method": "reranker", 73 | "shortlist_method": "static", 74 | "surrogate_mapping": null, 75 | "normalize": true, 76 | "use_shortlist": true, 77 | "init": "token_embeddings", 78 | "save_intermediate": false, 79 | "keep_invalid": true, 80 | "freeze_intermediate": false, 81 | "update_shortlist": false, 82 | "use_pretrained_shortlist": true, 83 | "embeddings": "fasttextB_embeddings_512d.npy" 84 | } 85 | } -------------------------------------------------------------------------------- /deepxml/configs/DeepXML/LF-WikiSeeAlsoTitles-320K.json: -------------------------------------------------------------------------------- 1 | { 2 | "global": { 3 | "dataset": "LF-WikiSeeAlsoTitles-320K", 4 | "feature_type": "sparse", 5 | "num_labels": 312330, 6 | "arch": "Astec", 7 | "A": 0.55, 8 | "B": 1.5, 9 | "use_reranker": true, 10 | "surrogate_threshold": 65536, 11 | "surrogate_method": 1, 12 | "embedding_dims": 300, 13 | "top_k": 200, 14 | "beta": 0.10, 15 | "save_predictions": true, 16 | "trn_label_fname": "trn_X_Y.txt", 17 | "val_label_fname": "tst_X_Y.txt", 18 | "tst_label_fname": "tst_X_Y.txt", 19 | "trn_feat_fname": "trn_X_Xf.txt", 20 | "val_feat_fname": "tst_X_Xf.txt", 21 | "tst_feat_fname": "tst_X_Xf.txt" 22 | }, 23 | "surrogate": { 24 | "num_epochs": 20, 25 | "dlr_factor": 0.5, 26 | "learning_rate": 0.005, 27 | "batch_size": 255, 28 | "dlr_step": 14, 29 | "normalize": true, 30 | "optim": "Adam", 31 | "init": "token_embeddings", 32 | "embeddings": "fasttextB_embeddings_300d.npy", 33 | "validate": true, 34 | "save_intermediate": true 35 | }, 36 | "extreme": { 37 | "num_epochs": 15, 38 | "dlr_factor": 0.5, 39 | "learning_rate": 0.002, 40 | "batch_size": 255, 41 | "dlr_step": 14, 42 | "ns_method": "ensemble", 43 | "num_centroids": 1, 44 | "efC": 300, 45 | "efS": 400, 46 | "M": 100, 47 | "num_nbrs": 500, 48 | "ann_threads": 18, 49 | "beta": 0.5, 50 | "surrogate_mapping": null, 51 | "num_clf_partitions": 1, 52 | "optim": "Adam", 53 | "freeze_intermediate": true, 54 | "validate": true, 55 | "model_method": "shortlist", 56 | "normalize": true, 57 | "shortlist_method": "hybrid", 58 | "init": "intermediate", 59 | "use_shortlist": true, 60 | "use_intermediate_for_shorty": true 61 | }, 62 | "reranker": { 63 | "num_epochs": 10, 64 | "dlr_factor": 0.5, 65 | "learning_rate": 0.002, 66 | "batch_size": 255, 67 | "dlr_step": 8, 68 | "beta": 0.6, 69 | "num_clf_partitions": 1, 70 | "optim": "Adam", 71 | "validate": true, 72 | "model_method": "reranker", 73 | "shortlist_method": "static", 74 | "surrogate_mapping": null, 75 | "normalize": true, 76 | "use_shortlist": true, 77 | "init": "token_embeddings", 78 | "save_intermediate": false, 79 | "keep_invalid": true, 80 | "freeze_intermediate": false, 81 | "update_shortlist": false, 82 | "use_pretrained_shortlist": true, 83 | "embeddings": "fasttextB_embeddings_300d.npy" 84 | } 85 | } -------------------------------------------------------------------------------- /deepxml/configs/DeepXML/LF-WikiTitles-500K.json: -------------------------------------------------------------------------------- 1 | { 2 | "global": { 3 | "dataset": "LF-WikiTitles-500K", 4 | "feature_type": "sparse", 5 | "num_labels": 501070, 6 | "arch": "Astec", 7 | "A": 0.5, 8 | "B": 0.4, 9 | "use_reranker": true, 10 | "surrogate_threshold": 65536, 11 | "surrogate_method": 1, 12 | "embedding_dims": 300, 13 | "top_k": 300, 14 | "beta": 0.60, 15 | "save_predictions": true, 16 | "trn_label_fname": "trn_X_Y.txt", 17 | "val_label_fname": "tst_X_Y.txt", 18 | "tst_label_fname": "tst_X_Y.txt", 19 | "trn_feat_fname": "trn_X_Xf.txt", 20 | "val_feat_fname": "tst_X_Xf.txt", 21 | "tst_feat_fname": "tst_X_Xf.txt" 22 | }, 23 | "surrogate": { 24 | "num_epochs": 20, 25 | "dlr_factor": 0.5, 26 | "learning_rate": 0.005, 27 | "batch_size": 255, 28 | "dlr_step": 14, 29 | "normalize": true, 30 | "optim": "Adam", 31 | "init": "token_embeddings", 32 | "embeddings": "fasttextB_embeddings_300d.npy", 33 | "save_intermediate": true 34 | }, 35 | "extreme": { 36 | "num_epochs": 15, 37 | "dlr_factor": 0.5, 38 | "learning_rate": 0.0005, 39 | "batch_size": 255, 40 | "dlr_step": 10, 41 | "ns_method": "ensemble", 42 | "num_centroids": 300, 43 | "efC": 300, 44 | "efS": 400, 45 | "M": 100, 46 | "num_nbrs": 500, 47 | "ann_threads": 12, 48 | "beta": 0.5, 49 | "surrogate_mapping": null, 50 | "num_clf_partitions": 1, 51 | "optim": "Adam", 52 | "freeze_intermediate": true, 53 | "validate": true, 54 | "model_method": "shortlist", 55 | "normalize": true, 56 | "shortlist_method": "hybrid", 57 | "init": "intermediate", 58 | "use_shortlist": true, 59 | "use_intermediate_for_shorty": true 60 | }, 61 | "reranker": { 62 | "num_epochs": 10, 63 | "dlr_factor": 0.5, 64 | "learning_rate": 0.002, 65 | "batch_size": 255, 66 | "dlr_step": 8, 67 | "beta": 0.6, 68 | "num_clf_partitions": 1, 69 | "optim": "Adam", 70 | "validate": true, 71 | "model_method": "reranker", 72 | "shortlist_method": "static", 73 | "surrogate_mapping": null, 74 | "normalize": true, 75 | "use_shortlist": true, 76 | "init": "token_embeddings", 77 | "save_intermediate": false, 78 | "keep_invalid": true, 79 | "freeze_intermediate": false, 80 | "update_shortlist": false, 81 | "use_pretrained_shortlist": true, 82 | "embeddings": "fasttextB_embeddings_300d.npy" 83 | } 84 | } -------------------------------------------------------------------------------- /deepxml/configs/DeepXML/WikiSeeAlsoTitles-350K.json: -------------------------------------------------------------------------------- 1 | { 2 | "global": { 3 | "dataset": "WikiSeeAlsoTitles-350K", 4 | "feature_type": "sparse", 5 | "num_labels": 352072, 6 | "arch": "Astec", 7 | "A": 0.55, 8 | "B": 1.5, 9 | "use_reranker": true, 10 | "surrogate_threshold": 65536, 11 | "surrogate_method": 1, 12 | "embedding_dims": 300, 13 | "top_k": 200, 14 | "beta": 0.30, 15 | "save_predictions": true, 16 | "trn_label_fname": "trn_X_Y.txt", 17 | "val_label_fname": "tst_X_Y.txt", 18 | "tst_label_fname": "tst_X_Y.txt", 19 | "trn_feat_fname": "trn_X_Xf.txt", 20 | "val_feat_fname": "tst_X_Xf.txt", 21 | "tst_feat_fname": "tst_X_Xf.txt" 22 | }, 23 | "surrogate": { 24 | "num_epochs": 20, 25 | "dlr_factor": 0.5, 26 | "learning_rate": 0.005, 27 | "batch_size": 255, 28 | "dlr_step": 14, 29 | "normalize": true, 30 | "optim": "Adam", 31 | "init": "token_embeddings", 32 | "embeddings": "fasttextB_embeddings_300d.npy", 33 | "validate": true, 34 | "save_intermediate": true 35 | }, 36 | "extreme": { 37 | "num_epochs": 15, 38 | "dlr_factor": 0.5, 39 | "learning_rate": 0.002, 40 | "batch_size": 255, 41 | "dlr_step": 14, 42 | "ns_method": "ensemble", 43 | "num_centroids": 1, 44 | "efC": 300, 45 | "efS": 400, 46 | "M": 100, 47 | "num_nbrs": 500, 48 | "ann_threads": 18, 49 | "beta": 0.5, 50 | "surrogate_mapping": null, 51 | "num_clf_partitions": 1, 52 | "optim": "Adam", 53 | "freeze_intermediate": true, 54 | "validate": true, 55 | "model_method": "shortlist", 56 | "normalize": true, 57 | "shortlist_method": "hybrid", 58 | "init": "intermediate", 59 | "use_shortlist": true, 60 | "use_intermediate_for_shorty": true 61 | }, 62 | "reranker": { 63 | "num_epochs": 12, 64 | "dlr_factor": 0.5, 65 | "learning_rate": 0.002, 66 | "batch_size": 255, 67 | "dlr_step": 7, 68 | "beta": 0.6, 69 | "num_clf_partitions": 1, 70 | "optim": "Adam", 71 | "validate": true, 72 | "model_method": "reranker", 73 | "shortlist_method": "static", 74 | "surrogate_mapping": null, 75 | "normalize": true, 76 | "use_shortlist": true, 77 | "init": "token_embeddings", 78 | "save_intermediate": false, 79 | "keep_invalid": true, 80 | "freeze_intermediate": false, 81 | "update_shortlist": false, 82 | "use_pretrained_shortlist": true, 83 | "embeddings": "fasttextB_embeddings_300d.npy" 84 | } 85 | } -------------------------------------------------------------------------------- /deepxml/configs/DeepXML/WikiTitles-500K.json: -------------------------------------------------------------------------------- 1 | { 2 | "global": { 3 | "dataset": "WikiTitles-500K", 4 | "feature_type": "sparse", 5 | "num_labels": 501070, 6 | "arch": "Astec", 7 | "A": 0.5, 8 | "B": 0.4, 9 | "use_reranker": true, 10 | "surrogate_threshold": 65536, 11 | "surrogate_method": 1, 12 | "embedding_dims": 300, 13 | "beta": 0.60, 14 | "top_k": 300, 15 | "save_predictions": true, 16 | "trn_label_fname": "trn_X_Y.txt", 17 | "val_label_fname": "tst_X_Y.txt", 18 | "tst_label_fname": "tst_X_Y.txt", 19 | "trn_feat_fname": "trn_X_Xf.txt", 20 | "val_feat_fname": "tst_X_Xf.txt", 21 | "tst_feat_fname": "tst_X_Xf.txt" 22 | }, 23 | "surrogate": { 24 | "num_epochs": 20, 25 | "dlr_factor": 0.5, 26 | "learning_rate": 0.005, 27 | "batch_size": 255, 28 | "dlr_step": 14, 29 | "normalize": true, 30 | "optim": "Adam", 31 | "init": "token_embeddings", 32 | "embeddings": "fasttextB_embeddings_300d.npy", 33 | "validate": true, 34 | "save_intermediate": true 35 | }, 36 | "extreme": { 37 | "num_epochs": 12, 38 | "dlr_factor": 0.5, 39 | "learning_rate": 0.0005, 40 | "batch_size": 255, 41 | "dlr_step": 10, 42 | "ns_method": "ensemble", 43 | "num_centroids": 300, 44 | "efC": 300, 45 | "efS": 400, 46 | "M": 100, 47 | "num_nbrs": 500, 48 | "ann_threads": 18, 49 | "beta": 0.5, 50 | "surrogate_mapping": null, 51 | "num_clf_partitions": 1, 52 | "optim": "Adam", 53 | "freeze_intermediate": true, 54 | "validate": true, 55 | "model_method": "shortlist", 56 | "normalize": true, 57 | "shortlist_method": "hybrid", 58 | "init": "intermediate", 59 | "use_shortlist": true, 60 | "use_intermediate_for_shorty": true 61 | }, 62 | "reranker": { 63 | "num_epochs": 10, 64 | "dlr_factor": 0.5, 65 | "learning_rate": 0.002, 66 | "batch_size": 255, 67 | "dlr_step": 7, 68 | "beta": 0.5, 69 | "num_clf_partitions": 1, 70 | "optim": "Adam", 71 | "validate": true, 72 | "model_method": "reranker", 73 | "shortlist_method": "static", 74 | "surrogate_mapping": null, 75 | "normalize": true, 76 | "use_shortlist": true, 77 | "init": "token_embeddings", 78 | "save_intermediate": false, 79 | "keep_invalid": true, 80 | "freeze_intermediate": false, 81 | "update_shortlist": false, 82 | "use_pretrained_shortlist": true, 83 | "embeddings": "fasttextB_embeddings_300d.npy" 84 | } 85 | } -------------------------------------------------------------------------------- /deepxml/libs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /deepxml/libs/collate_fn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn.utils.rnn import pad_sequence 4 | 5 | 6 | def pad_and_collate(x, pad_val=0, dtype=torch.FloatTensor): 7 | """ 8 | A generalized function for padding batch using utils.rnn.pad_sequence 9 | * pad as per the maximum length in the batch 10 | * returns a collated tensor 11 | 12 | Arguments: 13 | --------- 14 | x: iterator 15 | iterator over np.ndarray that needs to be converted to 16 | tensors and padded 17 | pad_val: float 18 | pad tensor with this value 19 | will cast the value as per the data type 20 | dtype: datatype, optional (default=torch.FloatTensor) 21 | tensor should be of this type 22 | """ 23 | return pad_sequence([torch.from_numpy(z) for z in x], 24 | batch_first=True, padding_value=pad_val).type(dtype) 25 | 26 | 27 | def collate_dense(x, dtype=torch.FloatTensor): 28 | """ 29 | Collate dense documents/labels and returns 30 | 31 | Arguments: 32 | --------- 33 | x: iterator 34 | iterator over np.ndarray that needs to be converted to 35 | tensors and padded 36 | dtype: datatype, optional (default=torch.FloatTensor) 37 | features should be of this type 38 | """ 39 | return torch.stack([torch.from_numpy(z) for z in x], 0).type(dtype) 40 | 41 | 42 | def collate_sparse(x, pad_val=0.0, has_weight=False, dtype=torch.FloatTensor): 43 | """ 44 | Collate sparse documents 45 | * Can handle with or without weights 46 | * Expects an iterator over tuples if has_weight=True 47 | 48 | Arguments: 49 | --------- 50 | x: iterator 51 | iterator over data points which can be 52 | np.array or tuple of np.ndarray depending on has_weight 53 | pad_val: list or float, optional, default=(0.0) 54 | padding value for indices and weights 55 | * expects a list when has_weight=True 56 | has_weight: bool, optional, default=False 57 | If entries have weights 58 | * True: objects are tuples of np.ndarrays 59 | 0: indices, 1: weights 60 | * False: objects are np.ndarrays 61 | dtypes: list or dtype, optional (default=torch.FloatTensor) 62 | dtypes of indices and values 63 | * expects a list when has_weight=True 64 | """ 65 | weights = None 66 | if has_weight: 67 | x = list(x) 68 | indices = pad_and_collate(map(lambda z: z[0], x), pad_val[0], dtype[0]) 69 | weights = pad_and_collate(map(lambda z: z[1], x), pad_val[1], dtype[1]) 70 | else: 71 | indices = pad_and_collate(x, pad_val, dtype) 72 | return indices, weights 73 | 74 | 75 | def get_iterator(x, ind=None): 76 | if ind is None: 77 | return map(lambda z: z, x) 78 | else: 79 | return map(lambda z: z[ind], x) 80 | 81 | 82 | def construct_collate_fn(feature_type, classifier_type, num_partitions=1): 83 | def _collate_fn_dense_full(batch): 84 | return collate_fn_dense_full(batch, num_partitions) 85 | 86 | def _collate_fn_dense(batch): 87 | return collate_fn_dense(batch) 88 | 89 | def _collate_fn_sparse(batch): 90 | return collate_fn_sparse(batch) 91 | 92 | def _collate_fn_dense_sl(batch): 93 | return collate_fn_dense_sl(batch, num_partitions) 94 | 95 | def _collate_fn_sparse_full(batch): 96 | return collate_fn_sparse_full(batch, num_partitions) 97 | 98 | def _collate_fn_sparse_sl(batch): 99 | return collate_fn_sparse_sl(batch, num_partitions) 100 | 101 | if feature_type == 'dense': 102 | if classifier_type == 'None': 103 | return _collate_fn_dense 104 | elif classifier_type == 'shortlist': 105 | return _collate_fn_dense_sl 106 | else: 107 | return _collate_fn_dense_full 108 | else: 109 | if classifier_type == 'None': 110 | return _collate_fn_sparse 111 | elif classifier_type == 'shortlist': 112 | return _collate_fn_sparse_sl 113 | else: 114 | return _collate_fn_sparse_full 115 | 116 | 117 | def collate_fn_sparse_sl(batch, num_partitions): 118 | """ 119 | Combine each sample in a batch with shortlist 120 | For sparse features 121 | """ 122 | _is_partitioned = True if num_partitions > 1 else False 123 | batch_data = {'batch_size': len(batch), 'X_ind': None} 124 | batch_data['batch_size'] = len(batch) 125 | batch_data['X_ind'], batch_data['X'] = collate_sparse( 126 | get_iterator(batch, 0), pad_val=[0, 0.0], has_weight=True, 127 | dtype=[torch.LongTensor, torch.FloatTensor]) 128 | 129 | z = list(get_iterator(batch, 1)) 130 | if _is_partitioned: 131 | batch_data['Y_s'] = [collate_dense( 132 | get_iterator(get_iterator(z, 0), idx), dtype=torch.LongTensor) 133 | for idx in range(num_partitions)] 134 | batch_data['Y'] = [collate_dense( 135 | get_iterator(get_iterator(z, 1), idx), dtype=torch.FloatTensor) 136 | for idx in range(num_partitions)] 137 | batch_data['Y_sim'] = [collate_dense( 138 | get_iterator(get_iterator(z, 2), idx), dtype=torch.FloatTensor) 139 | for idx in range(num_partitions)] 140 | batch_data['Y_mask'] = [collate_dense( 141 | get_iterator(get_iterator(z, 3), idx), dtype=torch.BoolTensor) 142 | for idx in range(num_partitions)] 143 | batch_data['Y_map'] = collate_dense( 144 | get_iterator(z, 4), dtype=torch.LongTensor) 145 | else: 146 | batch_data['Y_s'] = collate_dense( 147 | get_iterator(z, 0), dtype=torch.LongTensor) 148 | batch_data['Y'] = collate_dense( 149 | get_iterator(z, 1), dtype=torch.FloatTensor) 150 | batch_data['Y_sim'] = collate_dense( 151 | get_iterator(z, 2), dtype=torch.FloatTensor) 152 | batch_data['Y_mask'] = collate_dense( 153 | get_iterator(z, 3), dtype=torch.BoolTensor) 154 | return batch_data 155 | 156 | 157 | def collate_fn_dense_sl(batch, num_partitions): 158 | """ 159 | Combine each sample in a batch with shortlist 160 | For dense features 161 | """ 162 | _is_partitioned = True if num_partitions > 1 else False 163 | batch_data = {'batch_size': len(batch), 'X_ind': None} 164 | batch_data['X'] = collate_dense(get_iterator(batch, 0)) 165 | 166 | z = list(get_iterator(batch, 1)) 167 | if _is_partitioned: 168 | batch_data['Y_s'] = [collate_dense( 169 | get_iterator(get_iterator(z, 0), idx), dtype=torch.LongTensor) 170 | for idx in range(num_partitions)] 171 | batch_data['Y'] = [collate_dense( 172 | get_iterator(get_iterator(z, 1), idx), dtype=torch.FloatTensor) 173 | for idx in range(num_partitions)] 174 | batch_data['Y_sim'] = [collate_dense( 175 | get_iterator(get_iterator(z, 2), idx), dtype=torch.FloatTensor) 176 | for idx in range(num_partitions)] 177 | batch_data['Y_mask'] = [collate_dense( 178 | get_iterator(get_iterator(z, 3), idx), dtype=torch.BoolTensor) 179 | for idx in range(num_partitions)] 180 | batch_data['Y_map'] = collate_dense( 181 | get_iterator(z, 4), dtype=torch.LongTensor) 182 | else: 183 | batch_data['Y_s'] = collate_dense( 184 | get_iterator(z, 0), dtype=torch.LongTensor) 185 | batch_data['Y'] = collate_dense( 186 | get_iterator(z, 1), dtype=torch.FloatTensor) 187 | batch_data['Y_sim'] = collate_dense( 188 | get_iterator(z, 2), dtype=torch.FloatTensor) 189 | batch_data['Y_mask'] = collate_dense( 190 | get_iterator(z, 3), dtype=torch.BoolTensor) 191 | return batch_data 192 | 193 | 194 | def collate_fn_dense_full(batch, num_partitions): 195 | """ 196 | Combine each sample in a batch 197 | For dense features 198 | """ 199 | _is_partitioned = True if num_partitions > 1 else False 200 | batch_data = {'batch_size': len(batch), 'X_ind': None} 201 | batch_data['X'] = collate_dense(get_iterator(batch, 0)) 202 | if _is_partitioned: 203 | batch_data['Y'] = [collate_dense( 204 | get_iterator(get_iterator(batch, 1), idx)) 205 | for idx in range(self.num_partitions)] 206 | else: 207 | batch_data['Y'] = collate_dense(get_iterator(batch, 1)) 208 | return batch_data 209 | 210 | 211 | def collate_fn_sparse_full(batch, num_partitions): 212 | """ 213 | Combine each sample in a batch 214 | For sparse features 215 | """ 216 | _is_partitioned = True if num_partitions > 1 else False 217 | batch_data = {'batch_size': len(batch), 'X_ind': None} 218 | batch_data['X_ind'], batch_data['X'] = collate_sparse( 219 | get_iterator(batch, 0), pad_val=[0, 0.0], has_weight=True, 220 | dtype=[torch.LongTensor, torch.FloatTensor]) 221 | if _is_partitioned: 222 | batch_data['Y'] = [collate_dense( 223 | get_iterator(get_iterator(batch, 1), idx)) 224 | for idx in range(self.num_partitions)] 225 | else: 226 | batch_data['Y'] = collate_dense(get_iterator(batch, 1)) 227 | return batch_data 228 | 229 | 230 | def collate_fn_sparse(batch): 231 | """ 232 | Combine each sample in a batch 233 | For sparse features 234 | """ 235 | batch_data = {'batch_size': len(batch), 'X_ind': None} 236 | batch_data['X_ind'], batch_data['X'] = collate_sparse( 237 | get_iterator(batch), pad_val=[0, 0.0], has_weight=True, 238 | dtype=[torch.LongTensor, torch.FloatTensor]) 239 | return batch_data 240 | 241 | 242 | def collate_fn_dense(batch): 243 | """ 244 | Combine each sample in a batch 245 | For sparse features 246 | """ 247 | batch_data = {'batch_size': len(batch), 'X_ind': None} 248 | batch_data['X'] = collate_dense(get_iterator(batch)) 249 | return batch_data 250 | -------------------------------------------------------------------------------- /deepxml/libs/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from .dataset_base import DatasetBase, DatasetTensor 4 | from .dist_utils import Partitioner 5 | from xclib.utils.sparse import _map 6 | from .shortlist_handler import construct_handler 7 | 8 | 9 | def construct_dataset(data_dir, fname_features, fname_labels, data=None, 10 | model_dir='', mode='train', size_shortlist=-1, 11 | normalize_features=True, normalize_labels=True, 12 | keep_invalid=False, feature_type='sparse', 13 | num_clf_partitions=1, feature_indices=None, 14 | label_indices=None, shortlist_method='static', 15 | shorty=None, surrogate_mapping=None, _type='full', 16 | pretrained_shortlist=None): 17 | if _type == 'full': # with OVA classifier 18 | return DatasetFull( 19 | data_dir, fname_features, fname_labels, data, model_dir, mode, 20 | feature_indices, label_indices, keep_invalid, normalize_features, 21 | normalize_labels, num_clf_partitions, feature_type, surrogate_mapping) 22 | elif _type == 'shortlist': # with a shortlist 23 | # Construct dataset for sparse data 24 | return DatasetShortlist( 25 | data_dir, fname_features, fname_labels, data, model_dir, mode, 26 | feature_indices, label_indices, keep_invalid, normalize_features, 27 | normalize_labels, num_clf_partitions, size_shortlist, 28 | feature_type, shortlist_method, shorty, surrogate_mapping, 29 | pretrained_shortlist=pretrained_shortlist) 30 | elif _type == 'tensor': 31 | return DatasetTensor( 32 | data_dir, fname_features, data, feature_indices, 33 | normalize_features, feature_type) 34 | else: 35 | raise NotImplementedError("Unknown dataset type") 36 | 37 | 38 | class DatasetFull(DatasetBase): 39 | """Dataset to load and use XML-Datasets with full output space only 40 | 41 | Arguments 42 | --------- 43 | data_dir: str 44 | data files are stored in this directory 45 | fname_features: str 46 | feature file (libsvm or pickle) 47 | fname_labels: str 48 | labels file (libsvm or pickle) 49 | data: dict, optional, default=None 50 | Read data directly from this obj rather than files 51 | Files are ignored if this is not None 52 | Keys: 'X', 'Y' 53 | model_dir: str, optional, default='' 54 | Dump data like valid labels here 55 | mode: str, optional, default='train' 56 | Mode of the dataset 57 | feature_indices: str or None, optional, default=None 58 | Train with selected features only (read from file) 59 | label_indices: str or None, optional, default=None 60 | Train for selected labels only (read from file) 61 | keep_invalid: bool, optional, default=False 62 | Don't touch data points or labels 63 | normalize_features: bool, optional, default=True 64 | Normalize data points to unit norm 65 | normalize_lables: bool, optional, default=False 66 | Normalize labels to convert in probabilities 67 | Useful in-case on non-binary labels 68 | num_clf_partitions: int, optional, default=1 69 | Partition classifier in multiple parts 70 | Support for multiple GPUs 71 | feature_type: str, optional, default='sparse' 72 | sparse or dense features 73 | label_type: str, optional, default='dense' 74 | sparse (i.e. with shortlist) or dense (OVA) labels 75 | surrogate_mapping: str, optional, default=None 76 | Re-map clusters as per given mapping 77 | e.g. when labels are clustered 78 | """ 79 | 80 | def __init__(self, data_dir, fname_features, fname_labels, data=None, 81 | model_dir='', mode='train', feature_indices=None, 82 | label_indices=None, keep_invalid=False, 83 | normalize_features=True, normalize_labels=False, 84 | num_clf_partitions=1, feature_type='sparse', 85 | surrogate_mapping=None, label_type='dense'): 86 | super().__init__(data_dir, fname_features, fname_labels, data, 87 | model_dir, mode, feature_indices, label_indices, 88 | keep_invalid, normalize_features, normalize_labels, 89 | feature_type, label_type) 90 | if self.mode == 'train': 91 | # Remove samples w/o any feature or label 92 | self._remove_samples_wo_features_and_labels() 93 | if not keep_invalid and self.labels._valid: 94 | # Remove labels w/o any positive instance 95 | self._process_labels(model_dir, surrogate_mapping) 96 | self.feature_type = feature_type 97 | self.partitioner = None 98 | self.num_clf_partitions = 1 99 | if self.labels._valid: # If no labels are provided 100 | self.num_clf_partitions = num_clf_partitions 101 | if self.mode == 'train': 102 | assert self.labels._valid, "Labels can not be None while training." 103 | if self.num_clf_partitions > 1: 104 | self.partitioner = Partitioner( 105 | self.num_labels, self.num_clf_partitions, 106 | padding=False, contiguous=True) 107 | self.partitioner.save(os.path.join( 108 | self.model_dir, 'partitionar.pkl')) 109 | else: 110 | if self.num_clf_partitions > 1: 111 | self.partitioner = Partitioner( 112 | self.num_labels, self.num_clf_partitions, 113 | padding=False, contiguous=True) 114 | self.partitioner.load(os.path.join( 115 | self.model_dir, 'partitionar.pkl')) 116 | 117 | # TODO Take care of this select and padding index 118 | self.label_padding_index = self.num_labels 119 | 120 | def _process_labels(self, model_dir, surrogate_mapping): 121 | super()._process_labels(model_dir) 122 | # if surrogate task is clustered labels 123 | if surrogate_mapping is not None: 124 | print("Surrogate mapping is not None, mapping labels") 125 | surrogate_mapping = np.loadtxt(surrogate_mapping, dtype=np.int) 126 | _num_labels = len(np.unique(surrogate_mapping)) 127 | mapping = dict( 128 | zip(range(len(surrogate_mapping)), surrogate_mapping)) 129 | self.labels.Y = _map(self.labels.Y, mapping=mapping, 130 | shape=(self.num_instances, _num_labels), 131 | axis=1) 132 | self.labels.binarize() 133 | 134 | def __getitem__(self, index): 135 | """ 136 | Get features and labels for index 137 | Args: 138 | index: for this sample 139 | Returns: 140 | features: : non zero entries 141 | labels: : numpy array 142 | 143 | """ 144 | x = self.features[index] 145 | y = self.labels[index] 146 | if self.partitioner is not None: # Split if required 147 | y = self.partitioner.split(y) 148 | return x, y 149 | 150 | 151 | class DatasetShortlist(DatasetBase): 152 | """Dataset to load and use XML-Datasets with shortlist 153 | 154 | Arguments 155 | --------- 156 | data_dir: str 157 | data files are stored in this directory 158 | fname_features: str 159 | feature file (libsvm or pickle) 160 | fname_labels: str 161 | labels file (libsvm or pickle) 162 | data: dict, optional, default=None 163 | Read data directly from this obj rather than files 164 | Files are ignored if this is not None 165 | Keys: 'X', 'Y' 166 | model_dir: str, optional, default='' 167 | Dump data like valid labels here 168 | mode: str, optional, default='train' 169 | Mode of the dataset 170 | feature_indices: str or None, optional, default=None 171 | Train with selected features only (read from file) 172 | label_indices: str or None, optional, default=None 173 | Train for selected labels only (read from file) 174 | keep_invalid: bool, optional, default=False 175 | Don't touch data points or labels 176 | normalize_features: bool, optional, default=True 177 | Normalize data points to unit norm 178 | normalize_lables: bool, optional, default=False 179 | Normalize labels to convert in probabilities 180 | Useful in-case on non-binary labels 181 | num_clf_partitions: int, optional, default=1 182 | Partition classifier in multiple 183 | Support for multiple GPUs 184 | feature_type: str, optional, default='sparse' 185 | sparse or dense features 186 | shortlist_type: str, optional, default='static' 187 | type of shortlist (static or dynamic) 188 | shorty: obj, optional, default=None 189 | Useful in-case of dynamic shortlist 190 | surrogate_mapping: str, optional, default=None 191 | Re-map clusters as per given mapping 192 | e.g. when labels are clustered 193 | label_type: str, optional, default='dense' 194 | sparse (i.e. with shortlist) or dense (OVA) labels 195 | shortlist_in_memory: boolean, optional, default=True 196 | Keep shortlist in memory if True otherwise keep on disk 197 | pretrained_shortlist: None or str, optional, default=None 198 | Pre-trained shortlist (useful in a re-ranker) 199 | """ 200 | 201 | def __init__(self, data_dir, fname_features, fname_labels, data=None, 202 | model_dir='', mode='train', feature_indices=None, 203 | label_indices=None, keep_invalid=False, 204 | normalize_features=True, normalize_labels=False, 205 | num_clf_partitions=1, size_shortlist=-1, 206 | feature_type='sparse', shortlist_method='static', 207 | shorty=None, surrogate_mapping=None, label_type='sparse', 208 | shortlist_in_memory=True, pretrained_shortlist=None): 209 | super().__init__(data_dir, fname_features, fname_labels, data, 210 | model_dir, mode, feature_indices, label_indices, 211 | keep_invalid, normalize_features, normalize_labels, 212 | feature_type, label_type) 213 | if self.labels is None: 214 | NotImplementedError( 215 | "No support for shortlist w/o any label, \ 216 | consider using dense dataset.") 217 | self.feature_type = feature_type 218 | self.num_clf_partitions = num_clf_partitions 219 | self.shortlist_in_memory = shortlist_in_memory 220 | self.size_shortlist = size_shortlist 221 | self.shortlist_method = shortlist_method 222 | if self.mode == 'train': 223 | # Remove samples w/o any feature or label 224 | if pretrained_shortlist is None: 225 | self._remove_samples_wo_features_and_labels() 226 | 227 | if not keep_invalid: 228 | # Remove labels w/o any positive instance 229 | self._process_labels(model_dir, surrogate_mapping) 230 | 231 | self.shortlist = construct_handler( 232 | shortlist_type=shortlist_method, 233 | num_instances=self.num_instances, 234 | num_labels=self.num_labels, 235 | model_dir=model_dir, 236 | shorty=shorty, 237 | mode=mode, 238 | size_shortlist=size_shortlist, 239 | label_mapping=None, 240 | in_memory=shortlist_in_memory, 241 | corruption=150, 242 | fname=pretrained_shortlist) 243 | self.use_shortlist = True if self.size_shortlist > 0 else False 244 | self.label_padding_index = self.num_labels 245 | 246 | def _process_labels(self, model_dir, surrogate_mapping): 247 | super()._process_labels(model_dir) 248 | # if surrogate task is clustered labels 249 | if surrogate_mapping is not None: 250 | surrogate_mapping = np.loadtxt(surrogate_mapping, dtype=np.int) 251 | _num_labels = len(np.unique(surrogate_mapping)) 252 | mapping = dict( 253 | zip(range(len(surrogate_mapping)), surrogate_mapping)) 254 | self.labels.Y = _map(self.labels.Y, mapping=mapping, 255 | shape=(self.num_instances, _num_labels), 256 | axis=1) 257 | self.labels.binarize() 258 | 259 | def update_shortlist(self, ind, sim, fname='tmp', idx=-1): 260 | """Update label shortlist for each instance 261 | """ 262 | self.shortlist.update_shortlist(ind, sim, fname) 263 | 264 | def save_shortlist(self, fname): 265 | """Save label shortlist and distance for each instance 266 | """ 267 | self.shortlist.save_shortlist(fname) 268 | 269 | def load_shortlist(self, fname): 270 | """Load label shortlist and distance for each instance 271 | """ 272 | self.shortlist.load_shortlist(fname) 273 | 274 | def get_shortlist(self, index): 275 | """ 276 | Get data with shortlist for given data index 277 | """ 278 | pos_labels, _ = self.labels[index] 279 | return self.shortlist.get_shortlist(index, pos_labels) 280 | 281 | def __getitem__(self, index): 282 | """Get features and labels for index 283 | Arguments 284 | --------- 285 | index: int 286 | data for this index 287 | Returns 288 | ------- 289 | features: np.ndarray or tuple 290 | for dense: np.ndarray 291 | for sparse: feature indices and their weights 292 | labels: tuple 293 | shortlist: label indices in the shortlist 294 | labels_mask: 1 for relevant; 0 otherwise 295 | dist: distance (used during prediction only) 296 | """ 297 | x = self.features[index] 298 | y = self.get_shortlist(index) 299 | return x, y 300 | -------------------------------------------------------------------------------- /deepxml/libs/dataset_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import os 4 | import numpy as np 5 | from .features import construct as construct_f 6 | from .labels import construct as construct_l 7 | 8 | 9 | class DatasetTensor(torch.utils.data.Dataset): 10 | """Dataset to load and use sparse/dense matrix 11 | * Support npz, pickle, npy or libsvm file format 12 | * Useful when iterating over features or labels 13 | 14 | Arguments 15 | --------- 16 | data_dir: str 17 | data files are stored in this directory 18 | fname: str 19 | file name file (libsvm or npy or npz or pkl) 20 | will use 'X' key in case of pickle 21 | data: scipy.sparse or np.ndarray, optional, default=None 22 | Read data directly from this obj rather than files 23 | Files are ignored if this is not None 24 | indices: None or str, optional, default=None 25 | Use only these indices in the given list 26 | normalize: bool, optional, default=True 27 | Normalize the rows to unit norm 28 | _type: str, optional, default='sparse' 29 | Type of data (sparse/dense) 30 | """ 31 | 32 | def __init__(self, data_dir, fname, data=None, indices=None, 33 | normalize=True, _type='sparse'): 34 | self.data = self.construct( 35 | data_dir, fname, data, indices, normalize, _type) 36 | 37 | def construct(self, data_dir, fname, data, indices, normalize, _type): 38 | data = construct_f(data_dir, fname, data, normalize, _type) 39 | if indices is not None: 40 | indices = np.loadtxt(indices, dtype=np.int64) 41 | data._index_select(indices) 42 | return data 43 | 44 | def __len__(self): 45 | return self.num_instances 46 | 47 | @property 48 | def num_instances(self): 49 | return self.data.num_instances 50 | 51 | def __getitem__(self, index): 52 | """Get data for a given index 53 | Arguments 54 | --------- 55 | index: int 56 | data for this index 57 | Returns 58 | ------- 59 | features: tuple 60 | feature indices and their weights 61 | """ 62 | return self.data[index] 63 | 64 | 65 | class DatasetBase(torch.utils.data.Dataset): 66 | """Dataset to load and use XML-Datasets 67 | Support pickle or libsvm file format 68 | 69 | Arguments 70 | --------- 71 | data_dir: str 72 | data files are stored in this directory 73 | fname_features: str 74 | feature file (libsvm or pickle) 75 | fname_labels: str 76 | labels file (libsvm or pickle) 77 | data: dict, optional, default=None 78 | Read data directly from this obj rather than files 79 | Files are ignored if this is not None 80 | Keys: 'X', 'Y' 81 | model_dir: str, optional, default='' 82 | Dump data like valid labels here 83 | mode: str, optional, default='train' 84 | Mode of the dataset 85 | feature_indices: str or None, optional, default=None 86 | Train with selected features only (read from file) 87 | label_indices: str or None, optional, default=None 88 | Train for selected labels only (read from file) 89 | keep_invalid: bool, optional, default=False 90 | Don't touch data points or labels 91 | normalize_features: bool, optional, default=True 92 | Normalize data points to unit norm 93 | normalize_lables: bool, optional, default=False 94 | Normalize labels to convert in probabilities 95 | Useful in-case on non-binary labels 96 | feature_type: str, optional, default='sparse' 97 | sparse or dense features 98 | label_type: str, optional, default='dense' 99 | sparse (i.e. with shortlist) or dense (OVA) labels 100 | """ 101 | 102 | def __init__(self, data_dir, fname_features, fname_labels, 103 | data=None, model_dir='', mode='train', 104 | feature_indices=None, label_indices=None, 105 | keep_invalid=False, normalize_features=True, 106 | normalize_lables=False, feature_type='sparse', 107 | label_type='dense'): 108 | self.data_dir = data_dir 109 | self.mode = mode 110 | self.features, self.labels = self.load_data( 111 | data_dir, fname_features, fname_labels, data, 112 | normalize_features, normalize_lables, 113 | feature_type, label_type) 114 | self._split = None 115 | self.index_select(feature_indices, label_indices) 116 | self.model_dir = model_dir 117 | self.data_dir = data_dir 118 | self.label_padding_index = self.num_labels 119 | 120 | def _remove_samples_wo_features_and_labels(self): 121 | """Remove instances if they don't have any feature or label 122 | """ 123 | indices = self.features.get_valid_indices(axis=1) 124 | if self.labels is not None: 125 | indices_labels = self.labels.get_valid_indices(axis=1) 126 | indices = np.intersect1d(indices, indices_labels) 127 | self.labels._index_select(indices, axis=0) 128 | self.features._index_select(indices, axis=0) 129 | 130 | def index_select(self, feature_indices, label_indices): 131 | """Transform feature and label matrix to specified 132 | features/labels only 133 | """ 134 | def _get_split_id(fname): 135 | """Split ID (or quantile) from file name 136 | """ 137 | idx = fname.split("_")[-1].split(".")[0] 138 | return idx 139 | if label_indices is not None: 140 | self._split = _get_split_id(label_indices) 141 | label_indices = np.loadtxt(label_indices, dtype=np.int32) 142 | self.labels._index_select(label_indices, axis=1) 143 | if feature_indices is not None: 144 | self._split = _get_split_id(feature_indices) 145 | feature_indices = np.loadtxt(feature_indices, dtype=np.int32) 146 | self.features._index_select(feature_indices, axis=1) 147 | 148 | def load_features(self, data_dir, fname, X, 149 | normalize_features, feature_type): 150 | """Load features from given file 151 | Features can also be supplied directly 152 | """ 153 | return construct_f(data_dir, fname, X, 154 | normalize_features, feature_type) 155 | 156 | def load_labels(self, data_dir, fname, Y, normalize_labels, label_type): 157 | """Load labels from given file 158 | Labels can also be supplied directly 159 | """ 160 | labels = construct_l(data_dir, fname, Y, normalize_labels, 161 | label_type) # Pass dummy labels if required 162 | if normalize_labels: 163 | if self.mode == 'train': # Handle non-binary labels 164 | print("Non-binary labels encountered in train; Normalizing.") 165 | labels.normalize(norm='max', copy=False) 166 | else: 167 | print("Non-binary labels encountered in test/val; Binarizing.") 168 | labels.binarize() 169 | return labels 170 | 171 | def load_data(self, data_dir, fname_f, fname_l, data, 172 | normalize_features=True, normalize_labels=False, 173 | feature_type='sparse', label_type='dense'): 174 | """Load features and labels from file in libsvm format or pickle 175 | """ 176 | features = self.load_features( 177 | data_dir, fname_f, data['X'], normalize_features, feature_type) 178 | labels = self.load_labels( 179 | data_dir, fname_l, data['Y'], normalize_labels, label_type) 180 | return features, labels 181 | 182 | @property 183 | def num_instances(self): 184 | return self.features.num_instances 185 | 186 | @property 187 | def num_features(self): 188 | return self.features.num_features 189 | 190 | @property 191 | def num_labels(self): 192 | return self.labels.num_labels 193 | 194 | def get_stats(self): 195 | """Get dataset statistics 196 | """ 197 | return self.num_instances, self.num_features, self.num_labels 198 | 199 | def _process_labels_train(self, data_obj): 200 | """Process labels for train data 201 | - Remove labels without any training instance 202 | """ 203 | data_obj['num_labels'] = self.num_labels 204 | valid_labels = self.labels.remove_invalid() 205 | data_obj['valid_labels'] = valid_labels 206 | 207 | def _process_labels_predict(self, data_obj): 208 | """Process labels for test data 209 | Only use valid labels i.e. which had atleast one training 210 | example 211 | """ 212 | valid_labels = data_obj['valid_labels'] 213 | self.labels._index_select(valid_labels, axis=1) 214 | 215 | def _process_labels(self, model_dir): 216 | """Process labels to handle labels without any training instance; 217 | """ 218 | data_obj = {} 219 | fname = os.path.join( 220 | model_dir, 'labels_params.pkl' if self._split is None else 221 | "labels_params_split_{}.pkl".format(self._split)) 222 | if self.mode == 'train': 223 | self._process_labels_train(data_obj) 224 | pickle.dump(data_obj, open(fname, 'wb')) 225 | else: 226 | data_obj = pickle.load(open(fname, 'rb')) 227 | self._process_labels_predict(data_obj) 228 | 229 | def __len__(self): 230 | return self.num_instances 231 | 232 | def __getitem__(self, index): 233 | """Get features and labels for index 234 | Arguments 235 | --------- 236 | index: int 237 | data for this index 238 | Returns 239 | ------- 240 | features: np.ndarray or tuple 241 | for dense: np.ndarray 242 | for sparse: feature indices and their weights 243 | labels: np.ndarray 244 | 1 when relevant; 0 otherwise 245 | """ 246 | x = self.features[index] 247 | y = self.labels[index] 248 | return x, y 249 | -------------------------------------------------------------------------------- /deepxml/libs/dist_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import _pickle as pickle 3 | 4 | 5 | class Partitioner(object): 6 | """Utility to distribute an array 7 | Indices support: contiguous or otherwise (e.g. shortlist) 8 | * useful in distributed training of classifier 9 | 10 | Arguments: 11 | ----------- 12 | size: int 13 | size of data 14 | num_patitions: int 15 | Divide data in these many parittions 16 | padding: boolean, optional, default=False 17 | Padding index (Not handeled #TODO) 18 | contiguous: boolean, optional, default=True 19 | whether data is contiguous or not 20 | (non-contiguous not supported as of now) 21 | """ 22 | 23 | def __init__(self, size, num_patitions, padding=False, contiguous=True): 24 | # TODO: Handle padding 25 | self.num_patitions = num_patitions 26 | self.size = size 27 | self.contiguous = contiguous 28 | self._partitions = self._create_partitions() 29 | self.mapping_to_original, \ 30 | self.mapping_to_partition = self._create_mapping() 31 | self.partition_boundaries = self._create_partition_boundaries() 32 | 33 | def get_padding_indices(self): 34 | return [item.size for item in self._partitions] 35 | 36 | def _create_partition_boundaries(self): 37 | """Split array at these points 38 | """ 39 | _last = 0 40 | partition_boundaries = [] 41 | for item in self._partitions[:-1]: 42 | partition_boundaries.append(_last+item.size) 43 | _last = partition_boundaries[-1] 44 | return partition_boundaries 45 | 46 | def _create_partitions(self): 47 | """Create partitions 48 | """ 49 | return np.array_split(np.arange(self.size), self.num_patitions) 50 | 51 | def _create_mapping(self): 52 | """Mapping to map indices original<->partitioned 53 | """ 54 | mapping_to_original = [] 55 | mapping_to_partition = [] 56 | for _, _partition in enumerate(self._partitions): 57 | mapping_to_original.append( 58 | dict(zip(np.arange(_partition.size), _partition))) 59 | mapping_to_partition.append( 60 | dict(zip(_partition, np.arange(_partition.size)))) 61 | return mapping_to_original, mapping_to_partition 62 | 63 | def _map(self, fun, array): 64 | return np.fromiter(map(fun, array), dtype=array.dtype) 65 | 66 | def map_to_original(self, array, idx=None): 67 | return self._map(self.mapping_to_original[idx].get, array) 68 | 69 | # def map_to_partition(self, array): 70 | # return self.map(array, self.fun_map_to_partition) 71 | 72 | def get_partition_index(self, index): 73 | """ 74 | In which partition this index falls into? 75 | """ 76 | _last = 0 77 | for idx, _current in enumerate(self.partition_boundaries): 78 | if index >= _last and index < _current: 79 | return idx 80 | return self.num_patitions-1 81 | 82 | def split_indices_with_data(self, indices, data): 83 | """Split given indices and data (Shortlist) 84 | i.e. get the partition and map them accordingly 85 | """ 86 | out_ind = [[] for _ in range(self.num_patitions)] 87 | out_vals = [[] for _ in range(self.num_patitions)] 88 | for key, val in zip(indices, data): 89 | part = self.get_partition_index(key) 90 | ind = self.mapping_to_partition[part][key] 91 | out_ind[part].append(ind) 92 | out_vals[part].append(val) 93 | return out_ind, out_vals 94 | 95 | def split_indices(self, indices): 96 | """Split given indices (Shortlist) 97 | i.e. get the partition and map them accordingly 98 | """ 99 | out_ind = [[] for _ in range(self.num_patitions)] 100 | for key in indices: 101 | part = self.get_partition_index(key) 102 | ind = self.mapping_to_partition[part][key] 103 | out_ind[part].append(ind) 104 | return out_ind 105 | 106 | def split(self, array): # Split as per bondaries 107 | """Split given array in partitions (For contiguous indices) 108 | """ 109 | if self.contiguous: 110 | return np.hsplit(array, self.partition_boundaries) 111 | else: 112 | pass 113 | 114 | def merge(self, arrays): 115 | if self.contiguous: 116 | return np.hstack(arrays) 117 | else: 118 | pass 119 | 120 | def save(self, fname): 121 | pickle.dump(self.__dict__, open(fname, 'wb')) 122 | 123 | def load(self, fname): 124 | self.__dict__ = pickle.load(open(fname, 'rb')) 125 | 126 | def get_indices(self, part): 127 | # Indices for given partition 128 | return self._partitions[part] 129 | 130 | def __repr__(self): 131 | return "({}) #Size: {}, #partitions: {}".format( 132 | self.__class__.__name__, self.size, self.num_patitions) 133 | -------------------------------------------------------------------------------- /deepxml/libs/features.py: -------------------------------------------------------------------------------- 1 | from xclib.data.features import FeaturesBase, DenseFeatures, SparseFeatures 2 | import numpy as np 3 | from operator import itemgetter 4 | 5 | 6 | def construct(data_dir, fname, X=None, normalize=False, _type='sparse'): 7 | """Construct feature class based on given parameters 8 | 9 | Arguments 10 | ---------- 11 | data_dir: str 12 | data directory 13 | fname: str 14 | load data from this file 15 | X: csr_matrix or None, optional, default=None 16 | data is already provided 17 | normalize: boolean, optional, default=False 18 | Normalize the data or not 19 | _type: str, optional, default=sparse 20 | -sparse 21 | -dense 22 | -sequential 23 | """ 24 | if _type == 'sparse': 25 | return _SparseFeatures(data_dir, fname, X, normalize) 26 | elif _type == 'dense': 27 | return DenseFeatures(data_dir, fname, X, normalize) 28 | elif _type == 'sequential': 29 | return SequentialFeatures(data_dir, fname, X) 30 | else: 31 | raise NotImplementedError("Unknown feature type") 32 | 33 | 34 | class _SparseFeatures(SparseFeatures): 35 | """Class for sparse features 36 | 37 | * Difference: treat 0 as padding index 38 | 39 | Arguments 40 | ---------- 41 | data_dir: str 42 | data directory 43 | fname: str 44 | load data from this file 45 | X: csr_matrix or None, optional, default=None 46 | data is already provided 47 | normalize: boolean, optional, default=False 48 | Normalize the data or not 49 | """ 50 | 51 | def __init__(self, data_dir, fname, X=None, normalize=False): 52 | super().__init__(data_dir, fname, X, normalize) 53 | 54 | def __getitem__(self, index): 55 | # Treat idx:0 as Padding 56 | x = self.X[index].indices + 1 57 | w = self.X[index].data 58 | return x, w 59 | 60 | 61 | class SequentialFeatures(FeaturesBase): 62 | """Class for sequential features 63 | 64 | Arguments 65 | ---------- 66 | data_dir: str 67 | data directory 68 | fname: str 69 | load data from this file 70 | X: csr_matrix or None, optional, default=None 71 | data is already provided 72 | normalize: boolean, optional, default=False 73 | Normalize the data or not 74 | """ 75 | 76 | def __init__(self, data_dir, fname, X=None): 77 | super().__init__(data_dir, fname, X) 78 | 79 | def frequency(self, axis=0): 80 | pass 81 | 82 | def load(self, data_dir, fname, X): 83 | if X is not None: 84 | super().load(data_dir, fname, X) 85 | else: 86 | raise NotImplementedError 87 | 88 | def _select_instances(self, indices): 89 | self.X = list(itemgetter(*indices)(self.X)) 90 | 91 | def get_valid(self, axis=0): 92 | return np.arange(len(self.X)) 93 | 94 | def __getitem__(self, index): 95 | return np.array(self.X[index]) 96 | -------------------------------------------------------------------------------- /deepxml/libs/labels.py: -------------------------------------------------------------------------------- 1 | from xclib.data.labels import DenseLabels, SparseLabels, LabelsBase 2 | 3 | 4 | def construct(data_dir, fname, Y=None, normalize=False, _type='sparse'): 5 | """Construct label class based on given parameters 6 | 7 | Support for: 8 | * pkl file: Key 'Y' is used to access the labels 9 | * txt file: sparse libsvm format with header 10 | * npz file: numpy's sparse format 11 | 12 | Arguments 13 | ---------- 14 | data_dir: str 15 | data directory 16 | fname: str 17 | load data from this file 18 | Y: csr_matrix or None, optional, default=None 19 | data is already provided 20 | normalize: boolean, optional, default=False 21 | Normalize the labels or not 22 | Useful in case of non binary labels 23 | _type: str, optional, default=sparse 24 | -sparse or dense 25 | """ 26 | if fname is None and Y is None: # No labels are provided 27 | return LabelsBase(data_dir, fname, Y) 28 | else: 29 | if _type == 'sparse': 30 | return SparseLabels(data_dir, fname, Y, normalize) 31 | elif _type == 'dense': 32 | return DenseLabels(data_dir, fname, Y, normalize) 33 | else: 34 | raise NotImplementedError("Unknown label type") 35 | 36 | -------------------------------------------------------------------------------- /deepxml/libs/lookup.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import _pickle as pickle 3 | import h5py 4 | 5 | 6 | class Table(object): 7 | """Maintain a lookup table 8 | Supports in-memory and memmap file 9 | 10 | Arguments 11 | ---------- 12 | _type: str, optional, default='memory' 13 | keep data in-memory or on-disk 14 | _dtype: str, optional, default=np.float32 15 | datatype of the incoming data 16 | """ 17 | 18 | def __init__(self, _type='memory', _dtype=np.float32): 19 | self._type = _type 20 | self._dtype = _dtype 21 | self._shape = None 22 | self.data = None 23 | self._file = None 24 | 25 | def _get_fname(self, fname, mode='data'): 26 | if mode == 'data': 27 | return fname + ".dat.npy" 28 | else: 29 | return fname + ".metadata" 30 | 31 | def create(self, _data, _fname, *args, **kwargs): 32 | """ 33 | Create a file 34 | Will not copy data if data-types are same 35 | """ 36 | _data = np.asarray(_data, dtype=self._dtype) 37 | self._shape = _data.shape 38 | if self._type == 'memory': 39 | self.data = _data 40 | elif self._type == 'memmap': 41 | self.data = np.memmap(self._get_fname( 42 | _fname), dtype=self._dtype, mode='w+', shape=self._shape) 43 | self._file = self.data 44 | self.data[:] = _data[:] 45 | self.data.flush() 46 | del _data # Save to disk and delete object in write mode 47 | else: 48 | raise NotImplementedError("Unknown type!") 49 | 50 | def query(self, indices): 51 | return self.data[indices] 52 | 53 | def save(self, _fname): 54 | obj = {'_type': self._type, 55 | '_dtype': self._dtype, 56 | '_shape': self._shape} 57 | pickle.dump(obj, open(self._get_fname(_fname, 'metadata'), 'wb')) 58 | # Save numpy array; others are already on disk 59 | if self._type == 'memory': 60 | np.save(self._get_fname(_fname), self.data) 61 | # Not expected to work when filenames are same 62 | elif self._type == 'memmap': 63 | _file = np.memmap(self._get_fname(_fname), 64 | dtype=self._dtype, mode='w+', shape=self._shape) 65 | _file[:] = self.data[:] 66 | _file.flush() 67 | else: 68 | raise NotImplementedError("Unknown type!") 69 | 70 | def load(self, _fname): 71 | obj = pickle.load(open(self._get_fname(_fname, 'metadata'), 'rb')) 72 | self._type = obj['_type'] 73 | self._dtype = obj['_dtype'] 74 | self._shape = obj['_shape'] 75 | if self._type == 'memory': 76 | self.data = np.load(self._get_fname(_fname), allow_pickle=True) 77 | elif self._type == 'memmap': 78 | self.data = np.memmap(self._get_fname(_fname), 79 | mode='r+', shape=self._shape, 80 | dtype=self._dtype) 81 | else: 82 | raise NotImplementedError("Unknown type!") 83 | 84 | def __del__(self): 85 | del self.data 86 | 87 | @property 88 | def data_init(self): 89 | return True if self.data is not None else False 90 | 91 | 92 | class PartitionedTable(object): 93 | """Maintain multiple lookup tables 94 | Supports in-memory and memmap file 95 | 96 | Arguments 97 | --------- 98 | num_partitions: int, optional, default=1 99 | #tables to maintain 100 | _type: str, optional, default='memory' 101 | keep data in-memory or on-disk 102 | _dtype: str, optional, default=np.float32 103 | datatype of the incoming data 104 | """ 105 | 106 | def __init__(self, num_partitions=1, _type='memory', _dtype=np.float32): 107 | self.num_partitions = num_partitions 108 | self.data = [] 109 | for _ in range(self.num_partitions): 110 | self.data.append(Table(_type, _dtype)) 111 | 112 | def _create_one(self, _data, _fname, idx): # Create a specific graph only 113 | # TODO: Add condition to check for invalid idx 114 | self.data[idx].create(_data, _fname + ".{}".format(idx)) 115 | 116 | def create(self, _data, _fname, idx=-1): 117 | """ 118 | Create a file 119 | Will copy data 120 | """ 121 | if idx != -1: 122 | self._create_one(_data, _fname, idx) 123 | else: 124 | for idx in range(self.num_partitions): 125 | self._create_one(_data[idx], _fname, idx) 126 | 127 | def query(self, indices): 128 | """ 129 | Query indices will be fine as per each table 130 | No need to re-map here 131 | """ 132 | out = [] 133 | for idx in range(self.num_partitions): 134 | out.append(self.data[idx].query(indices)) 135 | return out 136 | 137 | def save(self, _fname): 138 | pickle.dump( 139 | {'num_partitions': self.num_partitions}, 140 | open(_fname+".metadata", "wb")) 141 | for idx in range(self.num_partitions): 142 | self.data[idx].save(_fname + ".{}".format(idx)) 143 | 144 | def load(self, _fname): 145 | self.num_partitions = pickle.load( 146 | open(_fname+".metadata", "rb"))['num_partitions'] 147 | for idx in range(self.num_partitions): 148 | self.data[idx].load(_fname + ".{}".format(idx)) 149 | 150 | @property 151 | def data_init(self): 152 | status = [item.data_init for item in self.data] 153 | return True if all(status) else False 154 | -------------------------------------------------------------------------------- /deepxml/libs/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class _Loss(torch.nn.Module): 6 | def __init__(self, reduction='mean', pad_ind=None): 7 | super(_Loss, self).__init__() 8 | self.reduction = reduction 9 | self.pad_ind = pad_ind 10 | 11 | def _reduce(self, loss): 12 | if self.reduction == 'none': 13 | return loss 14 | elif self.reduction == 'mean': 15 | return loss.mean() 16 | else: 17 | return loss.sum() 18 | 19 | def _mask_at_pad(self, loss): 20 | """ 21 | Mask the loss at padding index, i.e., make it zero 22 | """ 23 | if self.pad_ind is not None: 24 | loss[:, self.pad_ind] = 0.0 25 | return loss 26 | 27 | def _mask(self, loss, mask=None): 28 | """ 29 | Mask the loss at padding index, i.e., make it zero 30 | * Mask should be a boolean array with 1 where loss needs 31 | to be considered. 32 | * it'll make it zero where value is 0 33 | """ 34 | if mask is not None: 35 | loss = loss.masked_fill(~mask, 0.0) 36 | return loss 37 | 38 | 39 | def _convert_labels_for_svm(y): 40 | """ 41 | Convert labels from {0, 1} to {-1, 1} 42 | """ 43 | return 2.*y - 1.0 44 | 45 | 46 | class HingeLoss(_Loss): 47 | r""" Hinge loss 48 | * it'll automatically convert target to +1/-1 as required by hinge loss 49 | 50 | Arguments: 51 | ---------- 52 | margin: float, optional (default=1.0) 53 | the margin in hinge loss 54 | reduction: string, optional (default='mean') 55 | Specifies the reduction to apply to the output: 56 | * 'none': no reduction will be applied 57 | * 'mean' or 'sum': mean or sum of loss terms 58 | pad_ind: int/int64 or None (default=None) 59 | ignore loss values at this index 60 | useful when some index has to be used as padding index 61 | """ 62 | 63 | def __init__(self, margin=1.0, reduction='mean', pad_ind=None): 64 | super(HingeLoss, self).__init__(reduction, pad_ind) 65 | self.margin = margin 66 | 67 | def forward(self, input, target, mask=None): 68 | """ 69 | Arguments: 70 | --------- 71 | input: torch.FloatTensor 72 | real number pred matrix of size: batch_size x output_size 73 | typically logits from the neural network 74 | target: torch.FloatTensor 75 | 0/1 ground truth matrix of size: batch_size x output_size 76 | * it'll automatically convert to +1/-1 as required by hinge loss 77 | mask: torch.BoolTensor or None, optional (default=None) 78 | ignore entries [won't contribute to loss] where mask value is zero 79 | 80 | Returns: 81 | ------- 82 | loss: torch.FloatTensor 83 | dimension is defined based on reduction 84 | """ 85 | loss = F.relu(self.margin - _convert_labels_for_svm(target)*input) 86 | loss = self._mask_at_pad(loss) 87 | loss = self._mask(loss, mask) 88 | return self._reduce(loss) 89 | 90 | 91 | class SquaredHingeLoss(_Loss): 92 | r""" Squared Hinge loss 93 | * it'll automatically convert target to +1/-1 as required by hinge loss 94 | 95 | Arguments: 96 | ---------- 97 | margin: float, optional (default=1.0) 98 | the margin in squared hinge loss 99 | reduction: string, optional (default='mean') 100 | Specifies the reduction to apply to the output: 101 | * 'none': no reduction will be applied 102 | * 'mean' or 'sum': mean or sum of loss terms 103 | pad_ind: int/int64 or None (default=None) 104 | ignore loss values at this index 105 | useful when some index has to be used as padding index 106 | """ 107 | 108 | def __init__(self, margin=1.0, reduction='mean', pad_ind=None): 109 | super(SquaredHingeLoss, self).__init__(reduction, pad_ind) 110 | self.margin = margin 111 | 112 | def forward(self, input, target, mask=None): 113 | """ 114 | Arguments: 115 | --------- 116 | input: torch.FloatTensor 117 | real number pred matrix of size: batch_size x output_size 118 | typically logits from the neural network 119 | target: torch.FloatTensor 120 | 0/1 ground truth matrix of size: batch_size x output_size 121 | * it'll automatically convert to +1/-1 as required by hinge loss 122 | mask: torch.BoolTensor or None, optional (default=None) 123 | ignore entries [won't contribute to loss] where mask value is zero 124 | 125 | Returns: 126 | ------- 127 | loss: torch.FloatTensor 128 | dimension is defined based on reduction 129 | """ 130 | loss = F.relu(self.margin - _convert_labels_for_svm(target)*input) 131 | loss = loss**2 132 | loss = self._mask_at_pad(loss) 133 | loss = self._mask(loss, mask) 134 | return self._reduce(loss) 135 | 136 | 137 | class BCEWithLogitsLoss(_Loss): 138 | r""" BCE loss (expects logits; numercial stable) 139 | This loss combines a `Sigmoid` layer and the `BCELoss` in one single 140 | class. This version is more numerically stable than using a plain `Sigmoid` 141 | followed by a `BCELoss` as, by combining the operations into one layer, 142 | we take advantage of the log-sum-exp trick for numerical stability. 143 | 144 | Arguments: 145 | ---------- 146 | weight: torch.Tensor or None, optional (default=None)) 147 | a manual rescaling weight given to the loss of each batch element. 148 | If given, has to be a Tensor of size batch_size 149 | reduction: string, optional (default='mean') 150 | Specifies the reduction to apply to the output: 151 | * 'none': no reduction will be applied 152 | * 'mean' or 'sum': mean or sum of loss terms 153 | pos_weight: torch.Tensor or None, optional (default=None) 154 | a weight of positive examples. 155 | it must be a vector with length equal to the number of classes. 156 | pad_ind: int/int64 or None (default=None) 157 | ignore loss values at this index 158 | useful when some index has to be used as padding index 159 | """ 160 | __constants__ = ['weight', 'pos_weight', 'reduction'] 161 | 162 | def __init__(self, weight=None, reduction='mean', 163 | pos_weight=None, pad_ind=None): 164 | super(BCEWithLogitsLoss, self).__init__(reduction, pad_ind) 165 | self.register_buffer('weight', weight) 166 | self.register_buffer('pos_weight', pos_weight) 167 | 168 | def forward(self, input, target, mask=None): 169 | """ 170 | Arguments: 171 | --------- 172 | input: torch.FloatTensor 173 | real number pred matrix of size: batch_size x output_size 174 | typically logits from the neural network 175 | target: torch.FloatTensor 176 | 0/1 ground truth matrix of size: batch_size x output_size 177 | mask: torch.BoolTensor or None, optional (default=None) 178 | ignore entries [won't contribute to loss] where mask value is zero 179 | 180 | Returns: 181 | ------- 182 | loss: torch.FloatTensor 183 | dimension is defined based on reduction 184 | """ 185 | loss = F.binary_cross_entropy_with_logits(input, target, 186 | self.weight, 187 | pos_weight=self.pos_weight, 188 | reduction='none') 189 | loss = self._mask_at_pad(loss) 190 | loss = self._mask(loss, mask) 191 | return self._reduce(loss) 192 | -------------------------------------------------------------------------------- /deepxml/libs/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import models.transform_layer as transform_layer 4 | 5 | 6 | class Optimizer(object): 7 | """Wrapper for pytorch optimizer class to handle 8 | mixture of sparse and dense parameters 9 | * Infers sparse/dense from 'sparse' attribute 10 | * Best results with Adam optimizer 11 | * Uses _modules() method by default; User may choose to define 12 | modules_() to change the behaviour 13 | 14 | Arguments 15 | ---------- 16 | opt_type: str, optional, default='Adam' 17 | optimizer to use 18 | learning_rate: float, optional, default=0.01 19 | learning rate for the optimizer 20 | momentum: float, optional, default=0.9 21 | momentum (valid for SGD only) 22 | weight_decay: float, optional, default=0.0 23 | l2-regularization cofficient 24 | nesterov: boolean, optional, default=True 25 | Use nesterov method (useful in SGD only) 26 | freeze_embeddings: boolean, optional, default=False 27 | Don't update embedding layer 28 | """ 29 | 30 | def __init__(self, opt_type='Adam', learning_rate=0.01, 31 | momentum=0.9, weight_decay=0.0, nesterov=True): 32 | self.opt_type = opt_type 33 | self.optimizer = [] 34 | self.weight_decay = weight_decay 35 | self.learning_rate = learning_rate 36 | self.momentum = momentum 37 | self.nesterov = nesterov 38 | 39 | def _get_opt(self, params, is_sparse): 40 | if self.opt_type == 'SGD': 41 | if is_sparse: 42 | return torch.optim.SGD( 43 | params, 44 | lr=self.learning_rate, 45 | momentum=self.momentum, 46 | ) 47 | else: 48 | return torch.optim.SGD( 49 | params, 50 | lr=self.learning_rate, 51 | momentum=self.momentum, 52 | weight_decay=self.weight_decay 53 | ) 54 | elif self.opt_type == 'Adam': 55 | if is_sparse: 56 | return torch.optim.SparseAdam( 57 | params, 58 | lr=self.learning_rate 59 | ) 60 | else: 61 | return torch.optim.Adam( 62 | params, 63 | lr=self.learning_rate, 64 | weight_decay=self.weight_decay 65 | ) 66 | else: 67 | raise NotImplementedError("Unknown optimizer!") 68 | 69 | def construct(self, model): 70 | """ 71 | Get optimizer. 72 | Args: 73 | model: torch.nn.Module: network 74 | params: : parameters 75 | Returns: 76 | optimizer: torch.optim: optimizer as per given specifications 77 | """ 78 | model_params, is_sparse = self.get_params(model) 79 | for _, item in enumerate(zip(model_params, is_sparse)): 80 | if item[0]: 81 | self.optimizer.append(self._get_opt( 82 | params=item[0], is_sparse=item[1])) 83 | else: 84 | self.optimizer.append(None) 85 | 86 | def adjust_lr(self, dlr_factor): 87 | """ 88 | Adjust learning rate 89 | 90 | Arguments: 91 | --------- 92 | dlr_factor: float 93 | lr = lr * dlr_factor 94 | """ 95 | for opt in self.optimizer: # for each group 96 | if opt: 97 | for param_group in opt.param_groups: 98 | param_group['lr'] *= dlr_factor 99 | 100 | def step(self): 101 | for opt in self.optimizer: 102 | if opt: 103 | opt.step() 104 | 105 | def load_state_dict(self, sd): 106 | for idx, item in enumerate(sd): 107 | if item: 108 | self.optimizer[idx].load_state_dict(item) 109 | 110 | def state_dict(self): 111 | out_states = [] 112 | for item in self.optimizer: 113 | if item: 114 | out_states.append(item.state_dict()) 115 | else: 116 | out_states.append(None) 117 | return out_states 118 | 119 | def _modules(self, net): 120 | """ 121 | Get modules (nn.Module) in the network 122 | 123 | * _modules() method by default 124 | * user may choose to define modules_() to change the behaviour 125 | """ 126 | # Useful when parameters are tied etc. 127 | if hasattr(net, 'modules_'): 128 | return net.modules_.items() 129 | else: 130 | return net._modules.items() 131 | 132 | def _sparse(self, item): 133 | """ 134 | * infer (sparse attribute) if parameter group is sparse or dense 135 | * assume dense of the sparse attribute is unavailable 136 | """ 137 | try: 138 | return item.sparse 139 | except AttributeError: 140 | return False 141 | 142 | def _parameters(self, item): 143 | """ 144 | Return the parameters and is_sparse for each group 145 | * Uses parameters() method for a nn.Module object 146 | * traverse down the tree until a nn.Module object is found 147 | Unknown behaviour for infinite loop 148 | """ 149 | if isinstance(item, transform_layer.Transform): 150 | return self._parameters(item.transform) 151 | elif isinstance(item, nn.Sequential): 152 | params = [] 153 | is_sparse = [] 154 | for _item in item: 155 | _p, _s = self._parameters(_item) 156 | params.append(_p) 157 | is_sparse.append(_s) 158 | return params, is_sparse 159 | elif isinstance(item, nn.Module): 160 | return item.parameters(), self._sparse(item) 161 | else: 162 | raise NotImplementedError("Unknown module class!") 163 | 164 | def _get_params(self, _params, _sparse, params): 165 | if isinstance(_params, list): 166 | for p, s in zip(_params, _sparse): 167 | self._get_params(p, s, params) 168 | else: # Should be generator 169 | for p in _params: 170 | if p.requires_grad: 171 | if _sparse: 172 | params['sparse'].append({"params": p}) 173 | else: 174 | params['dense'].append({"params": p}) 175 | 176 | def get_params(self, net): 177 | self.net_params = {'sparse': [], 'dense': []} 178 | for _, val in self._modules(net): 179 | p, s = self._parameters(val) 180 | self._get_params(p, s, self.net_params) 181 | return [self.net_params['sparse'], self.net_params['dense']], \ 182 | [True, False] 183 | -------------------------------------------------------------------------------- /deepxml/libs/parameters.py: -------------------------------------------------------------------------------- 1 | from libs.parameters_base import ParametersBase 2 | 3 | 4 | class Parameters(ParametersBase): 5 | """ 6 | Parameter class for XML Classifiers 7 | """ 8 | 9 | def __init__(self, description): 10 | super().__init__(description) 11 | self._construct() 12 | 13 | def _construct(self): 14 | super()._construct() 15 | self.parser.add_argument( 16 | '--arch', 17 | dest='arch', 18 | default='astec', 19 | type=str, 20 | action='store', 21 | help='which network to use') 22 | self.parser.add_argument( 23 | '--lr', 24 | dest='learning_rate', 25 | default=0.1, 26 | action='store', 27 | type=float, 28 | help='Learning rate') 29 | self.parser.add_argument( 30 | '--surrogate_mapping', 31 | dest='surrogate_mapping', 32 | default=None, 33 | action='store', 34 | type=str, 35 | help='surrogate_mapping') 36 | self.parser.add_argument( 37 | '--dlr_step', 38 | dest='dlr_step', 39 | default=7, 40 | action='store', 41 | type=int, 42 | help='dlr_step') 43 | self.parser.add_argument( 44 | '--last_saved_epoch', 45 | dest='last_epoch', 46 | default=0, 47 | action='store', 48 | type=int, 49 | help='Last saved model at this epoch!') 50 | self.parser.add_argument( 51 | '--last_epoch', 52 | dest='last_epoch', 53 | default=0, 54 | action='store', 55 | type=int, 56 | help='Start training from here') 57 | self.parser.add_argument( 58 | '--shortlist_method', 59 | dest='shortlist_method', 60 | default='static', 61 | action='store', 62 | type=str, 63 | help='Shortlist method (static/dynamic/hybrid)') 64 | self.parser.add_argument( 65 | '--model_method', 66 | dest='model_method', 67 | default='full', 68 | action='store', 69 | type=str, 70 | help='Model method (full/shortlist/ns)') 71 | self.parser.add_argument( 72 | '--ns_method', 73 | dest='ns_method', 74 | default='kcentroid', 75 | action='store', 76 | type=str, 77 | help='Sample negatives using this method') 78 | self.parser.add_argument( 79 | '--ann_method', 80 | dest='ann_method', 81 | default='hnsw', 82 | action='store', 83 | type=str, 84 | help='Approximate nearest neighbor method') 85 | self.parser.add_argument( 86 | '--seed', 87 | dest='seed', 88 | default=22, 89 | action='store', 90 | type=int, 91 | help='seed values') 92 | self.parser.add_argument( 93 | '--top_k', 94 | dest='top_k', 95 | default=50, 96 | action='store', 97 | type=int, 98 | help='#labels to predict for each document') 99 | self.parser.add_argument( 100 | '--num_workers', 101 | dest='num_workers', 102 | default=6, 103 | action='store', 104 | type=int, 105 | help='#workers in data loader') 106 | self.parser.add_argument( 107 | '--ann_threads', 108 | dest='ann_threads', 109 | default=12, 110 | action='store', 111 | type=int, 112 | help='HSNW params') 113 | self.parser.add_argument( 114 | '--num_clf_partitions', 115 | dest='num_clf_partitions', 116 | default=1, 117 | action='store', 118 | type=int, 119 | help='#Partitioned classifier') 120 | self.parser.add_argument( 121 | '--label_indices', 122 | dest='label_indices', 123 | default=None, 124 | action='store', 125 | type=str, 126 | help='Use these labels only') 127 | self.parser.add_argument( 128 | '--feature_indices', 129 | dest='feature_indices', 130 | default=None, 131 | action='store', 132 | type=str, 133 | help='Use these features only') 134 | self.parser.add_argument( 135 | '--efC', 136 | dest='efC', 137 | action='store', 138 | default=300, 139 | type=int, 140 | help='efC') 141 | self.parser.add_argument( 142 | '--num_nbrs', 143 | dest='num_nbrs', 144 | action='store', 145 | default=300, 146 | type=int, 147 | help='num_nbrs') 148 | self.parser.add_argument( 149 | '--efS', 150 | dest='efS', 151 | action='store', 152 | default=300, 153 | type=int, 154 | help='efS') 155 | self.parser.add_argument( 156 | '--M', 157 | dest='M', 158 | action='store', 159 | default=100, 160 | type=int, 161 | help='M') 162 | self.parser.add_argument( 163 | '--retrain_hnsw_after', 164 | action='store', 165 | default=1, 166 | type=int, 167 | help='Retrain HSNW after these many epochs!') 168 | self.parser.add_argument( 169 | '--num_labels', 170 | dest='num_labels', 171 | default=-1, 172 | action='store', 173 | type=int, 174 | help='#labels') 175 | self.parser.add_argument( 176 | '--vocabulary_dims', 177 | dest='vocabulary_dims', 178 | default=-1, 179 | action='store', 180 | type=int, 181 | help='#features') 182 | self.parser.add_argument( 183 | '--padding_idx', 184 | dest='padding_idx', 185 | default=0, 186 | action='store', 187 | type=int, 188 | help='padding_idx') 189 | self.parser.add_argument( 190 | '--out_fname', 191 | dest='out_fname', 192 | default='out', 193 | action='store', 194 | type=str, 195 | help='prediction file name') 196 | self.parser.add_argument( 197 | '--dlr_factor', 198 | dest='dlr_factor', 199 | default=0.5, 200 | action='store', 201 | type=float, 202 | help='dlr_factor') 203 | self.parser.add_argument( 204 | '--m', 205 | dest='momentum', 206 | default=0.9, 207 | action='store', 208 | type=float, 209 | help='momentum') 210 | self.parser.add_argument( 211 | '--w', 212 | dest='weight_decay', 213 | default=0.0, 214 | action='store', 215 | type=float, 216 | help='weight decay parameter') 217 | self.parser.add_argument( 218 | '--dropout', 219 | dest='dropout', 220 | default=0.5, 221 | action='store', 222 | type=float, 223 | help='Dropout') 224 | self.parser.add_argument( 225 | '--optim', 226 | dest='optim', 227 | default='SGD', 228 | action='store', 229 | type=str, 230 | help='Optimizer') 231 | self.parser.add_argument( 232 | '--embedding_dims', 233 | dest='embedding_dims', 234 | default=300, 235 | action='store', 236 | type=int, 237 | help='embedding dimensions') 238 | self.parser.add_argument( 239 | '--embeddings', 240 | dest='embeddings', 241 | default='fasttextB_embeddings_300d.npy', 242 | action='store', 243 | type=str, 244 | help='embedding file name') 245 | self.parser.add_argument( 246 | '--validate_after', 247 | dest='validate_after', 248 | default=5, 249 | action='store', 250 | type=int, 251 | help='Validate after these many epochs.') 252 | self.parser.add_argument( 253 | '--num_epochs', 254 | dest='num_epochs', 255 | default=20, 256 | action='store', 257 | type=int, 258 | help='num epochs') 259 | self.parser.add_argument( 260 | '--batch_size', 261 | dest='batch_size', 262 | default=64, 263 | action='store', 264 | type=int, 265 | help='batch size') 266 | self.parser.add_argument( 267 | '--num_centroids', 268 | dest='num_centroids', 269 | default=1, 270 | type=int, 271 | action='store', 272 | help='#Centroids (Use multiple for ext head if more than 1)') 273 | self.parser.add_argument( 274 | '--beta', 275 | dest='beta', 276 | default=0.2, 277 | type=float, 278 | action='store', 279 | help='weight of classifier') 280 | self.parser.add_argument( 281 | '--res_init', 282 | dest='res_init', 283 | default='eye', 284 | type=str, 285 | action='store', 286 | help='eye or random') 287 | self.parser.add_argument( 288 | '--label_padding_index', 289 | dest='label_padding_index', 290 | default=None, 291 | type=int, 292 | action='store', 293 | help='Pad with this') 294 | self.parser.add_argument( 295 | '--mode', 296 | dest='mode', 297 | default='train', 298 | type=str, 299 | action='store', 300 | help='train or predict') 301 | self.parser.add_argument( 302 | '--init', 303 | dest='init', 304 | default='token_embeddings', 305 | type=str, 306 | action='store', 307 | help='initialize model parameters using') 308 | self.parser.add_argument( 309 | '--keep_invalid', 310 | action='store_true', 311 | help='Keep labels which do not have any training instance!.') 312 | self.parser.add_argument( 313 | '--freeze_intermediate', 314 | action='store_true', 315 | help='Do not train intermediate rep.') 316 | self.parser.add_argument( 317 | '--use_shortlist', 318 | action='store_true', 319 | help='Use shortlist or full') 320 | self.parser.add_argument( 321 | '--save_intermediate', 322 | action='store_true', 323 | help='Save model for intermediate rep.') 324 | self.parser.add_argument( 325 | '--use_pretrained_shortlist', 326 | action='store_true', 327 | help='Load shortlist from disk') 328 | self.parser.add_argument( 329 | '--validate', 330 | action='store_true', 331 | help='Validate or just train') 332 | self.parser.add_argument( 333 | '--bias', 334 | action='store', 335 | default=True, 336 | type=bool, 337 | help='Use bias term or not!') 338 | self.parser.add_argument( 339 | '--shuffle', 340 | action='store', 341 | default=True, 342 | type=bool, 343 | help='Shuffle data during training!') 344 | self.parser.add_argument( 345 | '--devices', 346 | action='store', 347 | default=['cuda:0'], 348 | nargs='+', 349 | help='Device for embeddings' 350 | ) 351 | self.parser.add_argument( 352 | '--normalize', 353 | action='store_true', 354 | help='Normalize features or not!') 355 | self.parser.add_argument( 356 | '--nbn_rel', 357 | action='store_true', 358 | help='Non binary label relevanxe') 359 | self.parser.add_argument( 360 | '--update_shortlist', 361 | action='store_true', 362 | help='Update shortlist while predicting' 363 | ) 364 | self.parser.add_argument( 365 | '--huge_dataset', 366 | action='store_true', 367 | help='Is it a really large dataset?' 368 | ) 369 | self.parser.add_argument( 370 | '--use_intermediate_for_shorty', 371 | action='store_true', 372 | help='Use intermediate representation for shortlist' 373 | ) 374 | self.parser.add_argument( 375 | '--get_only', 376 | nargs='+', 377 | type=str, 378 | default=['knn', 'clf', 'combined'], 379 | help='What do you have to output?' 380 | ) 381 | -------------------------------------------------------------------------------- /deepxml/libs/parameters_base.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | 5 | __author__ = 'X' 6 | 7 | 8 | class ParametersBase(): 9 | """ 10 | Base class for parameters in XML 11 | """ 12 | def __init__(self, description): 13 | self.parser = argparse.ArgumentParser(description) 14 | self.params = None 15 | 16 | def _construct(self): 17 | self.parser.add_argument( 18 | '--dataset', 19 | dest='dataset', 20 | action='store', 21 | type=str, 22 | help='dataset name') 23 | self.parser.add_argument( 24 | '--data_dir', 25 | dest='data_dir', 26 | action='store', 27 | type=str, 28 | help='path to main data directory') 29 | self.parser.add_argument( 30 | '--model_dir', 31 | dest='model_dir', 32 | action='store', 33 | type=str, 34 | help='directory to store models') 35 | self.parser.add_argument( 36 | '--result_dir', 37 | dest='result_dir', 38 | action='store', 39 | type=str, 40 | help='directory to store results') 41 | self.parser.add_argument( 42 | '--model_fname', 43 | dest='model_fname', 44 | default='model', 45 | action='store', 46 | type=str, 47 | help='model file name') 48 | self.parser.add_argument( 49 | '--pred_fname', 50 | dest='pred_fname', 51 | default='predictions', 52 | action='store', 53 | type=str, 54 | help='prediction file name') 55 | self.parser.add_argument( 56 | '--trn_feat_fname', 57 | dest='trn_feat_fname', 58 | default='trn_X_Xf.txt', 59 | action='store', 60 | type=str, 61 | help='training feature file name') 62 | self.parser.add_argument( 63 | '--val_feat_fname', 64 | dest='val_feat_fname', 65 | default='tst_X_Xf.txt', 66 | action='store', 67 | type=str, 68 | help='validation feature file name') 69 | self.parser.add_argument( 70 | '--tst_feat_fname', 71 | dest='tst_feat_fname', 72 | default='tst_X_Xf.txt', 73 | action='store', 74 | type=str, 75 | help='test feature file name') 76 | self.parser.add_argument( 77 | '--trn_label_fname', 78 | dest='trn_label_fname', 79 | default='trn_X_Y.txt', 80 | action='store', 81 | type=str, 82 | help='training label file name') 83 | self.parser.add_argument( 84 | '--val_label_fname', 85 | dest='val_label_fname', 86 | default='tst_X_Y.txt', 87 | action='store', 88 | type=str, 89 | help='validation label file name') 90 | self.parser.add_argument( 91 | '--feature_type', 92 | dest='feature_type', 93 | default='sparse', 94 | action='store', 95 | type=str, 96 | help='feature type sequential/dense/sparse') 97 | self.parser.add_argument( 98 | '--tst_label_fname', 99 | dest='tst_label_fname', 100 | default='tst_X_Y.txt', 101 | action='store', 102 | type=str, 103 | help='test label file name') 104 | 105 | def parse_args(self): 106 | self.params, _ = self.parser.parse_known_args() 107 | 108 | def update(self, _dict): 109 | self.params.__dict__.update(_dict) 110 | 111 | def load(self, fname): 112 | vars(self.params).update(json.load(open(fname))) 113 | 114 | def save(self, fname): 115 | print(vars(self.params)) 116 | json.dump(vars(self.params), open(fname, 'w'), indent=4) 117 | -------------------------------------------------------------------------------- /deepxml/libs/sampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from functools import partial 4 | 5 | 6 | class BaseSampler(object): 7 | """Sampler with support for sampling from 8 | multinomial distribution 9 | 10 | Arguments: 11 | ---------- 12 | size: int 13 | sample spce 14 | num_samples: int 15 | #samples 16 | probs: np.ndarray or None, optional, default=None 17 | probability of each item 18 | replace: boolean, optional, default=False 19 | with or without replacement 20 | """ 21 | def __init__(self, size, num_samples, prob=None, replace=False): 22 | self.size = size 23 | self.num_samples = num_samples 24 | self.prob = prob 25 | self.replace = replace 26 | self.index = None 27 | self._construct() 28 | 29 | def _construct(self): 30 | """Create a partial function with given parameters 31 | Index should take one argument i.e. size during querying 32 | """ 33 | self.index = partial(np.random.randint, low=0, high=self.size) 34 | 35 | def _query(self): 36 | """Query for one sample 37 | """ 38 | return (self.index(size=self.num_samples), [1.0]*self.num_samples) 39 | 40 | def query(self, num_instances, *args, **kwargs): 41 | """Query shortlist for one or more samples 42 | """ 43 | if num_instances == 1: 44 | return self._query() 45 | else: 46 | out = [self._query() for _ in range(num_instances)] 47 | return out 48 | 49 | def save(self, fname): 50 | """ 51 | Save object 52 | """ 53 | state = self.__dict__ 54 | pickle.dump(state, open(fname, 'wb')) 55 | 56 | def load(self, fname): 57 | """ Load object 58 | """ 59 | self = pickle.load(open(fname, 'rb')) 60 | 61 | @property 62 | def data_init(self): 63 | return True if self.index is not None else False 64 | 65 | 66 | class NegativeSampler(BaseSampler): 67 | """Negative sampler with support for sampling from 68 | multinomial distribution 69 | 70 | Arguments: 71 | ---------- 72 | size: int 73 | sample space 74 | num_negatives: int 75 | #samples 76 | probs: np.ndarray or None, optional, default=None 77 | probability of each item 78 | replace: boolean, optional, default=False 79 | with or without replacement 80 | * replace = True is slow 81 | """ 82 | def __init__(self, size, num_negatives, prob=None, replace=False): 83 | super().__init__(size, num_negatives, prob, replace) 84 | 85 | def _construct(self): 86 | self.index = partial( 87 | np.random.default_rng().choice, a=self.size, 88 | replace=self.replace, p=self.prob) 89 | 90 | 91 | class Sampler(BaseSampler): 92 | """Sampler with support for sampling from 93 | multinomial distribution 94 | 95 | Arguments: 96 | ---------- 97 | size: int 98 | sample from this space 99 | num_samples: int 100 | #samples 101 | probs: np.ndarray or None, optional, default=None 102 | probability of each item 103 | replace: boolean, optional, default=False 104 | with or without replacement 105 | """ 106 | def __init__(self, size, num_samples, prob=None, replace=False): 107 | super().__init__(size, num_samples, prob, replace) 108 | 109 | def _construct(self): 110 | self.index = partial( 111 | np.random.default_rng().choice, replace=self.replace) 112 | 113 | def _query(self, ind): 114 | """Query for one sample 115 | """ 116 | prob = None 117 | if self.prob is not None: 118 | prob = self.prob[ind] 119 | return (self.index(a=ind, p=prob, size=self.num_samples), 120 | [1.0]*self.num_samples) 121 | 122 | def query(self, num_instances, ind, *args, **kwargs): 123 | """Query shortlist for one or more samples; 124 | Pick labels from given indices 125 | """ 126 | if num_instances == 1: 127 | return self._query(ind) 128 | else: 129 | out = [self._query(ind[i]) for i in range(num_instances)] 130 | return out 131 | -------------------------------------------------------------------------------- /deepxml/libs/shortlist.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from xclib.utils.sparse import topk, csr_from_arrays 3 | from xclib.utils.shortlist import Shortlist 4 | from xclib.utils.shortlist import ShortlistCentroids 5 | from xclib.utils.shortlist import ShortlistInstances 6 | 7 | 8 | class ShortlistEnsemble(object): 9 | """Get nearest labels using KNN + Kcentroid 10 | * Give less weight to KNN (typically 0.1 or 0.075) 11 | * brute or HNSW algorithm for search 12 | Parameters 13 | ---------- 14 | method: str, optional, default='hnsw' 15 | brute or hnsw 16 | num_neighbours: int, optional, default=500 17 | number of labels to keep for each instance 18 | * will pad using pad_ind and pad_val in case labels 19 | are less than num_neighbours 20 | M: int, optional, default=100 21 | HNSW M (Usually 100) 22 | efC: dict, optional, default={'kcentroid': 300, 'knn': 50} 23 | construction parameter for kcentroid and knn 24 | * Usually 300 for kcentroid and 50 for knn 25 | efS: dict, optional, default={'kcentroid': 300, 'knn': 500} 26 | search parameter for kcentroid and knn 27 | * Usually 300 for kcentroid and 500 for knn 28 | num_threads: int, optional, default=24 29 | use multiple threads to cluster 30 | space: str, optional, default='cosine' 31 | metric to use while quering 32 | verbose: boolean, optional, default=True 33 | print progress 34 | num_clusters: int, optional, default=1 35 | cluster instances => multiple representatives for chosen labels 36 | pad_val: int, optional, default=-10000 37 | value for padding indices 38 | - Useful as documents may have different number of nearest labels 39 | gamma: float, optional, default=0.075 40 | weight for KNN. 41 | * final shortlist => gamma * knn + (1-gamma) * kcentroid 42 | """ 43 | def __init__(self, method='hnsw', num_neighbours={'ens': 500, 44 | 'kcentroid': 400, 'knn': 300}, 45 | M={'kcentroid': 100, 'knn': 50}, 46 | efC={'kcentroid': 300, 'knn': 50}, 47 | efS={'kcentroid': 400, 'knn': 100}, 48 | num_threads=24, space='cosine', verbose=True, 49 | num_clusters=1, pad_val=-10000, gamma=0.075): 50 | self.kcentroid = ShortlistCentroids( 51 | method=method, num_neighbours=efS['kcentroid'], 52 | M=M['kcentroid'], efC=efC['kcentroid'], efS=efS['kcentroid'], 53 | num_threads=num_threads, space=space, num_clusters=num_clusters, 54 | verbose=True) 55 | self.knn = ShortlistInstances( 56 | method=method, num_neighbours=num_neighbours['knn'], M=M['knn'], 57 | efC=efC['knn'], efS=efS['knn'], num_threads=num_threads, 58 | space=space, verbose=True) 59 | self.num_labels = None 60 | self.num_neighbours = num_neighbours['ens'] 61 | self.pad_val = pad_val 62 | self.pad_ind = -1 63 | self.gamma = gamma 64 | 65 | def fit(self, X, Y, *args, **kwargs): 66 | # Useful when number of neighbors are not same 67 | self.pad_ind = Y.shape[1] 68 | self.num_labels = Y.shape[1] 69 | self.kcentroid.fit(X, Y) 70 | self.knn.fit(X, Y) 71 | 72 | @property 73 | def model_size(self): 74 | return self.knn.model_size + self.kcentroid.model_size 75 | 76 | def merge(self, indices_kcentroid, indices_knn, sim_kcentroid, sim_knn): 77 | _shape = (len(indices_kcentroid), self.num_labels+1) 78 | short_knn = csr_from_arrays( 79 | indices_knn, sim_knn, _shape) 80 | short_kcentroid = csr_from_arrays( 81 | indices_kcentroid, sim_kcentroid, _shape) 82 | indices, sim = topk( 83 | (self.gamma*short_knn + (1-self.gamma)*short_kcentroid), 84 | k=self.num_neighbours, pad_ind=self.pad_ind, 85 | pad_val=self.pad_val, return_values=True) 86 | return indices, sim 87 | 88 | def query(self, data): 89 | indices_knn, sim_knn = self.knn.query(data) 90 | indices_kcentroid, sim_kcentroid = self.kcentroid.query(data) 91 | indices, similarities = self.merge( 92 | indices_kcentroid, indices_knn, sim_kcentroid, sim_knn) 93 | return indices, similarities 94 | 95 | def save(self, fname): 96 | # Returns the filename on disk; useful in purging checkpoints 97 | pickle.dump( 98 | {'num_labels': self.num_labels, 99 | 'pad_ind': self.pad_ind}, open(fname+".metadata", 'wb')) 100 | self.kcentroid.save(fname+'.kcentroid') 101 | self.knn.save(fname+'.knn') 102 | 103 | def purge(self, fname): 104 | # purge files from disk 105 | self.knn.purge(fname) 106 | self.kcentroid.purge(fname) 107 | 108 | def load(self, fname): 109 | obj = pickle.load( 110 | open(fname+".metadata", 'rb')) 111 | self.num_labels = obj['num_labels'] 112 | self.pad_ind = obj['pad_ind'] 113 | self.kcentroid.load(fname+'.kcentroid') 114 | self.knn.load(fname+'.knn') 115 | 116 | def reset(self): 117 | self.kcentroid.reset() 118 | self.knn.reset() 119 | 120 | 121 | class ParallelShortlist(object): 122 | """Multiple graphs; Supports parallel training 123 | Assumes that all parameters are same for each graph 124 | Parameters 125 | ---------- 126 | method: str 127 | brute or hnsw 128 | num_neighbours: int 129 | number of neighbors 130 | M: int 131 | HNSW M (Usually 100) 132 | efC: int 133 | construction parameter (Usually 300) 134 | efS: int 135 | search parameter (Usually 300) 136 | num_threads: int, optional, default=-1 137 | use multiple threads to cluster 138 | num_graphs: int, optional, default=2 139 | #graphs to maintain 140 | """ 141 | 142 | def __init__(self, method, num_neighbours, M, efC, efS, 143 | num_threads=-1, num_graphs=2): 144 | self.num_graphs = num_graphs 145 | self.index = [] 146 | for _ in range(num_graphs): 147 | self.index.append( 148 | Shortlist(method, num_neighbours, M, efC, efS, num_threads)) 149 | 150 | def train(self, data): 151 | # Sequential for now; Shit happends in parallel 152 | for idx in range(self.num_graphs): 153 | self.index[idx].train(data[idx]) 154 | 155 | def _query(self, idx, data): 156 | return self.index[idx].query(data) 157 | 158 | def query(self, data, idx=-1): 159 | # Sequential for now 160 | # Parallelize with return values? 161 | # Data is same for everyone 162 | if idx != -1: # Query from particular graph only 163 | indices, similarities = self._query(idx, data) 164 | else: 165 | indices, sims = [], [] 166 | for idx in range(self.num_graphs): 167 | _indices, _sims = self._query(idx, data) 168 | indices.append(_indices) 169 | sims.append(_sims) 170 | return indices, similarities 171 | 172 | def save(self, fname): 173 | pickle.dump({'num_graphs': self.num_graphs}, 174 | open(fname+".metadata", "wb")) 175 | for idx in range(self.num_graphs): 176 | self.index[idx].save(fname+".{}".format(idx)) 177 | 178 | def load(self, fname): 179 | self.num_graphs = pickle.load( 180 | open(fname+".metadata", "rb"))['num_graphs'] 181 | for idx in range(self.num_graphs): 182 | self.index[idx].load(fname+".{}".format(idx)) 183 | 184 | def reset(self): 185 | for idx in range(self.num_graphs): 186 | self.index[idx].reset() 187 | -------------------------------------------------------------------------------- /deepxml/libs/shortlist_handler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .dist_utils import Partitioner 3 | import os 4 | from .sampling import NegativeSampler 5 | from scipy.sparse import load_npz 6 | from xclib.utils import sparse as sp 7 | from xclib.utils.matrix import SMatrix 8 | 9 | 10 | def construct_handler(shortlist_type, num_instances, num_labels, 11 | model_dir='', mode='train', size_shortlist=-1, 12 | label_mapping=None, in_memory=True, 13 | shorty=None, fname=None, corruption=200, 14 | num_clf_partitions=1): 15 | if shortlist_type == 'static': 16 | return ShortlistHandlerStatic( 17 | num_instances, num_labels, model_dir, num_clf_partitions, mode, 18 | size_shortlist, in_memory, label_mapping, fname) 19 | elif shortlist_type == 'hybrid': 20 | return ShortlistHandlerHybrid( 21 | num_instances, num_labels, model_dir, num_clf_partitions, mode, 22 | size_shortlist, in_memory, label_mapping, corruption) 23 | elif shortlist_type == 'dynamic': 24 | return ShortlistHandlerDynamic( 25 | num_labels, shorty, model_dir, 26 | mode, num_clf_partitions, size_shortlist, label_mapping) 27 | else: 28 | raise NotImplementedError( 29 | "Unknown shortlist method: {}!".format(shortlist_type)) 30 | 31 | 32 | class ShortlistHandlerBase(object): 33 | """Base class for ShortlistHandler 34 | - support for partitioned classifier 35 | 36 | Arguments 37 | ---------- 38 | num_labels: int 39 | number of labels 40 | shortlist: 41 | shortlist object 42 | model_dir: str, optional, default='' 43 | save the data in model_dir 44 | num_clf_partitions: int, optional, default='' 45 | #classifier splits 46 | mode: str: optional, default='' 47 | mode i.e. train or test or val 48 | size_shortlist:int, optional, default=-1 49 | get shortlist of this size 50 | label_mapping: None or dict: optional, default=None 51 | map labels as per this mapping 52 | """ 53 | 54 | def __init__(self, num_labels, shortlist, model_dir='', 55 | num_clf_partitions=1, mode='train', size_shortlist=-1, 56 | label_mapping=None, max_pos=20): 57 | self.model_dir = model_dir 58 | self.num_clf_partitions = num_clf_partitions 59 | self.size_shortlist = size_shortlist 60 | self.mode = mode 61 | self.max_pos = max_pos 62 | self.num_labels = num_labels 63 | self.label_mapping = label_mapping 64 | # self._create_shortlist(shortlist) 65 | self._create_partitioner() 66 | self.label_padding_index = self.num_labels 67 | if self.num_clf_partitions > 1: 68 | self.label_padding_index = self.partitioner.get_padding_indices() 69 | 70 | def _create_shortlist(self, shortlist): 71 | """ 72 | Create structure to hold shortlist 73 | """ 74 | self.shortlist = shortlist 75 | 76 | def query(self, *args, **kwargs): 77 | return self.shortlist(*args, **kwargs) 78 | 79 | def _create_partitioner(self): 80 | """ 81 | Create partiotionar to for splitted classifier 82 | """ 83 | self.partitioner = None 84 | if self.num_clf_partitions > 1: 85 | if self.mode == 'train': 86 | self.partitioner = Partitioner( 87 | self.num_labels, self.num_clf_partitions, 88 | padding=False, contiguous=True) 89 | self.partitioner.save(os.path.join( 90 | self.model_dir, 'partitionar.pkl')) 91 | else: 92 | self.partitioner = Partitioner( 93 | self.num_labels, self.num_clf_partitions, 94 | padding=False, contiguous=True) 95 | self.partitioner.load(os.path.join( 96 | self.model_dir, 'partitionar.pkl')) 97 | 98 | def _adjust_shortlist(self, pos_labels, shortlist, sim): 99 | """ 100 | Adjust shortlist for a instance 101 | Training: Add positive labels to the shortlist 102 | Inference: Return shortlist with label mask 103 | """ 104 | if self.mode == 'train': 105 | _target = np.zeros(self.size_shortlist, dtype=np.float32) 106 | _sim = np.zeros(self.size_shortlist, dtype=np.float32) 107 | _shortlist = np.full( 108 | self.size_shortlist, fill_value=self.label_padding_index, 109 | dtype=np.int64) 110 | # TODO: Adjust sim as well 111 | if len(pos_labels) > self.max_pos: 112 | pos_labels = np.random.choice( 113 | pos_labels, size=self.max_pos, replace=False) 114 | neg_labels = shortlist[~np.isin(shortlist, pos_labels)] 115 | _target[:len(pos_labels)] = 1.0 116 | # #TODO not used during training; not perfect values 117 | _sim[:len(pos_labels)] = 1.0 118 | _short = np.concatenate([pos_labels, neg_labels]) 119 | temp = min(len(_short), self.size_shortlist) 120 | _shortlist[:temp] = _short[:temp] 121 | else: 122 | _target = np.zeros(self.size_shortlist, dtype=np.float32) 123 | _shortlist = np.full( 124 | self.size_shortlist, fill_value=self.label_padding_index, 125 | dtype=np.int64) 126 | _shortlist[:len(shortlist)] = shortlist 127 | _target[np.isin(shortlist, pos_labels)] = 1.0 128 | _sim = np.zeros(self.size_shortlist, dtype=np.float32) 129 | _sim[:len(shortlist)] = sim 130 | return _shortlist, _target, _sim 131 | 132 | def _get_sl_one(self, index, pos_labels): 133 | shortlist, sim = self.query(index) 134 | shortlist, target, sim = self._adjust_shortlist( 135 | pos_labels, shortlist, sim) 136 | mask = shortlist != self.label_padding_index 137 | return shortlist, target, sim, mask 138 | 139 | def _get_sl_partitioned(self, index, pos_labels): 140 | # Partition labels 141 | pos_labels = self.partitioner.split_indices(pos_labels) 142 | if self.shortlist.data_init: # Shortlist is initialized 143 | _shortlist, _sim = self.query(index) 144 | shortlist, target, sim, mask, rev_map = [], [], [], [], [] 145 | # Get shortlist for each classifier 146 | for idx in range(self.num_clf_partitions): 147 | __shortlist, __target, __sim, __mask = self._adjust_shortlist( 148 | pos_labels[idx], 149 | _shortlist[idx], 150 | _sim[idx]) 151 | shortlist.append(__shortlist) 152 | target.append(__target) 153 | sim.append(__sim) 154 | mask.append(__mask) 155 | rev_map.append( 156 | self.partitioner.map_to_original(__shortlist, idx)) 157 | rev_map = np.concatenate(rev_map) 158 | else: # Shortlist is un-initialized 159 | shortlist = [np.zeros(self.size_shortlist)]*self.num_clf_partitions 160 | target = [np.zeros(self.size_shortlist)]*self.num_clf_partitions 161 | sim = [np.zeros(self.size_shortlist)]*self.num_clf_partitions 162 | mask = [np.zeros(self.size_shortlist)]*self.num_clf_partitions 163 | rev_map = np.zeros(self.size_shortlist*self.num_clf_partitions) 164 | return shortlist, target, sim, mask, rev_map 165 | 166 | def get_shortlist(self, index, pos_labels=None): 167 | """ 168 | Get data with shortlist for given data index 169 | """ 170 | if self.num_clf_partitions > 1: 171 | return self._get_sl_partitioned(index, pos_labels) 172 | else: 173 | return self._get_sl_one(index, pos_labels) 174 | 175 | def get_partition_indices(self, index): 176 | return self.partitioner.get_indices(index) 177 | 178 | 179 | class ShortlistHandlerStatic(ShortlistHandlerBase): 180 | """ShortlistHandler with static shortlist 181 | - save/load/update/process shortlist 182 | - support for partitioned classifier 183 | 184 | Arguments 185 | ---------- 186 | num_labels: int 187 | number of labels 188 | model_dir: str, optional, default='' 189 | save the data in model_dir 190 | num_clf_partitions: int, optional, default='' 191 | #classifier splits 192 | mode: str: optional, default='' 193 | mode i.e. train or test or val 194 | size_shortlist:int, optional, default=-1 195 | get shortlist of this size 196 | in_memory: bool: optional, default=True 197 | Keep the shortlist in memory or on-disk 198 | label_mapping: None or dict: optional, default=None 199 | map labels as per this mapping 200 | """ 201 | 202 | def __init__(self, num_instances, num_labels, model_dir='', 203 | num_clf_partitions=1, mode='train', size_shortlist=-1, 204 | in_memory=True, label_mapping=None, fname=None): 205 | super().__init__(num_labels, None, model_dir, num_clf_partitions, 206 | mode, size_shortlist, label_mapping) 207 | self.in_memory = in_memory 208 | self._create_shortlist(num_instances, num_labels, size_shortlist) 209 | if fname is not None: 210 | self.from_pretrained(fname) 211 | 212 | def from_pretrained(self, fname): 213 | """ 214 | Load label shortlist and similarity for each instance 215 | """ 216 | shortlist = load_npz(fname) 217 | _ind, _sim = sp.topk(shortlist, 218 | self.size_shortlist, self.num_labels, 219 | -1000, return_values=True) 220 | self.update_shortlist(_ind, _sim) 221 | 222 | def query(self, index): 223 | ind, sim = self.shortlist[index] 224 | return ind, sim 225 | 226 | def _create_shortlist(self, num_instances, num_labels, k): 227 | """ 228 | Create structure to hold shortlist 229 | """ 230 | _type = 'memory' if self.in_memory else 'memmap' 231 | if self.num_clf_partitions > 1: 232 | raise NotImplementedError() 233 | else: 234 | self.shortlist = SMatrix(num_instances, num_labels, k) 235 | 236 | def update_shortlist(self, ind, sim, fname='tmp'): 237 | """ 238 | Update label shortlist for each instance 239 | """ 240 | self.shortlist.update(ind, sim) 241 | del sim, ind 242 | 243 | def save_shortlist(self, fname): 244 | """ 245 | Save label shortlist and similarity for each instance 246 | """ 247 | raise NotImplementedError() 248 | 249 | def load_shortlist(self, fname): 250 | """ 251 | Load label shortlist and similarity for each instance 252 | """ 253 | raise NotImplementedError() 254 | 255 | 256 | class ShortlistHandlerDynamic(ShortlistHandlerBase): 257 | """ShortlistHandler with dynamic shortlist 258 | 259 | Arguments 260 | ---------- 261 | num_labels: int 262 | number of labels 263 | shortlist: 264 | shortlist object like negative sampler 265 | model_dir: str, optional, default='' 266 | save the data in model_dir 267 | mode: str: optional, default='' 268 | mode i.e. train or test or val 269 | size_shortlist:int, optional, default=-1 270 | get shortlist of this size 271 | label_mapping: None or dict: optional, default=None 272 | map labels as per this mapping 273 | """ 274 | 275 | def __init__(self, num_labels, shortlist, model_dir='', 276 | num_clf_partitions=1, mode='train', 277 | size_shortlist=-1, label_mapping=None): 278 | super().__init__( 279 | num_labels, shortlist, model_dir, num_clf_partitions, 280 | mode, size_shortlist, label_mapping) 281 | self._create_shortlist(shortlist) 282 | 283 | def query(self, num_instances=1, ind=None): 284 | return self.shortlist.query( 285 | num_instances=num_instances, ind=ind) 286 | 287 | 288 | class ShortlistHandlerHybrid(ShortlistHandlerBase): 289 | """ShortlistHandler with hybrid shortlist 290 | - save/load/update/process shortlist 291 | - support for partitioned classifier 292 | 293 | Arguments 294 | ---------- 295 | num_labels: int 296 | number of labels 297 | model_dir: str, optional, default='' 298 | save the data in model_dir 299 | num_clf_partitions: int, optional, default='' 300 | #classifier splits 301 | mode: str: optional, default='' 302 | mode i.e. train or test or val 303 | size_shortlist:int, optional, default=-1 304 | get shortlist of this size 305 | in_memory: bool: optional, default=True 306 | Keep the shortlist in memory or on-disk 307 | label_mapping: None or dict: optional, default=None 308 | map labels as per this mapping 309 | _corruption: int, optional, default=None 310 | add these many random labels 311 | """ 312 | 313 | def __init__(self, num_instances, num_labels, model_dir='', 314 | num_clf_partitions=1, mode='train', size_shortlist=-1, 315 | in_memory=True, label_mapping=None, _corruption=200): 316 | super().__init__(num_labels, None, model_dir, num_clf_partitions, 317 | mode, size_shortlist, label_mapping) 318 | self.in_memory = in_memory 319 | self._create_shortlist(num_instances, num_labels, size_shortlist) 320 | self.shortlist_dynamic = NegativeSampler(num_labels, _corruption+20) 321 | self.size_shortlist = size_shortlist+_corruption # Both 322 | 323 | def query(self, index): 324 | ind, sim = self.shortlist[index] 325 | _ind, _sim = self.shortlist_dynamic.query(1) 326 | ind = np.concatenate([ind, _ind]) 327 | sim = np.concatenate([sim, _sim]) 328 | return ind, sim 329 | 330 | def _create_shortlist(self, num_instances, num_labels, k): 331 | """ 332 | Create structure to hold shortlist 333 | """ 334 | _type = 'memory' if self.in_memory else 'memmap' 335 | if self.num_clf_partitions > 1: 336 | raise NotImplementedError() 337 | else: 338 | self.shortlist = SMatrix(num_instances, num_labels, k) 339 | 340 | def update_shortlist(self, ind, sim, fname='tmp'): 341 | """ 342 | Update label shortlist for each instance 343 | """ 344 | self.shortlist.update(ind, sim) 345 | del sim, ind 346 | 347 | def save_shortlist(self, fname): 348 | """ 349 | Save label shortlist and similarity for each instance 350 | """ 351 | raise NotImplementedError() 352 | 353 | def load_shortlist(self, fname): 354 | """ 355 | Load label shortlist and similarity for each instance 356 | """ 357 | raise NotImplementedError() 358 | -------------------------------------------------------------------------------- /deepxml/libs/tracking.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tracking object; Maintain history of loss; accuracy etc. 3 | """ 4 | 5 | import pickle 6 | 7 | 8 | class Tracking(object): 9 | def __init__(self): 10 | self.checkpoint_history = 3 11 | self.mean_train_loss = [] 12 | self.mean_val_loss = [] 13 | self.saved_models = [] 14 | self.val_precision = [] 15 | self.val_ndcg = [] 16 | self.train_time = 0 17 | self.validation_time = 0 18 | self.shortlist_time = 0 19 | self.saved_checkpoints = [] 20 | self.last_saved_epoch = -1 21 | self.last_epoch = 0 22 | 23 | def save(self, fname): 24 | pickle.dump(self.__dict__, open(fname, 'wb')) 25 | 26 | def load(self, fname): 27 | self.__dict__ = pickle.load(open(fname, 'rb')) 28 | -------------------------------------------------------------------------------- /deepxml/libs/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import os 4 | from scipy.sparse import save_npz 5 | from xclib.utils.sparse import _map_cols 6 | 7 | 8 | def save_predictions(preds, result_dir, valid_labels, num_samples, 9 | num_labels, get_fnames=['knn', 'clf', 'combined'], 10 | prefix='predictions'): 11 | if isinstance(preds, dict): 12 | for _fname, _pred in preds.items(): 13 | if _fname in get_fnames: 14 | if valid_labels is not None: 15 | predicted_labels = _map_cols( 16 | _pred, valid_labels, 17 | shape=(num_samples, num_labels)) 18 | else: 19 | predicted_labels = _pred 20 | save_npz(os.path.join( 21 | result_dir, '{}_{}.npz'.format(prefix, _fname)), 22 | predicted_labels, compressed=False) 23 | else: 24 | if valid_labels is not None: 25 | predicted_labels = _map_cols( 26 | preds, valid_labels, 27 | shape=(num_samples, num_labels)) 28 | else: 29 | predicted_labels = preds 30 | save_npz(os.path.join(result_dir, '{}.npz'.format(prefix)), 31 | predicted_labels, compressed=False) 32 | 33 | 34 | def append_padding_classifier_one(classifier, num_labels, 35 | key_w='classifier.weight', 36 | key_b='classifier.bias'): 37 | _num_labels, dims = classifier[key_w].size() 38 | if _num_labels != num_labels: 39 | status = "Appended padding classifier." 40 | _device = classifier[key_w].device 41 | classifier[key_w] = torch.cat( 42 | [classifier[key_w], torch.zeros(1, dims).to(_device)], 0) 43 | if key_b in classifier: 44 | classifier[key_b] = torch.cat( 45 | [classifier[key_b], -1e5*torch.ones(1, 1).to(_device)], 0) 46 | else: 47 | status = "Shapes are fine, Not padding again." 48 | return status 49 | 50 | 51 | def append_padding_classifier(net, num_labels): 52 | if isinstance(num_labels, list): 53 | status = [] 54 | for idx, item in enumerate(num_labels): 55 | status.append(append_padding_classifier_one( 56 | net, item, 'classifier.classifier.{}.weight'.format( 57 | idx), 'classifier.classifier.{}.bias'.format(idx))) 58 | print("Padding not implemented for distributed classifier for now!") 59 | else: 60 | return append_padding_classifier_one(net, num_labels) 61 | 62 | 63 | def get_header(fname): 64 | with open(fname, 'r') as fp: 65 | line = fp.readline() 66 | return list(map(int, line.split(" "))) 67 | 68 | 69 | def get_data_stats(fname, key): 70 | def get(fname, key): 71 | with open(fname, 'r') as fp: 72 | val = json.load(fp)[key] 73 | return val 74 | if isinstance(key, tuple): 75 | out = [] 76 | for _key in key: 77 | out.append(get(fname, _key)) 78 | return tuple(out) 79 | else: 80 | return get(fname, key) 81 | 82 | 83 | def save_parameters(fname, params): 84 | json.dump({'num_labels': params.num_labels, 85 | 'vocabulary_dims': params.vocabulary_dims, 86 | 'use_shortlist': params.use_shortlist, 87 | 'ann_method': params.ann_method, 88 | 'num_nbrs': params.num_nbrs, 89 | 'arch': params.arch, 90 | 'embedding_dims': params.embedding_dims, 91 | 'num_clf_partitions': params.num_clf_partitions, 92 | 'label_padding_index': params.label_padding_index, 93 | 'keep_invalid': params.keep_invalid}, 94 | open(fname, 'w'), 95 | sort_keys=True, 96 | indent=4) 97 | 98 | 99 | def load_parameters(fname, params): 100 | temp = json.load(open(fname, 'r')) 101 | params.num_labels = temp['num_labels'] 102 | params.vocabulary_dims = temp['vocabulary_dims'] 103 | params.num_nbrs = temp['num_nbrs'] 104 | params.arch = temp['arch'] 105 | params.num_clf_partitions = temp['num_clf_partitions'] 106 | params.label_padding_index = temp['label_padding_index'] 107 | params.ann_method = temp['ann_method'] 108 | params.embedding_dims = temp['embedding_dims'] 109 | params.keep_invalid = temp['keep_invalid'] 110 | -------------------------------------------------------------------------------- /deepxml/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /deepxml/models/astec.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import models.embedding_layer as embedding_layer 3 | 4 | 5 | class Astec(nn.Module): 6 | """ 7 | Encode a document using the feature representaion as per Astec 8 | 9 | Arguments: 10 | ---------- 11 | num_embeddings: int 12 | vocalubary size 13 | embedding_dim: int, optional (default=300) 14 | dimension for embeddings 15 | dropout: float, optional (default=0.5) 16 | drop probability 17 | padding_idx: int, optional (default=0) 18 | index for ; embedding is not updated 19 | Values other than 0 are not yet tested 20 | reduction: str or None, optional (default=None) 21 | * None: don't reduce 22 | * sum: sum over tokens 23 | * mean: mean over tokens 24 | sparse: boolean, optional (default=False) 25 | sparse or dense gradients 26 | * the optimizer will infer from this parameters 27 | freeze_embeddings: boolean, optional (default=False) 28 | * freeze the gradient of token embeddings 29 | device: str, optional (default="cuda:0") 30 | Keep embeddings on this device 31 | """ 32 | def __init__(self, vocabulary_dims, embedding_dims=300, 33 | dropout=0.5, padding_idx=0, reduction='sum', 34 | sparse=True, freeze=False, device="cuda:0"): 35 | super(Astec, self).__init__() 36 | self.vocabulary_dims = vocabulary_dims + 1 37 | self.embedding_dims = embedding_dims 38 | self.padding_idx = padding_idx 39 | self.device = device 40 | self.sparse = sparse 41 | self.reduction = reduction 42 | self.embeddings = self._construct_embedding() 43 | self.relu = nn.ReLU() 44 | self.dropout = nn.Dropout(dropout) 45 | self.freeze = freeze 46 | if self.freeze: 47 | for params in self.embeddings.parameters(): 48 | params.requires_grad = False 49 | 50 | def _construct_embedding(self): 51 | return embedding_layer.Embedding( 52 | num_embeddings=self.vocabulary_dims, 53 | embedding_dim=self.embedding_dims, 54 | padding_idx=self.padding_idx, 55 | scale_grad_by_freq=False, 56 | device=self.device, 57 | reduction=self.reduction, 58 | sparse=self.sparse) 59 | 60 | def encoder(self, x, x_ind): 61 | if x_ind is None: # Assume embedding is pre-computed 62 | return x 63 | else: 64 | return self.embeddings(x_ind, x) 65 | 66 | def forward(self, x): 67 | """ 68 | Arguments: 69 | ---------- 70 | x: (torch.Tensor or None, torch.LongTensor) 71 | token weights and indices 72 | weights can be None 73 | 74 | Returns: 75 | -------- 76 | embed: torch.Tensor 77 | transformed document representation 78 | Dimension depends on reduction 79 | """ 80 | return self.dropout(self.relu(self.encoder(*x))) 81 | 82 | def to(self): 83 | super().to(self.device) 84 | 85 | def initialize(self, x): 86 | self.embeddings.from_pretrained(x) 87 | 88 | def initialize_token_embeddings(self, x): 89 | return self.initialize(x) 90 | 91 | def get_token_embeddings(self): 92 | return self.embeddings.get_weights() 93 | 94 | @property 95 | def representation_dims(self): 96 | return self.embedding_dims 97 | -------------------------------------------------------------------------------- /deepxml/models/embedding_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.parameter import Parameter 3 | import torch.nn.functional as F 4 | 5 | 6 | class Embedding(torch.nn.Module): 7 | """ 8 | General way to handle embeddings 9 | 10 | * Support for sequential models 11 | * Memory efficient way to compute weighted EmbeddingBag 12 | 13 | Arguments: 14 | ---------- 15 | num_embeddings: int 16 | vocalubary size 17 | embedding_dim: int 18 | dimension for embeddings 19 | padding_idx: 0 or None, optional (default=None) 20 | index for ; embedding is not updated 21 | max_norm: None or float, optional (default=None) 22 | maintain norm of embeddings 23 | norm_type: int, optional (default=2) 24 | norm for max_norm 25 | scale_grad_by_freq: boolean, optional (default=False) 26 | Scale gradients by token frequency 27 | sparse: boolean, optional (default=False) 28 | sparse or dense gradients 29 | * the optimizer will infer from this parameters 30 | reduction: str or None, optional (default=None) 31 | * None: don't reduce 32 | * sum: sum over tokens 33 | * mean: mean over tokens 34 | pretrained_weights: torch.Tensor or None, optional (default=None) 35 | Initialize with these weights 36 | * first token is treated as a padding index 37 | * dim=1 should be one less than the num_embeddings 38 | device: str, optional (default="cuda:0") 39 | Keep embeddings on this device 40 | """ 41 | 42 | def __init__(self, num_embeddings, embedding_dim, padding_idx=None, 43 | max_norm=None, norm_type=2, scale_grad_by_freq=False, 44 | sparse=False, reduction=True, pretrained_weights=None, 45 | device="cuda:0"): 46 | super(Embedding, self).__init__() 47 | self.num_embeddings = num_embeddings 48 | self.embedding_dim = embedding_dim 49 | self.padding_idx = padding_idx 50 | self.max_norm = max_norm 51 | self.norm_type = norm_type 52 | self.scale_grad_by_freq = scale_grad_by_freq 53 | self.weight = Parameter(torch.Tensor(num_embeddings, embedding_dim)) 54 | self.sparse = sparse 55 | self.reduce = self._construct_reduce(reduction) 56 | self.reduction = reduction 57 | self.device = torch.device(device) 58 | self.reset_parameters() 59 | if pretrained_weights is not None: 60 | self.from_pretrained(pretrained_weights) 61 | 62 | def _construct_reduce(self, reduction): 63 | if reduction is None: 64 | return self._reduce 65 | elif reduction == 'sum': 66 | return self._reduce_sum 67 | elif reduction == 'mean': 68 | return self._reduce_mean 69 | else: 70 | return NotImplementedError(f"Unknown reduction: {reduction}") 71 | 72 | def reset_parameters(self): 73 | """ 74 | Reset weights 75 | """ 76 | torch.nn.init.xavier_uniform_( 77 | self.weight.data, gain=torch.nn.init.calculate_gain('relu')) 78 | if self.padding_idx is not None: 79 | self.weight.data[self.padding_idx].fill_(0) 80 | 81 | def to(self): 82 | super().to(self.device) 83 | 84 | def _reduce_sum(self, x, w): 85 | if w is None: 86 | return torch.sum(x, dim=1) 87 | else: 88 | return torch.sum(x * w.unsqueeze(2), dim=1) 89 | 90 | def _reduce_mean(self, x, w): 91 | if w is None: 92 | return torch.mean(x, dim=1) 93 | else: 94 | return torch.mean(x * w.unsqueeze(2), dim=1) 95 | 96 | def _reduce(self, x, *args): 97 | return x 98 | 99 | def forward(self, x, w=None): 100 | """ 101 | Forward pass for embedding layer 102 | 103 | Arguments: 104 | --------- 105 | x: torch.LongTensor 106 | indices of tokens in a batch 107 | (batch_size, max_features_in_a_batch) 108 | w: torch.Tensor or None, optional (default=None) 109 | weights of tokens in a batch 110 | (batch_size, max_features_in_a_batch) 111 | 112 | Returns: 113 | -------- 114 | out: torch.Tensor 115 | embedding for each sample 116 | Shape: (batch_size, seq_len, embedding_dims), if reduction is None 117 | Shape: (batch_size, embedding_dims), otherwise 118 | """ 119 | x = F.embedding( 120 | x, self.weight, 121 | self.padding_idx, self.max_norm, self.norm_type, 122 | self.scale_grad_by_freq, self.sparse) 123 | return self.reduce(x, w) 124 | 125 | def from_pretrained(self, embeddings): 126 | # first index is treated as padding index 127 | assert embeddings.shape[0] == self.num_embeddings-1, \ 128 | "Shapes doesn't match for pre-trained embeddings" 129 | self.weight.data[1:, :] = torch.from_numpy(embeddings) 130 | 131 | def get_weights(self): 132 | return self.weight.detach().cpu().numpy()[1:, :] 133 | 134 | def __repr__(self): 135 | s = '{name}({num_embeddings}, {embedding_dim}, {device}' 136 | s += ', reduction={reduction}' 137 | if self.padding_idx is not None: 138 | s += ', padding_idx={padding_idx}' 139 | if self.max_norm is not None: 140 | s += ', max_norm={max_norm}' 141 | if self.norm_type != 2: 142 | s += ', norm_type={norm_type}' 143 | if self.scale_grad_by_freq is not False: 144 | s += ', scale_grad_by_freq={scale_grad_by_freq}' 145 | if self.sparse is not False: 146 | s += ', sparse=True' 147 | s += ')' 148 | return s.format(name=self.__class__.__name__, **self.__dict__) 149 | -------------------------------------------------------------------------------- /deepxml/models/linear_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from torch.nn.parameter import Parameter 6 | import math 7 | 8 | __author__ = 'KD' 9 | 10 | 11 | class Linear(nn.Module): 12 | """Linear layer 13 | Parameters: 14 | ----------- 15 | input_size: int 16 | input size of transformation 17 | output_size: int 18 | output size of transformation 19 | bias: boolean, default=True 20 | whether to use bias or not 21 | device: str, default="cuda:0" 22 | keep on this device 23 | """ 24 | 25 | def __init__(self, input_size, output_size, 26 | bias=True, device="cuda:0"): 27 | super(Linear, self).__init__() 28 | self.device = device # Useful in case of multiple GPUs 29 | self.input_size = input_size 30 | self.output_size = output_size 31 | self.weight = Parameter( 32 | torch.Tensor(self.output_size, self.input_size)) 33 | if bias: 34 | self.bias = Parameter(torch.Tensor(self.output_size, 1)) 35 | else: 36 | self.register_parameter('bias', None) 37 | self.reset_parameters() 38 | 39 | def forward(self, input): 40 | if self.bias is not None: 41 | return F.linear( 42 | input.to(self.device), self.weight, self.bias.view(-1)) 43 | else: 44 | return F.linear( 45 | input.to(self.device), self.weight) 46 | 47 | def to(self): 48 | """Transfer to device 49 | """ 50 | super().to(self.device) 51 | 52 | def reset_parameters(self): 53 | """Initialize vectors 54 | """ 55 | torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 56 | if self.bias is not None: 57 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) 58 | bound = 1 / math.sqrt(fan_in) 59 | torch.nn.init.uniform_(self.bias, -bound, bound) 60 | # stdv = 1. / math.sqrt(self.weight.size(1)) 61 | # self.weight.data.uniform_(-stdv, stdv) 62 | # if self.bias is not None: 63 | # self.bias.data.uniform_(-stdv, stdv) 64 | 65 | def get_weights(self): 66 | """Get weights as numpy array 67 | Bias is appended in the end 68 | """ 69 | _wts = self.weight.detach().cpu().numpy() 70 | if self.bias is not None: 71 | _bias = self.bias.detach().cpu().numpy() 72 | _wts = np.hstack([_wts, _bias]) 73 | return _wts 74 | 75 | def __repr__(self): 76 | s = '{name}({input_size}, {output_size}, {device}' 77 | if self.bias is not None: 78 | s += ', bias=True' 79 | s += ')' 80 | return s.format(name=self.__class__.__name__, **self.__dict__) 81 | 82 | @property 83 | def sparse(self): 84 | return False 85 | 86 | 87 | class SparseLinear(Linear): 88 | """Sparse Linear linear with sparse gradients 89 | Parameters: 90 | ----------- 91 | input_size: int 92 | input size of transformation 93 | output_size: int 94 | output size of transformation 95 | padding_idx: int 96 | index for dummy label; embedding is not updated 97 | bias: boolean, default=True 98 | whether to use bias or not 99 | device: str, default="cuda:0" 100 | keep on this device 101 | """ 102 | 103 | def __init__(self, input_size, output_size, padding_idx=None, 104 | bias=True, device="cuda:0"): 105 | self.padding_idx = padding_idx 106 | super(SparseLinear, self).__init__( 107 | input_size=input_size, 108 | output_size=output_size, 109 | bias=bias, 110 | device=device) 111 | 112 | def forward(self, embed, shortlist): 113 | """Forward pass for Linear sparse layer 114 | Parameters: 115 | ---------- 116 | embed: torch.Tensor 117 | input to the layer 118 | shortlist: torch.LongTensor 119 | evaluate these labels only 120 | 121 | Returns 122 | ------- 123 | out: torch.Tensor 124 | logits for each label in provided shortlist 125 | """ 126 | embed = embed.to(self.device) 127 | shortlist = shortlist.to(self.device) 128 | short_weights = F.embedding(shortlist, 129 | self.weight, 130 | sparse=self.sparse, 131 | padding_idx=self.padding_idx) 132 | out = torch.matmul(embed.unsqueeze(1), short_weights.permute(0, 2, 1)) 133 | if self.bias is not None: 134 | short_bias = F.embedding(shortlist, 135 | self.bias, 136 | sparse=self.sparse, 137 | padding_idx=self.padding_idx) 138 | out = out + short_bias.permute(0, 2, 1) 139 | return out.squeeze() 140 | 141 | def reset_parameters(self): 142 | """Initialize weights vectors 143 | """ 144 | super().reset_parameters() 145 | if self.padding_idx is not None: 146 | self.weight.data[self.padding_idx].fill_(0) 147 | 148 | def __repr__(self): 149 | s = '{name}({input_size}, {output_size}, {device}' 150 | if self.bias is not None: 151 | s += ', bias=True' 152 | if self.padding_idx is not None: 153 | s += ', padding_idx={padding_idx}' 154 | s += ', sparse=True)' 155 | return s.format(name=self.__class__.__name__, **self.__dict__) 156 | 157 | def get_weights(self): 158 | """Get weights as numpy array 159 | Bias is appended in the end 160 | """ 161 | _wts = self.weight.detach().cpu().numpy() 162 | if self.padding_idx is not None: 163 | _wts = _wts[:-1, :] 164 | if (self.bias is not None): 165 | _bias = self.bias.detach().cpu().numpy() 166 | if self.padding_idx is not None: 167 | _bias = _bias[:-1, :] 168 | _wts = np.hstack([_wts, _bias]) 169 | return _wts 170 | 171 | @property 172 | def sparse(self): 173 | return True 174 | 175 | 176 | class ParallelLinear(nn.Module): 177 | """Distributed Linear layer with support for multiple devices 178 | Parameters: 179 | ----------- 180 | input_size: int 181 | input size of transformation 182 | output_size: int 183 | output size of transformation 184 | bias: boolean, default=True 185 | whether to use bias or not 186 | num_partitions: int, default=2 187 | partition classifier in these many partitions 188 | device: list or None, default=None 189 | devices for each partition; keep "cuda:0" for everyone if None 190 | """ 191 | 192 | def __init__(self, input_size, output_size, bias=True, 193 | num_partitions=2, devices=None): 194 | super(ParallelLinear, self).__init__() 195 | self.input_size = input_size 196 | self.output_size = output_size 197 | self.devices = devices 198 | self.bias = bias 199 | if devices is None: # Keep everything on cuda:0 200 | self.devices = ["cuda:{}".format(idx) for idx in num_partitions] 201 | self.num_partitions = num_partitions 202 | self.classifier = self._construct() 203 | 204 | def _construct(self): 205 | self._output_sizes = [item.size for item in np.array_split( 206 | np.arange(self.output_size), self.num_partitions)] 207 | clf = nn.ModuleList() 208 | 209 | # Input size is same for everyone 210 | for out in zip(self._output_sizes, self.bias, self.devices): 211 | clf.append(Linear(self.input_size, *out)) 212 | return clf 213 | 214 | def forward(self, embed): 215 | """Forward pass 216 | Arguments: 217 | ----------- 218 | embed: torch.Tensor 219 | input to the layer 220 | 221 | Returns: 222 | -------- 223 | out: list 224 | logits for each partition 225 | """ 226 | out = [] # Sequential for now 227 | for idx in range(self.num_partitions): 228 | out.append(self.classifier[idx](embed)) 229 | return out 230 | 231 | def to(self): 232 | """ Transfer to device 233 | """ 234 | for item in self.classifier: 235 | item.to() 236 | 237 | def get_weights(self): 238 | """Get weights as numpy array 239 | Bias is appended in the end 240 | """ 241 | out = [item.get_weights() for item in self.classifier] 242 | return np.vstack(out) 243 | 244 | 245 | class ParallelSparseLinear(ParallelLinear): 246 | """Distributed Linear layer with support for multiple devices 247 | Parameters: 248 | ----------- 249 | input_size: int 250 | input size of transformation 251 | output_size: int 252 | output size of transformation 253 | padding_idx: int or None, default=None 254 | padding index in classifier 255 | bias: boolean, default=True 256 | whether to use bias or not 257 | num_partitions: int, default=2 258 | partition classifier in these many partitions 259 | device: list or None, default=None 260 | devices for each partition; keep "cuda:0" for everyone if None 261 | """ 262 | 263 | def __init__(self, input_size, output_size, padding_idx=None, 264 | bias=True, num_partitions=2, devices=None): 265 | self.padding_idx = padding_idx 266 | super(ParallelSparseLinear, self).__init__( 267 | input_size=input_size, 268 | output_size=output_size, 269 | bias=bias, 270 | num_partitions=num_partitions, 271 | devices=devices) 272 | 273 | def _construct(self): 274 | self._output_sizes = [item.size for item in np.array_split( 275 | np.arange(self.output_size), self.num_partitions)] 276 | clf = nn.ModuleList() 277 | for out in zip(self._output_sizes, self.padding_idx, self.bias, self.devices): 278 | clf.append(SparseLinear(self.input_size, *out)) 279 | return clf 280 | 281 | def forward(self, embed, shortlist): 282 | """Forward pass 283 | Arguments: 284 | ----------- 285 | embed: torch.Tensor 286 | input to the layer 287 | shortlist: [torch.LongTensor] 288 | Shortlist for each partition 289 | 290 | Returns: 291 | -------- 292 | out: list 293 | logits for each partition 294 | """ 295 | out = [] 296 | for idx in range(self.num_partitions): 297 | out.append(self.classifier[idx](embed, shortlist[idx])) 298 | return out 299 | -------------------------------------------------------------------------------- /deepxml/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | __author__ = 'KD' 6 | 7 | 8 | class MLP(nn.Module): 9 | """ 10 | A multi-layer perceptron with flexibility for non-liearity 11 | * no non-linearity after last layer 12 | * support for 2D or 3D inputs 13 | 14 | Parameters: 15 | ----------- 16 | input_size: int 17 | input size of embeddings 18 | hidden_size: int or list of ints or str (comma separated) 19 | e.g., 512: a single hidden layer with 512 neurons 20 | "512": a single hidden layer with 512 neurons 21 | "512,300": 512 -> nnl -> 300 22 | [512, 300]: 512 -> nnl -> 300 23 | dimensionality of layers in MLP 24 | nnl: str, optional, default='relu' 25 | which non-linearity to use 26 | device: str, default="cuda:0" 27 | keep on this device 28 | """ 29 | def __init__(self, input_size, hidden_size, nnl='relu', device="cuda:0"): 30 | super(MLP, self).__init__() 31 | hidden_size = self.parse_hidden_size(hidden_size) 32 | assert len(hidden_size) >= 1, "Should contain atleast 1 hidden layer" 33 | hidden_size = [input_size] + hidden_size 34 | self.device = torch.device(device) 35 | layers = [] 36 | for i, (i_s, o_s) in enumerate(zip(hidden_size[:-1], hidden_size[1:])): 37 | layers.append(nn.Linear(i_s, o_s, bias=True)) 38 | if i < len(hidden_size) - 2: 39 | layers.append(self._get_nnl(nnl)) 40 | self.transform = torch.nn.Sequential(*layers) 41 | 42 | def parse_hidden_size(self, hidden_size): 43 | if isinstance(hidden_size, int): 44 | return [hidden_size] 45 | elif isinstance(hidden_size, str): 46 | _hidden_size = [] 47 | for item in hidden_size.split(","): 48 | _hidden_size.append(int(item)) 49 | return _hidden_size 50 | elif isinstance(hidden_size, list): 51 | return hidden_size 52 | else: 53 | raise NotImplementedError("hidden_size must be a int, str or list") 54 | 55 | def _get_nnl(self, nnl): 56 | if nnl == 'sigmoid': 57 | return torch.nn.Sigmoid() 58 | elif nnl == 'relu': 59 | return torch.nn.ReLU() 60 | elif nnl == 'gelu': 61 | return torch.nn.GELU() 62 | elif nnl == 'tanh': 63 | return torch.nn.Tanh() 64 | else: 65 | raise NotImplementedError(f"{nnl} not implemented!") 66 | 67 | def forward(self, x): 68 | return self.transform(x) 69 | 70 | def to(self): 71 | """Transfer to device 72 | """ 73 | super().to(self.device) 74 | 75 | @property 76 | def sparse(self): 77 | return False 78 | -------------------------------------------------------------------------------- /deepxml/models/network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import os 5 | import models.transform_layer as transform_layer 6 | import models.linear_layer as linear_layer 7 | 8 | 9 | __author__ = 'KD' 10 | 11 | 12 | def _to_device(x, device): 13 | if x is None: 14 | return None 15 | elif isinstance(x, (tuple, list)): 16 | out = [] 17 | for item in x: 18 | out.append(_to_device(item, device)) 19 | return out 20 | else: 21 | return x.to(device) 22 | 23 | 24 | class DeepXMLBase(nn.Module): 25 | """DeepXMLBase: Base class for DeepXML architecture 26 | 27 | * Identity op as classifier by default 28 | (derived class should implement it's own classifier) 29 | * embedding and classifier shall automatically transfer 30 | the vector to the appropriate device 31 | 32 | Arguments: 33 | ---------- 34 | vocabulary_dims: int 35 | number of tokens in the vocabulary 36 | embedding_dims: int 37 | size of word/token representations 38 | trans_config: list of strings 39 | configuration of the transformation layer 40 | padding_idx: int, default=0 41 | padding index in words embedding layer 42 | """ 43 | 44 | def __init__(self, config, device="cuda:0"): 45 | super(DeepXMLBase, self).__init__() 46 | self.transform = self._construct_transform(config) 47 | self.classifier = self._construct_classifier() 48 | self.device = torch.device(device) 49 | 50 | def _construct_classifier(self): 51 | return nn.Identity() 52 | 53 | def _construct_transform(self, trans_config): 54 | return transform_layer.Transform( 55 | transform_layer.get_functions(trans_config)) 56 | 57 | @property 58 | def representation_dims(self): 59 | return self._repr_dims 60 | 61 | @representation_dims.setter 62 | def representation_dims(self, dims): 63 | self._repr_dims = dims 64 | 65 | def encode(self, x): 66 | """Forward pass 67 | * Assumes features are dense if x_ind is None 68 | 69 | Arguments: 70 | ----------- 71 | x: tuple 72 | torch.FloatTensor or None 73 | (sparse features) contains weights of features as per x_ind or 74 | (dense features) contains the dense representation of a point 75 | torch.LongTensor or None 76 | contains indices of features (sparse or seqential features) 77 | 78 | Returns 79 | ------- 80 | out: logits for each label 81 | """ 82 | return self.transform( 83 | _to_device(x, self.device)) 84 | 85 | def forward(self, batch_data, *args): 86 | """Forward pass 87 | * Assumes features are dense if X_w is None 88 | * By default classifier is identity op 89 | 90 | Arguments: 91 | ----------- 92 | batch_data: dict 93 | * 'X': torch.FloatTensor 94 | feature weights for given indices or dense rep. 95 | * 'X_ind': torch.LongTensor 96 | feature indices (LongTensor) or None 97 | 98 | Returns 99 | ------- 100 | out: logits for each label 101 | """ 102 | return self.classifier( 103 | self.encode(batch_data['X'], batch_data['X_ind'])) 104 | 105 | def initialize(self, x): 106 | """Initialize embeddings from existing ones 107 | Parameters: 108 | ----------- 109 | word_embeddings: numpy array 110 | existing embeddings 111 | """ 112 | self.transform.initialize(x) 113 | 114 | def to(self): 115 | """Send layers to respective devices 116 | """ 117 | self.transform.to() 118 | self.classifier.to() 119 | 120 | def purge(self, fname): 121 | if os.path.isfile(fname): 122 | os.remove(fname) 123 | 124 | @property 125 | def num_params(self, ignore_fixed=False): 126 | if ignore_fixed: 127 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 128 | else: 129 | return sum(p.numel() for p in self.parameters()) 130 | 131 | @property 132 | def model_size(self): # Assumptions: 32bit floats 133 | return self.num_params * 4 / math.pow(2, 20) 134 | 135 | def __repr__(self): 136 | return f"{self.embeddings}\n(Transform): {self.transform}" 137 | 138 | 139 | class DeepXMLf(DeepXMLBase): 140 | """DeepXMLf: Network for DeepXML's architecture 141 | with fully-connected o/p layer (a.k.a 1-vs.-all in literature) 142 | 143 | Allows additional transform layer to transform features from the 144 | base class. e.g. base class can handle intermediate rep. and transform 145 | could be used to the intermediate rep. from base class 146 | """ 147 | 148 | def __init__(self, params): 149 | self.num_labels = params.num_labels 150 | self.num_clf_partitions = params.num_clf_partitions 151 | transform_config_dict = transform_layer.fetch_json( 152 | params.arch, params) 153 | trans_config_coarse = transform_config_dict['transform_coarse'] 154 | self.representation_dims = int( 155 | transform_config_dict['representation_dims']) 156 | self._bias = params.bias 157 | super(DeepXMLf, self).__init__(trans_config_coarse) 158 | if params.freeze_intermediate: 159 | print("Freezing intermediate model parameters!") 160 | for params in self.transform.parameters(): 161 | params.requires_grad = False 162 | trans_config_fine = transform_config_dict['transform_fine'] 163 | self.transform_fine = self._construct_transform( 164 | trans_config_fine) 165 | 166 | def encode_fine(self, x): 167 | """Forward pass (assumes input is coarse computation) 168 | 169 | Arguments: 170 | ----------- 171 | x: torch.FloatTensor 172 | (sparse features) contains weights of features as per x_ind or 173 | (dense features) contains the dense representation of a point 174 | 175 | Returns 176 | ------- 177 | out: torch.FloatTensor 178 | encoded x with fine encoder 179 | """ 180 | return self.transform_fine(_to_device(x, self.device)) 181 | 182 | def encode(self, x, x_ind=None, bypass_fine=False): 183 | """Forward pass 184 | * Assumes features are dense if x_ind is None 185 | 186 | Arguments: 187 | ----------- 188 | x: torch.FloatTensor 189 | (sparse features) contains weights of features as per x_ind or 190 | (dense features) contains the dense representation of a point 191 | x_ind: torch.LongTensor or None, optional (default=None) 192 | contains indices of features (sparse features) 193 | bypass_fine: boolean, optional (default=False) 194 | Return coarse features or not 195 | 196 | Returns 197 | ------- 198 | out: logits for each label 199 | """ 200 | encoding = super().encode((x, x_ind)) 201 | return encoding if bypass_fine else self.transform_fine(encoding) 202 | 203 | def forward(self, batch_data, bypass_coarse=False): 204 | """Forward pass 205 | * Assumes features are dense if X_w is None 206 | * By default classifier is identity op 207 | 208 | Arguments: 209 | ----------- 210 | batch_data: dict 211 | * 'X': torch.FloatTensor 212 | feature weights for given indices or dense rep. 213 | * 'X_ind': torch.LongTensor 214 | feature indices (LongTensor) or None 215 | 216 | Returns 217 | ------- 218 | out: logits for each label 219 | """ 220 | if bypass_coarse: 221 | return self.classifier( 222 | self.encode_fine(batch_data['X'])) 223 | else: 224 | return self.classifier( 225 | self.encode(batch_data['X'], batch_data['X_ind'])) 226 | 227 | def _construct_classifier(self): 228 | if self.num_clf_partitions > 1: # Run the distributed version 229 | _bias = [self._bias for _ in range(self.num_clf_partitions)] 230 | _clf_devices = ["cuda:{}".format( 231 | idx) for idx in range(self.num_clf_partitions)] 232 | return linear_layer.ParallelLinear( 233 | input_size=self.representation_dims, 234 | output_size=self.num_labels, 235 | bias=_bias, 236 | num_partitions=self.num_clf_partitions, 237 | devices=_clf_devices) 238 | else: 239 | return linear_layer.Linear( 240 | input_size=self.representation_dims, 241 | output_size=self.num_labels, # last one is padding index 242 | bias=self._bias 243 | ) 244 | 245 | def get_token_embeddings(self): 246 | return self.transform.get_token_embeddings() 247 | 248 | def save_intermediate_model(self, fname): 249 | torch.save(self.transform.state_dict(), fname) 250 | 251 | def load_intermediate_model(self, fname): 252 | self.transform.load_state_dict(torch.load(fname)) 253 | 254 | def to(self): 255 | """Send layers to respective devices 256 | """ 257 | self.transform_fine.to() 258 | super().to() 259 | 260 | def initialize_classifier(self, weight, bias=None): 261 | """Initialize classifier from existing weights 262 | 263 | Arguments: 264 | ----------- 265 | weight: numpy.ndarray 266 | bias: numpy.ndarray or None, optional (default=None) 267 | """ 268 | self.classifier.weight.data.copy_(torch.from_numpy(weight)) 269 | if bias is not None: 270 | self.classifier.bias.data.copy_( 271 | torch.from_numpy(bias).view(-1, 1)) 272 | 273 | def get_clf_weights(self): 274 | """Get classifier weights 275 | """ 276 | return self.classifier.get_weights() 277 | 278 | def __repr__(self): 279 | s = f"{self.transform}\n" 280 | s += f"(Transform fine): {self.transform_fine}" 281 | s += f"\n(Classifier): {self.classifier}\n" 282 | return s 283 | 284 | 285 | class DeepXMLs(DeepXMLBase): 286 | """DeepXMLt: DeepXML architecture to be trained with 287 | a label shortlist 288 | * Allows additional transform layer for features 289 | """ 290 | 291 | def __init__(self, params): 292 | self.num_labels = params.num_labels 293 | self.num_clf_partitions = params.num_clf_partitions 294 | self.label_padding_index = params.label_padding_index 295 | transform_config_dict = transform_layer.fetch_json( 296 | params.arch, params) 297 | trans_config_coarse = transform_config_dict['transform_coarse'] 298 | self.representation_dims = int( 299 | transform_config_dict['representation_dims']) 300 | self._bias = params.bias 301 | super(DeepXMLs, self).__init__(trans_config_coarse) 302 | if params.freeze_intermediate: 303 | print("Freezing intermediate model parameters!") 304 | for params in self.transform.parameters(): 305 | params.requires_grad = False 306 | trans_config_fine = transform_config_dict['transform_fine'] 307 | self.transform_fine = self._construct_transform( 308 | trans_config_fine) 309 | 310 | def save_intermediate_model(self, fname): 311 | torch.save(self.transform.state_dict(), fname) 312 | 313 | def load_intermediate_model(self, fname): 314 | self.transform.load_state_dict(torch.load(fname)) 315 | 316 | def encode_fine(self, x): 317 | """Forward pass (assumes input is coarse computation) 318 | 319 | Arguments: 320 | ----------- 321 | x: torch.FloatTensor 322 | (sparse features) contains weights of features as per x_ind or 323 | (dense features) contains the dense representation of a point 324 | 325 | Returns 326 | ------- 327 | out: torch.FloatTensor 328 | encoded x with fine encoder 329 | """ 330 | return self.transform_fine(_to_device(x, self.device)) 331 | 332 | def encode(self, x, x_ind=None, bypass_fine=False): 333 | """Forward pass 334 | * Assumes features are dense if x_ind is None 335 | 336 | Arguments: 337 | ----------- 338 | x: torch.FloatTensor 339 | (sparse features) contains weights of features as per x_ind or 340 | (dense features) contains the dense representation of a point 341 | x_ind: torch.LongTensor or None, optional (default=None) 342 | contains indices of features (sparse features) 343 | bypass_fine: boolean, optional (default=False) 344 | Return coarse features or not 345 | 346 | Returns 347 | ------- 348 | out: logits for each label 349 | """ 350 | encoding = super().encode((x, x_ind)) 351 | return encoding if bypass_fine else self.transform_fine(encoding) 352 | 353 | def forward(self, batch_data, bypass_coarse=False): 354 | """Forward pass 355 | * Assumes features are dense if X_w is None 356 | * By default classifier is identity op 357 | 358 | Arguments: 359 | ----------- 360 | batch_data: dict 361 | * 'X': torch.FloatTensor 362 | feature weights for given indices or dense rep. 363 | * 'X_ind': torch.LongTensor 364 | feature indices (LongTensor) or None 365 | 366 | Returns 367 | ------- 368 | out: logits for each label 369 | """ 370 | if bypass_coarse: 371 | return self.classifier( 372 | self.encode_fine(batch_data['X']), batch_data['Y_s']) 373 | else: 374 | return self.classifier( 375 | self.encode(batch_data['X'], batch_data['X_ind']), 376 | batch_data['Y_s']) 377 | 378 | def _construct_classifier(self): 379 | offset = 0 380 | if self.label_padding_index: 381 | offset = self.num_clf_partitions 382 | if self.num_clf_partitions > 1: # Run the distributed version 383 | # TODO: Label padding index 384 | # last one is padding index for each partition 385 | _num_labels = self.num_labels + offset 386 | _padding_idx = [None for _ in range(self.num_clf_partitions)] 387 | _bias = [self._bias for _ in range(self.num_clf_partitions)] 388 | _clf_devices = ["cuda:{}".format( 389 | idx) for idx in range(self.num_clf_partitions)] 390 | return linear_layer.ParallelSparseLinear( 391 | input_size=self.representation_dims, 392 | output_size=_num_labels, 393 | bias=_bias, 394 | padding_idx=_padding_idx, 395 | num_partitions=self.num_clf_partitions, 396 | devices=_clf_devices) 397 | else: 398 | # last one is padding index 399 | return linear_layer.SparseLinear( 400 | input_size=self.representation_dims, 401 | output_size=self.num_labels + offset, 402 | padding_idx=self.label_padding_index, 403 | bias=self._bias) 404 | 405 | def to(self): 406 | """Send layers to respective devices 407 | """ 408 | self.transform_fine.to() 409 | super().to() 410 | 411 | def initialize_classifier(self, weight, bias=None): 412 | """Initialize classifier from existing weights 413 | 414 | Arguments: 415 | ----------- 416 | weight: numpy.ndarray 417 | bias: numpy.ndarray or None, optional (default=None) 418 | """ 419 | self.classifier.weight.data.copy_(torch.from_numpy(weight)) 420 | if bias is not None: 421 | self.classifier.bias.data.copy_( 422 | torch.from_numpy(bias).view(-1, 1)) 423 | 424 | def get_clf_weights(self): 425 | """Get classifier weights 426 | """ 427 | return self.classifier.get_weights() 428 | 429 | def __repr__(self): 430 | s = f"{self.transform}\n" 431 | s += f"(Transform fine): {self.transform_fine}" 432 | s += f"\n(Classifier): {self.classifier}\n" 433 | return s 434 | -------------------------------------------------------------------------------- /deepxml/models/residual_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | __author__ = 'KD' 7 | 8 | 9 | class Residual(nn.Module): 10 | """Implementation of a Residual block 11 | Parameters: 12 | ---------- 13 | input_size: int 14 | input dimension 15 | output_size: int 16 | output dimension 17 | dropout: float 18 | dropout probability 19 | init: str, default='eye' 20 | initialization strategy 21 | """ 22 | 23 | def __init__(self, input_size, output_size, dropout, init='eye'): 24 | super(Residual, self).__init__() 25 | self.input_size = input_size 26 | self.output_size = output_size 27 | self.init = init 28 | self.dropout = dropout 29 | self.padding_size = self.output_size - self.input_size 30 | self.hidden_layer = nn.Sequential( 31 | nn.utils.spectral_norm(nn.Linear(self.input_size, self.output_size)), 32 | nn.ReLU(), 33 | nn.Dropout(self.dropout)) 34 | self.initialize(self.init) 35 | 36 | def forward(self, embed): 37 | """Forward pass for Residual 38 | Parameters: 39 | ---------- 40 | embed: torch.Tensor 41 | dense document embedding 42 | 43 | Returns 44 | ------- 45 | out: torch.Tensor 46 | dense document embeddings transformed via residual block 47 | """ 48 | temp = F.pad(embed, (0, self.padding_size), 'constant', 0) 49 | embed = self.hidden_layer(embed) + temp 50 | return embed 51 | 52 | def initialize(self, init_type): 53 | """Initialize units 54 | Parameters: 55 | ----------- 56 | init_type: str 57 | Initialize hidden layer with 'random' or 'eye' 58 | """ 59 | if init_type == 'random': 60 | nn.init.xavier_uniform_( 61 | self.hidden_layer[0].weight, 62 | gain=nn.init.calculate_gain('relu')) 63 | nn.init.constant_(self.hidden_layer[0].bias, 0.0) 64 | else: 65 | nn.init.eye_(self.hidden_layer[0].weight) 66 | nn.init.constant_(self.hidden_layer[0].bias, 0.0) 67 | -------------------------------------------------------------------------------- /deepxml/models/transform_layer.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch.nn as nn 3 | import models.residual_layer as residual_layer 4 | import models.astec as astec 5 | import json 6 | import models.mlp as mlp 7 | 8 | 9 | class _Identity(nn.Module): 10 | def __init__(self, *args, **kwargs): 11 | super(_Identity, self).__init__() 12 | 13 | def forward(self, x): 14 | x, _ = x 15 | return x 16 | 17 | def initialize(self, *args, **kwargs): 18 | pass 19 | 20 | 21 | class Identity(nn.Module): 22 | def __init__(self, *args, **kwargs): 23 | super(Identity, self).__init__() 24 | 25 | def forward(self, x): 26 | return x 27 | 28 | def initialize(self, *args, **kwargs): 29 | pass 30 | 31 | 32 | elements = { 33 | 'dropout': nn.Dropout, 34 | 'batchnorm1d': nn.BatchNorm1d, 35 | 'linear': nn.Linear, 36 | 'relu': nn.ReLU, 37 | 'residual': residual_layer.Residual, 38 | 'identity': Identity, 39 | '_identity': _Identity, 40 | 'astec': astec.Astec, 41 | 'mlp': mlp.MLP 42 | } 43 | 44 | 45 | class Transform(nn.Module): 46 | """ 47 | Transform document representation! 48 | transform_string: string for sequential pipeline 49 | eg relu#,dropout#p:0.1,residual#input_size:300-output_size:300 50 | params: dictionary like object for default params 51 | eg {emb_size:300} 52 | """ 53 | 54 | def __init__(self, modules, device="cuda:0"): 55 | super(Transform, self).__init__() 56 | self.device = device 57 | if len(modules) == 1: 58 | self.transform = modules[0] 59 | else: 60 | self.transform = nn.Sequential(*modules) 61 | 62 | def forward(self, x): 63 | """ 64 | Forward pass for transform layer 65 | Args: 66 | x: torch.Tensor: document representation 67 | Returns: 68 | x: torch.Tensor: transformed document representation 69 | """ 70 | return self.transform(x) 71 | 72 | def _initialize(self, x): 73 | """Initialize parameters from existing ones 74 | Typically for word embeddings 75 | """ 76 | if isinstance(self.transform, nn.Sequential): 77 | self.transform[0].initialize(x) 78 | else: 79 | self.transform.initialize(x) 80 | 81 | def initialize(self, x): 82 | # Currently implemented for: 83 | # * initializing first module of nn.Sequential 84 | # * initializing module 85 | self._initialize(x) 86 | 87 | def to(self): 88 | super().to(self.device) 89 | 90 | def get_token_embeddings(self): 91 | return self.transform.get_token_embeddings() 92 | 93 | @property 94 | def sparse(self): 95 | try: 96 | _sparse = self.transform.sparse 97 | except AttributeError: 98 | _sparse = False 99 | return _sparse 100 | 101 | 102 | def resolve_schema_args(jfile, ARGS): 103 | arguments = re.findall(r"#ARGS\.(.+?);", jfile) 104 | for arg in arguments: 105 | replace = '#ARGS.%s;' % (arg) 106 | to = str(ARGS.__dict__[arg]) 107 | # Python True and False to json true & false 108 | if to == 'True' or to == 'False': 109 | to = to.lower() 110 | if jfile.find('\"#ARGS.%s;\"' % (arg)) != -1: 111 | replace = '\"#ARGS.%s;\"' % (arg) 112 | if isinstance(ARGS.__dict__[arg], str): 113 | to = str("\""+ARGS.__dict__[arg]+"\"") 114 | jfile = jfile.replace(replace, to) 115 | return jfile 116 | 117 | 118 | def fetch_json(file, ARGS): 119 | with open(file, encoding='utf-8') as f: 120 | file = ''.join(f.readlines()) 121 | schema = resolve_schema_args(file, ARGS) 122 | return json.loads(schema) 123 | 124 | 125 | def get_functions(obj, params=None): 126 | return list(map(lambda x: elements[x](**obj[x]), obj['order'])) 127 | -------------------------------------------------------------------------------- /deepxml/run_scripts/Astec.json: -------------------------------------------------------------------------------- 1 | { 2 | "representation_dims": "#ARGS.embedding_dims;", 3 | "transform_coarse": { 4 | "order": ["astec"], 5 | "astec": { 6 | "vocabulary_dims": "#ARGS.vocabulary_dims;", 7 | "embedding_dims": "#ARGS.embedding_dims;", 8 | "freeze": "#ARGS.freeze_intermediate;", 9 | "dropout": 0.5 10 | } 11 | }, 12 | "transform_fine": { 13 | "order": ["residual"], 14 | "residual": { 15 | "input_size": "#ARGS.embedding_dims;", 16 | "output_size": "#ARGS.embedding_dims;", 17 | "dropout": 0.5, 18 | "init": "eye" 19 | } 20 | } 21 | } -------------------------------------------------------------------------------- /deepxml/run_scripts/Identity.json: -------------------------------------------------------------------------------- 1 | { 2 | "representation_dims": "#ARGS.embedding_dims;", 3 | "transform_coarse": { 4 | "order": ["_identity"], 5 | "_identity": {} 6 | }, 7 | "transform_fine": { 8 | "order": ["batchnorm1d"], 9 | "batchnorm1d": { 10 | "num_features": 300 11 | } 12 | } 13 | } -------------------------------------------------------------------------------- /deepxml/run_scripts/RNN.json: -------------------------------------------------------------------------------- 1 | { 2 | "transform_coarse": { 3 | "order": ["astecpp", "rnn", "dropout"], 4 | "astecpp": { 5 | "vocabulary_dims": "#ARGS.vocabulary_dims;", 6 | "embedding_dims": "#ARGS.embedding_dims;", 7 | "dropout": 0.2, 8 | "reduction": null 9 | }, 10 | "rnn": { 11 | "cell_type": "GRU", 12 | "input_size": "#ARGS.embedding_dims;", 13 | "hidden_size": 256, 14 | "dropout": 0.2, 15 | "num_layers": 1, 16 | "bidirectional": true 17 | }, 18 | "dropout": { 19 | "p": 0.3 20 | } 21 | }, 22 | "transform_fine": { 23 | "order": ["identity"], 24 | "identity": {} 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /deepxml/run_scripts/run_datasets.sh: -------------------------------------------------------------------------------- 1 | #./run_main.sh 0 DeepXML WikiSeeAlsoTitles-350K 10 2 | #./run_main.sh 0 DeepXML WikiTitles-500K 10 3 | ./run_main.sh 0 DeepXML EURLex-4K 10 4 | -------------------------------------------------------------------------------- /deepxml/run_scripts/run_main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # $1 GPU Device ID 3 | # $2 Model Type (DeepXML/DeepXML-OVA etc.) 4 | # $3 Dataset 5 | # $4 version 6 | # $5 seed 7 | # eg. ./run_main.sh 0 DeepXML EURLex-4K 0 22 8 | # eg. ./run_main.sh 0 DeepXML-fr EURLex-4K 0 22 9 | 10 | export CUDA_VISIBLE_DEVICES=$1 11 | model_type=$2 12 | dataset=$3 13 | version=$4 14 | seed=$5 15 | 16 | work_dir=$(cd ../../../../ && pwd) 17 | 18 | current_working_dir=$(pwd) 19 | python3 ../runner.py "${model_type}" "${work_dir}" ${version} "$(dirname "$current_working_dir")/configs/${model_type}/${dataset}.json" "${seed}" 20 | -------------------------------------------------------------------------------- /deepxml/tools/convert_format.pl: -------------------------------------------------------------------------------- 1 | my $inpfile; 2 | open($inpfile,"<",$ARGV[0]); 3 | 4 | my $ftfile; 5 | open($ftfile,">",$ARGV[1]); 6 | 7 | my $lblfile; 8 | open($lblfile,">",$ARGV[2]); 9 | 10 | my $ctr = 0; 11 | while(<$inpfile>) 12 | { 13 | chomp($_); 14 | 15 | if($ctr==0) 16 | { 17 | my @items = split(" ",$_); 18 | $num_inst = $items[0]; 19 | $num_ft = $items[1]; 20 | $num_lbl = $items[2]; 21 | 22 | print $ftfile "$num_inst $num_ft\n"; 23 | print $lblfile "$num_inst $num_lbl\n"; 24 | } 25 | else 26 | { 27 | my @items = split(" ",$_,2); 28 | 29 | if($_ =~ /^ .*/) 30 | { 31 | print $lblfile "\n"; 32 | print $ftfile $items[0]."\n"; 33 | } 34 | else 35 | { 36 | my @lbls = split(",",$items[0]); 37 | print $lblfile join(" ",map {"$_:1"} @lbls)."\n"; 38 | print $ftfile $items[1]."\n"; 39 | } 40 | } 41 | 42 | $ctr++; 43 | } 44 | 45 | close($inpfile); 46 | close($ftfile); 47 | close($lblfile); 48 | -------------------------------------------------------------------------------- /deepxml/tools/evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import xclib.evaluation.xc_metrics as xc_metrics 3 | import xclib.data.data_utils as data_utils 4 | from scipy.sparse import load_npz, save_npz 5 | import numpy as np 6 | import os 7 | from xclib.utils.sparse import sigmoid, normalize, retain_topk 8 | 9 | 10 | def get_filter_map(fname): 11 | if fname is not None: 12 | mapping = np.loadtxt(fname).astype(np.int) 13 | if mapping.size != 0: 14 | return mapping 15 | return None 16 | 17 | 18 | def filter_predictions(pred, mapping): 19 | if mapping is not None and len(mapping) > 0: 20 | print("Filtering labels.") 21 | pred[mapping[:, 0], mapping[:, 1]] = 0 22 | pred.eliminate_zeros() 23 | return pred 24 | 25 | 26 | def main(tst_label_fname, trn_label_fname, filter_fname, pred_fname, 27 | A, B, betas, top_k, save): 28 | true_labels = data_utils.read_sparse_file(tst_label_fname) 29 | trn_labels = data_utils.read_sparse_file(trn_label_fname) 30 | inv_propen = xc_metrics.compute_inv_propesity(trn_labels, A, B) 31 | mapping = get_filter_map(filter_fname) 32 | acc = xc_metrics.Metrics(true_labels, inv_psp=inv_propen) 33 | root = os.path.dirname(pred_fname) 34 | ans = "" 35 | if isinstance(betas, list) and betas[0] != -1: 36 | knn = filter_predictions( 37 | load_npz(pred_fname+'_knn.npz'), mapping) 38 | clf = filter_predictions( 39 | load_npz(pred_fname+'_clf.npz'), mapping) 40 | args = acc.eval(clf, 5) 41 | ans = f"classifier\n{xc_metrics.format(*args)}" 42 | args = acc.eval(knn, 5) 43 | ans = ans + f"\nshortlist\n{xc_metrics.format(*args)}" 44 | clf = retain_topk(clf, k=top_k) 45 | knn = retain_topk(knn, k=top_k) 46 | clf = normalize(sigmoid(clf), norm='max') 47 | knn = normalize(sigmoid(knn), norm='max') 48 | for beta in betas: 49 | predicted_labels = beta*clf + (1-beta)*knn 50 | args = acc.eval(predicted_labels, 5) 51 | ans = ans + f"\nbeta: {beta:.2f}\n{xc_metrics.format(*args)}" 52 | if save: 53 | fname = os.path.join(root, f"score_{beta:.2f}.npz") 54 | save_npz(fname, retain_topk(predicted_labels, k=top_k), 55 | compressed=False) 56 | else: 57 | predicted_labels = filter_predictions( 58 | sigmoid(load_npz(pred_fname+'.npz')), mapping) 59 | args = acc.eval(predicted_labels, 5) 60 | ans = xc_metrics.format(*args) 61 | if save: 62 | print("Saving predictions..") 63 | fname = os.path.join(root, "score.npz") 64 | save_npz(fname, retain_topk(predicted_labels, k=top_k), 65 | compressed=False) 66 | line = "-"*30 67 | print(f"\n{line}\n{ans}\n{line}") 68 | return ans 69 | 70 | 71 | if __name__ == '__main__': 72 | trn_label_file = sys.argv[1] 73 | targets_file = sys.argv[2] 74 | filter_map = sys.argv[3] 75 | pred_fname = sys.argv[4] 76 | A = float(sys.argv[5]) 77 | B = float(sys.argv[6]) 78 | save = int(sys.argv[7]) 79 | top_k = int(sys.argv[8]) 80 | betas = list(map(float, sys.argv[9:])) 81 | main(targets_file, trn_label_file, filter_map, pred_fname, A, B, betas, top_k, save) 82 | -------------------------------------------------------------------------------- /deepxml/tools/evaluate_ensemble.py: -------------------------------------------------------------------------------- 1 | from scipy.sparse import load_npz 2 | from functools import reduce 3 | import sys 4 | import xclib.evaluation.xc_metrics as xc_metrics 5 | import xclib.data.data_utils as data_utils 6 | from scipy.sparse import load_npz, save_npz 7 | import numpy as np 8 | import os 9 | 10 | 11 | def read_files(fnames): 12 | output = [] 13 | for fname in fnames: 14 | output.append(load_npz(fname)) 15 | return output 16 | 17 | 18 | def merge(predictions): 19 | return reduce(lambda a, b: a+b, predictions) 20 | 21 | 22 | def main(tst_label_fname, trn_label_fname, pred_fname, 23 | A, B, save, *args, **kwargs): 24 | true_labels = data_utils.read_sparse_file(tst_label_fname) 25 | trn_labels = data_utils.read_sparse_file(trn_label_fname) 26 | inv_propen = xc_metrics.compute_inv_propesity(trn_labels, A, B) 27 | acc = xc_metrics.Metrics(true_labels, inv_psp=inv_propen) 28 | root = os.path.dirname(pred_fname[-1]) 29 | predicted_labels = read_files(pred_fname) 30 | ens_predicted_labels = merge(predicted_labels) 31 | ans = "" 32 | for idx, pred in enumerate(predicted_labels): 33 | args = acc.eval(pred, 5) 34 | ans = ans + f"learner: {idx}\n{xc_metrics.format(*args)}\n" 35 | args = acc.eval(ens_predicted_labels, 5) 36 | ans = ans + f"Ensemble\n{xc_metrics.format(*args)}" 37 | if save: 38 | print("Saving predictions..") 39 | fname = os.path.join(root, "score.npz") 40 | save_npz(fname, ens_predicted_labels, compressed=False) 41 | line = "-"*30 42 | print(f"\n{line}\n{ans}\n{line}") 43 | return ans 44 | 45 | 46 | if __name__ == '__main__': 47 | trn_label_fname = sys.argv[1] 48 | tst_label_fname = sys.argv[2] 49 | pred_fname = sys.argv[3].rstrip(",").split(",") 50 | A = float(sys.argv[4]) 51 | B = float(sys.argv[5]) 52 | save = int(sys.argv[6]) 53 | main(tst_label_fname, trn_label_fname, pred_fname, A, B, save) 54 | -------------------------------------------------------------------------------- /deepxml/tools/surrogate_mapping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from xclib.utils.sparse import binarize, normalize, compute_centroid 3 | import functools 4 | import xclib.data.data_utils as data_utils 5 | from xclib.utils.graph import RandomWalk 6 | from xclib.utils.clustering import cluster_balance 7 | from xclib.utils.clustering import b_kmeans_sparse, b_kmeans_dense 8 | import os 9 | import json 10 | 11 | 12 | def compute_correlation(Y, walk_to=50, p_reset=0.2, k=10): 13 | rw = RandomWalk(Y) 14 | return rw.simulate(walk_to=walk_to, p_reset=p_reset, k=k) 15 | 16 | 17 | class SurrogateMapping(object): 18 | """ 19 | Generate mapping of labels for surrogate task 20 | 21 | Arguments: 22 | ---------- 23 | method: int, optional (default: 0) 24 | - 0 none (use extreme task) 25 | - 1 cluster labels & treat clusters as new labels 26 | - 2 pick topk labels based on given label frequency 27 | - 3 pick topk labels 28 | 29 | threshold: int, optional (default: 65536) 30 | - method 0: none 31 | - method 1: number of clusters 32 | - method 2: label frequency; pick labels more with 33 | frequency more than given value 34 | - method 3: #labels to pick 35 | """ 36 | def __init__(self, method=0, threshold=65536, feature_type='sparse'): 37 | self.feature_type = feature_type 38 | self.method = method 39 | self.threshold = threshold 40 | 41 | def map_on_cluster(self, features, labels): 42 | label_centroids = compute_centroid(features, labels) 43 | cooc = normalize(compute_correlation(labels), norm="l1") 44 | if self.feature_type == 'sparse': 45 | freq = labels.getnnz(axis = 0) 46 | if freq.max() > 5000: 47 | print("Correlation matrix is too dense. Skipping..") 48 | else: 49 | label_centroids = cooc.dot(label_centroids) 50 | splitter=functools.partial(b_kmeans_sparse) 51 | elif self.feature_type == 'dense': 52 | label_centroids = cooc.dot(label_centroids) 53 | splitter=functools.partial(b_kmeans_dense) 54 | else: 55 | raise NotImplementedError("Unknown feature type!") 56 | _, self.mapping = cluster_balance( 57 | X=label_centroids, 58 | clusters=[np.asarray(np.arange(labels.shape[1]), dtype=np.int64)], 59 | num_clusters=self.threshold, 60 | splitter=splitter) 61 | self.num_surrogate_labels = self.threshold 62 | 63 | def map_on_frequency(self, labels): 64 | raise NotImplementedError("") 65 | 66 | def map_on_topk(self, labels): 67 | raise NotImplementedError("") 68 | 69 | def remove_documents_wo_features(self, features, labels): 70 | if isinstance(features, np.ndarray): 71 | features = np.power(features, 2) 72 | else: 73 | features = features.power(2) 74 | freq = np.array(features.sum(axis=1)).ravel() 75 | indices = np.where(freq > 0)[0] 76 | features = features[indices] 77 | labels = labels[indices] 78 | return features, labels 79 | 80 | def map_none(self): 81 | self.num_surrogate_labels = len(self.valid_labels) 82 | self.mapping = self.valid_labels 83 | 84 | def gen_mapping(self, features, labels): 85 | # Assumes invalid labels are already removed 86 | if self.method == 0: 87 | self.map_none() 88 | elif self.method == 1: 89 | self.map_on_cluster(features, labels) 90 | elif self.method == 1: 91 | self.map_on_frequency(labels) 92 | elif self.method == 2: 93 | self.map_on_topk(labels) 94 | else: 95 | pass 96 | 97 | def get_valid_labels(self, labels): 98 | freq = np.array(labels.sum(axis=0)).ravel() 99 | ind = np.where(freq > 0)[0] 100 | return labels[:, ind], ind 101 | 102 | def fit(self, features, labels): 103 | self.num_labels = labels.shape[1] 104 | # Remove documents w/o any feature 105 | # these may impact the count, if not removed 106 | features, labels = self.remove_documents_wo_features(features, labels) 107 | # keep only valid labels; main code will also remove invalid labels 108 | labels, self.valid_labels = self.get_valid_labels(labels) 109 | self.gen_mapping(features, labels) 110 | 111 | 112 | def run(feat_fname, lbl_fname, feature_type, method, threshold, seed, tmp_dir): 113 | np.random.seed(seed) 114 | if feature_type == 'dense': 115 | features = data_utils.read_gen_dense(feat_fname) 116 | elif feature_type == 'sparse': 117 | features = data_utils.read_gen_sparse(feat_fname) 118 | else: 119 | raise NotImplementedError() 120 | labels = data_utils.read_sparse_file(lbl_fname) 121 | assert features.shape[0] == labels.shape[0], \ 122 | "Number of instances must be same in features and labels" 123 | num_features = features.shape[1] 124 | stats_obj = {} 125 | stats_obj['threshold'] = threshold 126 | stats_obj['method'] = method 127 | 128 | sd = SurrogateMapping( 129 | method=method, threshold=threshold, feature_type=feature_type) 130 | sd.fit(features, labels) 131 | stats_obj['surrogate'] = "{},{},{}".format( 132 | num_features, sd.num_surrogate_labels, sd.num_surrogate_labels) 133 | stats_obj['extreme'] = "{},{},{}".format( 134 | num_features, sd.num_labels, len(sd.valid_labels)) 135 | 136 | json.dump(stats_obj, open( 137 | os.path.join(tmp_dir, "data_stats.json"), 'w'), indent=4) 138 | 139 | np.savetxt(os.path.join(tmp_dir, "valid_labels.txt"), 140 | sd.valid_labels, fmt='%d') 141 | np.savetxt(os.path.join(tmp_dir, "surrogate_mapping.txt"), 142 | sd.mapping, fmt='%d') 143 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.0 2 | xclib==0.96 3 | cudatoolkit=11.1.1=h6406543_8 4 | scikit-learn=0.24.1=py38ha9443f7_0 5 | scipy=1.6.1=py38h91f5cce_0 6 | cython=0.29.22=py38h2531618_0 7 | numba=0.51.1=pypi_0 8 | numexpr=2.7.3=py38hb2eb853_0 9 | numpy=1.19.5=pypi_0 10 | numpy-base=1.19.2=py38hfa32c7d_0 11 | numpydoc=1.1.0=pyhd3eb1b0_1 12 | cython=0.29.22=py38h2531618_0 13 | nmslib=2.0.6=pypi_0 14 | matplotlib=3.3.4=py38h06a4308_0 15 | matplotlib-base=3.3.4=py38h62a2d02_0 16 | setuptools=52.0.0=py38h06a4308_0 17 | tqdm=4.59.0=pyhd3eb1b0_1 --------------------------------------------------------------------------------