├── docs ├── assets │ ├── images │ │ ├── favicon.ico │ │ ├── carousel1.jpg │ │ ├── carousel2.jpg │ │ ├── carousel3.jpg │ │ ├── carousel4.jpg │ │ ├── hit_logo.png │ │ ├── hit_loss.png │ │ ├── hit_logo+title.png │ │ ├── hit_illustration.png │ │ └── neurips.svg │ └── pdfs │ │ └── HiT_neurips2024_poster.pdf ├── static │ ├── js │ │ ├── index.js │ │ ├── bulma-slider.min.js │ │ └── bulma-slider.js │ └── css │ │ ├── index.css │ │ ├── bulma-carousel.min.css │ │ └── bulma-slider.min.css └── changelog.md ├── requirements.txt ├── scripts ├── training │ ├── sft │ │ ├── config_sft.yaml │ │ └── training_sft.py │ ├── static_embed │ │ ├── config_static.yaml │ │ └── training_static.py │ └── hit │ │ ├── config_hit.yaml │ │ └── training_hit.py └── evaluation │ ├── sbert │ ├── config_sbert.yaml │ └── eval_sbert.py │ └── hit │ ├── config_hit.yaml │ └── eval_hit.py ├── .github ├── dependabot.yml └── workflows │ └── python-publish.yml ├── tests ├── __init__.py ├── test_loading_hit.py ├── test_training_hit.py ├── test_metrics.py └── test_loading_dataset.py ├── src └── hierarchy_transformers │ ├── __init__.py │ ├── datasets │ ├── __init__.py │ ├── load.py │ └── construct.py │ ├── models │ ├── hierarchy_transformer │ │ ├── __init__.py │ │ ├── hit_trainer.py │ │ ├── hyperbolic.py │ │ └── hit.py │ ├── static_embed │ │ ├── __init__.py │ │ ├── poincare_embed.py │ │ └── poincare_trainer.py │ └── __init__.py │ ├── evaluation │ ├── __init__.py │ ├── metrics.py │ ├── sbert_eval.py │ ├── static_embed_eval.py │ └── hit_eval.py │ ├── losses │ ├── __init__.py │ ├── poincare_embed_loss.py │ ├── hyper_cone_loss.py │ └── hit_loss.py │ ├── utils.py │ └── plot.py ├── .gitignore ├── pyproject.toml ├── README.md └── LICENSE /docs/assets/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KRR-Oxford/HierarchyTransformers/HEAD/docs/assets/images/favicon.ico -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sentence_transformers[train]>=3.4.0 2 | deeponto>=0.9.2 3 | geoopt>=0.5.0 4 | scipy==1.13.1 5 | seaborn 6 | -------------------------------------------------------------------------------- /docs/assets/images/carousel1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KRR-Oxford/HierarchyTransformers/HEAD/docs/assets/images/carousel1.jpg -------------------------------------------------------------------------------- /docs/assets/images/carousel2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KRR-Oxford/HierarchyTransformers/HEAD/docs/assets/images/carousel2.jpg -------------------------------------------------------------------------------- /docs/assets/images/carousel3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KRR-Oxford/HierarchyTransformers/HEAD/docs/assets/images/carousel3.jpg -------------------------------------------------------------------------------- /docs/assets/images/carousel4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KRR-Oxford/HierarchyTransformers/HEAD/docs/assets/images/carousel4.jpg -------------------------------------------------------------------------------- /docs/assets/images/hit_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KRR-Oxford/HierarchyTransformers/HEAD/docs/assets/images/hit_logo.png -------------------------------------------------------------------------------- /docs/assets/images/hit_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KRR-Oxford/HierarchyTransformers/HEAD/docs/assets/images/hit_loss.png -------------------------------------------------------------------------------- /docs/assets/images/hit_logo+title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KRR-Oxford/HierarchyTransformers/HEAD/docs/assets/images/hit_logo+title.png -------------------------------------------------------------------------------- /docs/assets/images/hit_illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KRR-Oxford/HierarchyTransformers/HEAD/docs/assets/images/hit_illustration.png -------------------------------------------------------------------------------- /docs/assets/pdfs/HiT_neurips2024_poster.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KRR-Oxford/HierarchyTransformers/HEAD/docs/assets/pdfs/HiT_neurips2024_poster.pdf -------------------------------------------------------------------------------- /scripts/training/sft/config_sft.yaml: -------------------------------------------------------------------------------- 1 | # dataset from Hugging Face 2 | dataset_path: "Hierarchy-Transformers/WordNetNoun" 3 | dataset_name: "MixedHop-RandomNegatives" 4 | 5 | # pre-trained model from Hugging Face 6 | model_path: "sentence-transformers/all-MiniLM-L12-v2" 7 | 8 | # training config 9 | num_train_epochs: 3 10 | train_batch_size: 256 11 | eval_batch_size: 512 12 | learning_rate: 1e-5 13 | -------------------------------------------------------------------------------- /scripts/training/static_embed/config_static.yaml: -------------------------------------------------------------------------------- 1 | # local dataset downloaded from Zenodo 2 | dataset_path: "data/wordnet-multi" 3 | negative_type: "random" 4 | 5 | # training config (on hyperbolic distance loss) 6 | embed_dim: 200 7 | num_train_epochs: 200 8 | train_batch_size: 256 9 | eval_batch_size: 512 10 | learning_rate: 0.01 11 | warmup_epochs: 10 12 | 13 | # post-training config (on hyperbolic entailment cone loss) 14 | num_post_train_epochs: 200 -------------------------------------------------------------------------------- /docs/static/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | 4 | $(document).ready(function() { 5 | // Check for click events on the navbar burger icon 6 | 7 | var options = { 8 | slidesToScroll: 1, 9 | slidesToShow: 1, 10 | loop: true, 11 | infinite: true, 12 | autoplay: true, 13 | autoplaySpeed: 5000, 14 | } 15 | 16 | // Initialize all div with carousel class 17 | var carousels = bulmaCarousel.attach('.carousel', options); 18 | 19 | bulmaSlider.attach(); 20 | 21 | }) 22 | -------------------------------------------------------------------------------- /scripts/training/hit/config_hit.yaml: -------------------------------------------------------------------------------- 1 | # dataset from Hugging Face 2 | dataset_path: "Hierarchy-Transformers/WordNetNoun" 3 | dataset_name: "MixedHop-RandomNegatives" 4 | 5 | # pre-trained model from Hugging Face 6 | model_path: "sentence-transformers/all-MiniLM-L12-v2" 7 | 8 | # training config 9 | num_train_epochs: 20 10 | train_batch_size: 256 11 | eval_batch_size: 512 12 | learning_rate: 1e-5 13 | hit_loss: 14 | clustering_loss_weight: 1.0 15 | clustering_loss_margin: 3.0 16 | centripetal_loss_weight: 1.0 17 | centripetal_loss_margin: 0.5 18 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "daily" 12 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /scripts/evaluation/sbert/config_sbert.yaml: -------------------------------------------------------------------------------- 1 | # Mixed-hop Prediction on WordNet's random negatives 2 | dataset_path: "Hierarchy-Transformers/WordNetNoun" 3 | dataset_name: "MixedHop-RandomNegatives" 4 | model_path: "sentence-transformers/all-MiniLM-L12-v2" 5 | 6 | # Mixed-hop Prediction on WordNet's hard negatives 7 | # dataset_path: "Hierarchy-Transformers/WordNetNoun" 8 | # dataset_name: "MixedHop-HardNegatives" 9 | # model_path: "sentence-transformers/all-MiniLM-L12-v2" 10 | 11 | # Mixed-hop Prediction evaluation on DOID's random negatives 12 | # dataset_path: "Hierarchy-Transformers/DOID" 13 | # dataset_name: "MixedHop-RandomNegatives" 14 | # model_path: "sentence-transformers/all-MiniLM-L12-v2" 15 | 16 | eval_batch_size: 512 -------------------------------------------------------------------------------- /src/hierarchy_transformers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from .models import HierarchyTransformer 17 | 18 | __all__ = [ 19 | "HierarchyTransformer", 20 | ] 21 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from .load import load_hf_dataset, load_zenodo_dataset 17 | 18 | __all__ = [ 19 | "load_hf_dataset", 20 | "load_zenodo_dataset", 21 | ] 22 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/models/hierarchy_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from .hit import HierarchyTransformer 17 | from .hit_trainer import HierarchyTransformerTrainer 18 | 19 | __all__ = [ 20 | "HierarchyTransformer", 21 | "HierarchyTransformerTrainer", 22 | ] 23 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/models/static_embed/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from .poincare_embed import PoincareStaticEmbedding 17 | from .poincare_trainer import PoincareStaticEmbeddingTrainer 18 | 19 | __all__ = [ 20 | "PoincareStaticEmbedding", 21 | "PoincareStaticEmbeddingTrainer", 22 | ] 23 | -------------------------------------------------------------------------------- /scripts/evaluation/hit/config_hit.yaml: -------------------------------------------------------------------------------- 1 | # Mixed-hop Prediction on WordNet's random negatives 2 | dataset_path: "Hierarchy-Transformers/WordNetNoun" 3 | dataset_name: "MixedHop-RandomNegatives" 4 | model_path: "Hierarchy-Transformers/HiT-MiniLM-L12-WordNetNoun" 5 | revision: "v1-random-negatives" # choose the random negative version 6 | 7 | # Mixed-hop Prediction on WordNet's hard negatives 8 | # dataset_path: "Hierarchy-Transformers/WordNetNoun" 9 | # dataset_name: "MixedHop-HardNegatives" 10 | # model_path: "Hierarchy-Transformers/HiT-MiniLM-L12-WordNetNoun" 11 | # revision: "v1-hard-negatives" # choose the hard negative version 12 | 13 | # Transfer Mixed-hop Prediction evaluation on DOID's random negatives 14 | # dataset_path: "Hierarchy-Transformers/DOID" 15 | # dataset_name: "MixedHop-RandomNegatives" 16 | # model_path: "Hierarchy-Transformers/HiT-MiniLM-L12-WordNetNoun" 17 | # revision: "v1-random-negatives" # choose the random negative version 18 | 19 | eval_batch_size: 512 -------------------------------------------------------------------------------- /src/hierarchy_transformers/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from .hit_eval import HierarchyTransformerEvaluator 17 | from .sbert_eval import SentenceTransformerEvaluator 18 | from .static_embed_eval import PoincareStaticEmbeddingEvaluator 19 | 20 | __all__ = [ 21 | "HierarchyTransformerEvaluator", 22 | "SentenceTransformerEvaluator", 23 | "PoincareStaticEmbeddingEvaluator", 24 | ] 25 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from .hierarchy_transformer import HierarchyTransformer, HierarchyTransformerTrainer, hyperbolic 17 | from .static_embed import PoincareStaticEmbedding, PoincareStaticEmbeddingTrainer 18 | 19 | __all__ = [ 20 | "HierarchyTransformer", 21 | "HierarchyTransformerTrainer", 22 | "hyperbolic", 23 | "PoincareStaticEmbedding", 24 | "PoincareStaticEmbeddingTrainer", 25 | ] 26 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v3 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build 33 | - name: Build package 34 | run: python -m build 35 | - name: Publish package 36 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 37 | with: 38 | user: __token__ 39 | password: ${{ secrets.PYPI_API_TOKEN }} 40 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from .hit_loss import HierarchyTransformerLoss, HyperbolicCentripetalLoss, HyperbolicClusteringLoss 17 | from .hyper_cone_loss import ( 18 | HyperbolicEntailmentConeLoss, 19 | HyperbolicEntailmentConeStaticLoss, 20 | HyperbolicEntailmentConeTripletLoss, 21 | ) 22 | from .poincare_embed_loss import PoincareEmbeddingStaticLoss 23 | 24 | __all__ = [ 25 | "HierarchyTransformerLoss", 26 | "HyperbolicCentripetalLoss", 27 | "HyperbolicClusteringLoss", 28 | "HyperbolicEntailmentConeLoss", 29 | "HyperbolicEntailmentConeStaticLoss", 30 | "HyperbolicEntailmentConeTripletLoss", 31 | "PoincareEmbeddingStaticLoss", 32 | ] 33 | -------------------------------------------------------------------------------- /docs/changelog.md: -------------------------------------------------------------------------------- 1 | # Changelog :newspaper: 2 | 3 | 9 | 10 | ## v0.1.1 (2024-12-13) 11 | 12 | ### Added 13 | 14 | - [X] [`feature`] Add `HierarchyTransformerTrainer` that extends `SentenceTransformerTrainer` with step-wise (per batch) loss logging. 15 | 16 | ### Changed 17 | 18 | - [X] [`chore`] Refactor all the code with ruff linter. 19 | 20 | ## v0.1.0 (2024-12-11) 21 | 22 | Significant development to align with `sentence-transformers>=3.4.0.dev0`. 23 | 24 | ### Added 25 | 26 | - [X] [`feature`] Add pytest modules for testing. 27 | - [X] [`docs`] Set up [project page](https://krr-oxford.github.io/HierarchyTransformers/). 28 | - [X] [`feature`] Upload HiT datasets on [HuggingFace](https://huggingface.co/Hierarchy-Transformers). 29 | - [X] [`feature`] Re-organise models by setting `v1-random-negatives` and `v1-hard-negatives` revisions on [HuggingFace](https://huggingface.co/Hierarchy-Transformers). 30 | 31 | ### Changed 32 | 33 | - [X] [`chore`] Rewrite and reorganise `hierarchy_transformers.models`, `hierarchy_transformers.losses`, and `hierarchy_transformers.evaluation` to to align with `sentence-transformers>=3.4.0.dev`. 34 | - [X] [`chore`] Rewrite dataset processing and loading functions and reorganise everything into `hierarchy_transformers.datasets`. 35 | 36 | ### Removed 37 | 38 | - [X] [`chore`] Remove `hierarchy_transformers.models.utils`. 39 | 40 | ## v0.0.3 (2024-05-09) 41 | 42 | Initial release (should work with `sentence-transformers<3.0.0` ) and bug fix. 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | #lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | nosetests.xml 43 | coverage.xml 44 | *.cover 45 | 46 | # Translations 47 | *.mo 48 | *.pot 49 | 50 | # Django stuff: 51 | *.log 52 | 53 | # Sphinx documentation 54 | docs/_build/ 55 | 56 | # PyBuilder 57 | target/ 58 | 59 | # DotEnv configuration 60 | .env 61 | 62 | # Database 63 | *.db 64 | *.rdb 65 | 66 | # Pycharm 67 | .idea 68 | 69 | # VS Code 70 | .vscode/ 71 | 72 | # Spyder 73 | .spyproject/ 74 | 75 | # Jupyter NB Checkpoints 76 | .ipynb_checkpoints/ 77 | 78 | # exclude data from source control by default 79 | /data/ 80 | data 81 | results 82 | 83 | # Mac OS-specific storage files 84 | .DS_Store 85 | 86 | # vim 87 | *.swp 88 | *.swo 89 | 90 | # Mypy cache 91 | .mypy_cache/ 92 | 93 | # pytest 94 | .pytest* 95 | testing/ 96 | 97 | # data scripts 98 | data_scripts/*.ipynb 99 | 100 | # logmap redundancy 101 | logmap-log.out 102 | 103 | # experiments 104 | experiments/ 105 | run.mondo.sh 106 | run.umls.sh 107 | 108 | print_reqs.py 109 | 110 | # mkdocs site 111 | site 112 | 113 | # notebook for learning 114 | *.ipynb 115 | 116 | temp* -------------------------------------------------------------------------------- /src/hierarchy_transformers/models/hierarchy_transformer/hit_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from typing import Any 17 | 18 | import torch 19 | from sentence_transformers.trainer import SentenceTransformerTrainer 20 | 21 | from .hit import HierarchyTransformer 22 | 23 | 24 | class HierarchyTransformerTrainer(SentenceTransformerTrainer): 25 | r"""A simple extension of `SentenceTransformerTrainer` to monitor and log batch losses of `HierarchyTransformer`.""" 26 | 27 | def compute_loss( 28 | self, 29 | model: HierarchyTransformer, 30 | inputs: dict[str, torch.Tensor | Any], 31 | return_outputs: bool = False, 32 | num_items_in_batch=None, 33 | ) -> torch.Tensor | tuple[torch.Tensor, dict[str, Any]]: 34 | loss_dict = super().compute_loss( 35 | model=model, inputs=inputs, return_outputs=return_outputs, num_items_in_batch=num_items_in_batch 36 | ) 37 | outputs = None 38 | if return_outputs: 39 | loss_dict, outputs = loss_dict 40 | self.log( 41 | { 42 | "cluster_loss": round(loss_dict["cluster_loss"].item(), 4), 43 | "centri_loss": round(loss_dict["centri_loss"].item(), 4), 44 | "combined_loss": round(loss_dict["loss"].item(), 4), 45 | } 46 | ) 47 | 48 | return (loss_dict["loss"], outputs) if return_outputs else loss_dict["loss"] 49 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | from textwrap import dedent 17 | 18 | import torch 19 | 20 | 21 | def get_torch_device(gpu_id: int): 22 | return torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu") 23 | 24 | 25 | def are_models_equal(model1: torch.nn.Module, model2: torch.nn.Module, tolerance: float = 1e-6) -> bool: 26 | """ 27 | Compares two PyTorch models to check if they are the same by comparing each parameter and buffer. 28 | 29 | Args: 30 | model1 (torch.nn.Module): The first model to compare. 31 | model2 (torch.nn.Module): The second model to compare. 32 | tolerance (float): The tolerance level for floating-point comparison (default is 1e-6). 33 | 34 | Returns: 35 | bool: True if the models are the same, False otherwise. 36 | """ 37 | # Compare model parameters 38 | for param1, param2 in zip(model1.parameters(), model2.parameters()): 39 | if not torch.allclose(param1, param2, atol=tolerance): 40 | return False 41 | 42 | # Compare model buffers (e.g., running mean/variance in batch norm layers) 43 | for buffer1, buffer2 in zip(model1.buffers(), model2.buffers()): 44 | if not torch.allclose(buffer1, buffer2, atol=tolerance): 45 | return False 46 | 47 | return True 48 | 49 | 50 | def format_citation(bibtex: str): 51 | """ 52 | Use `dedent` to properly form bibtex string. 53 | """ 54 | return dedent(bibtex) 55 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/plot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import logging 17 | 18 | import seaborn as sns 19 | from deeponto.onto import Taxonomy 20 | from geoopt.manifolds import PoincareBall 21 | from sentence_transformers import SentenceTransformer 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def entity_norm_plot(hierarchy: Taxonomy, model: SentenceTransformer): 27 | entity_names = [hierarchy.get_node_attributes(e)["name"] for e in hierarchy.nodes] 28 | entity_embeds = model.encode(entity_names, 1024, True, convert_to_tensor=True) 29 | manifold = PoincareBall(c=1 / model._first_module().get_word_embedding_dimension()) 30 | entity_norms = manifold.dist0(entity_embeds) 31 | return ( 32 | entity_embeds, 33 | entity_norms, 34 | sns.histplot(entity_norms.cpu().numpy(), bins=10, kde=True, kde_kws={"bw_adjust": 2}), 35 | ) 36 | 37 | 38 | def entity_depths_plot(hierarchy: Taxonomy): 39 | if not hierarchy.root_node: 40 | logger.info("No root node detected; adding in edges from current top nodes to a pseudo root node.") 41 | top_nodes = [] 42 | for n in hierarchy.nodes: 43 | if not hierarchy.get_parents(n): 44 | top_nodes.append(n) 45 | root = "owl:Thing" 46 | rooted_hierarchy = Taxonomy(hierarchy.edges + [(root, t) for t in top_nodes], root_node=root) 47 | else: 48 | rooted_hierarchy = hierarchy 49 | depths = [] 50 | for n in hierarchy.nodes: 51 | depths.append(rooted_hierarchy.get_shortest_node_depth(n)) 52 | return depths, sns.histplot(depths, bins=10, kde=True, kde_kws={"bw_adjust": 2}) 53 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/losses/poincare_embed_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import torch 17 | from geoopt.manifolds import PoincareBall 18 | 19 | from hierarchy_transformers.utils import format_citation 20 | 21 | 22 | class PoincareEmbeddingStaticLoss(torch.nn.Module): 23 | """Poincare embedding loss. 24 | 25 | Essentially, this loss is expected to achieve: 26 | 27 | $$d(child, parent) < d(child, negative)$$ 28 | 29 | Inputs are presented in `(subject, *objects)` where the first `object` is positive and the rest are negative. 30 | 31 | This is designed for the static embedding implementation. 32 | """ 33 | 34 | def __init__(self, manifold: PoincareBall): 35 | super().__init__() 36 | self.manifold = manifold 37 | self.cross_entropy = torch.nn.CrossEntropyLoss() 38 | 39 | def forward(self, subject: torch.Tensor, objects: torch.Tensor): 40 | # first object is always the correct one 41 | pred_dists = self.manifold.dist(subject, objects) 42 | correct_object_indices = torch.tensor([0] * len(pred_dists)).to(pred_dists.device) 43 | return self.cross_entropy(-pred_dists, correct_object_indices) 44 | 45 | @property 46 | def citation(self) -> str: 47 | return format_citation( 48 | """ 49 | @article{nickel2017poincare, 50 | title={Poincar{\'e} embeddings for learning hierarchical representations}, 51 | author={Nickel, Maximillian and Kiela, Douwe}, 52 | journal={Advances in neural information processing systems}, 53 | volume={30}, 54 | year={2017} 55 | } 56 | """ 57 | ) 58 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "hierarchy_transformers" 3 | version = "0.1.2.dev0" 4 | description = "Language models as hierarchy encoders." 5 | readme = "README.md" 6 | authors = [ 7 | {name = "Yuan He", email = "yuan.he@cs.ox.ac.uk"} 8 | ] 9 | maintainers = [ 10 | { name = "Yuan He", email = "yuan.he@cs.ox.ac.uk" } 11 | ] 12 | license = {text = "Apache License 2.0"} 13 | classifiers = [ 14 | "Programming Language :: Python :: 3.9", 15 | "Programming Language :: Python :: 3.10", 16 | "Programming Language :: Python :: 3.11", 17 | "Programming Language :: Python :: 3.12", 18 | "License :: OSI Approved :: Apache Software License", 19 | "Intended Audience :: Developers", 20 | "Intended Audience :: Science/Research", 21 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 22 | ] 23 | requires-python = ">=3.9" 24 | keywords = [ 25 | "Language Models", 26 | "Transformer Encoders", 27 | "Hierarchy Encoders", 28 | "Hierarchy Embedding", 29 | "Hyperbolic Embedding" 30 | ] 31 | dependencies = [ 32 | "sentence_transformers[train]>=3.4.0", 33 | "deeponto>=0.9.2", 34 | "geoopt>=0.5.0", 35 | "scipy==1.13.1", 36 | "seaborn" 37 | ] 38 | 39 | [project.urls] 40 | Homepage = "https://krr-oxford.github.io/HierarchyTransformers/" 41 | Repository = "https://github.com/KRR-Oxford/HierarchyTransformers" 42 | 43 | [project.optional-dependencies] 44 | dev = [ 45 | "pytest", 46 | "pytest-cov", 47 | "pytest-env" 48 | ] 49 | 50 | [build-system] 51 | requires = ["setuptools"] 52 | build-backend = "setuptools.build_meta" 53 | 54 | [tool.setuptools.packages.find] 55 | where = ["src"] 56 | 57 | [tool.setuptools.package-data] 58 | "*" = ["*.jar", "*.yaml", "lib/*.jar"] 59 | 60 | [tool.setuptools] 61 | package-dir = {"" = "src"} 62 | include-package-data = true 63 | 64 | [tool.ruff] 65 | line-length = 119 66 | fix = true 67 | 68 | [tool.ruff.lint] 69 | select = ["E", "F", "W", "I", "UP"] 70 | # Skip `E731` (do not assign a lambda expression, use a def) 71 | ignore = [ 72 | # LineTooLong 73 | "E501", 74 | # DoNotAssignLambda 75 | "E731" 76 | ] 77 | 78 | [tool.ruff.lint.isort] 79 | known-third-party = ["datasets"] 80 | required-imports = ["from __future__ import annotations"] 81 | 82 | [tool.pytest.ini-options] 83 | testpaths = ["tests"] 84 | addopts = "--strict-markers -m 'not slow'" 85 | markers = [ 86 | "slow: marks tests as slow" 87 | ] 88 | 89 | [tool.pytest_env] 90 | MODEL_PATHS="Hierarchy-Transformers/HiT-MiniLM-L12-WordNetNoun,Hierarchy-Transformers/HiT-MiniLM-L12-SnomedCT" 91 | DATASET_PATHS="Hierarchy-Transformers/WordNetNoun,Hierarchy-Transformers/SnomedCT" 92 | -------------------------------------------------------------------------------- /tests/test_loading_hit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import os 17 | 18 | import pytest 19 | import torch 20 | 21 | from hierarchy_transformers.models import HierarchyTransformer 22 | 23 | 24 | @pytest.fixture(params=os.getenv("MODEL_PATHS", "").split(",")) 25 | def model_path(request): 26 | # Ensure there are valid model names 27 | if not request.param: 28 | pytest.fail("No valid model names found in the MODEL_PATHS environment variable") 29 | return request.param.strip() # Strip any extra spaces 30 | 31 | 32 | def test_hierarchy_transformer_loading(model_path): 33 | try: 34 | # Attempt to load the HierarchyTransformer model 35 | model = HierarchyTransformer.from_pretrained(model_path) 36 | except Exception as e: 37 | pytest.fail(f"Model failed to load: {str(e)}") 38 | 39 | # Check that the model is not None 40 | assert model is not None, "Loaded model is None" 41 | # Check that the model has a valid manifold attribute 42 | assert hasattr(model, "manifold"), "Model does not have a 'manifold' attribute" 43 | # Check that the manifold is an instance of PoincareBall 44 | from geoopt.manifolds import PoincareBall 45 | 46 | assert isinstance(model.manifold, PoincareBall), "Manifold is not an instance of PoincareBall" 47 | 48 | # Perform a basic check on the embedding dimension 49 | assert model.embed_dim > 0, "Embedding dimension should be greater than zero" 50 | 51 | # Test that the model can perform a simple forward pass 52 | sample_input = ["computer", "personal computer", "fruit", "berry"] 53 | try: 54 | output = model.encode(sample_input, convert_to_tensor=True) 55 | except Exception as e: 56 | pytest.fail(f"Model failed to encode input: {str(e)}") 57 | 58 | # Check that the output is a tensor with the expected shape 59 | assert isinstance(output, torch.Tensor), "Output is not a tensor" 60 | assert output.shape[0] == len(sample_input), "Output shape does not match the input batch size" 61 | assert output.shape[1] == model.embed_dim, "Output embedding dimension does not match model's embed_dim" 62 | -------------------------------------------------------------------------------- /tests/test_training_hit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import random 17 | import tempfile 18 | 19 | import pytest 20 | from datasets import load_dataset 21 | from sentence_transformers.training_args import SentenceTransformerTrainingArguments 22 | 23 | from hierarchy_transformers.losses import HierarchyTransformerLoss 24 | from hierarchy_transformers.models import HierarchyTransformer, HierarchyTransformerTrainer 25 | 26 | 27 | @pytest.fixture 28 | def model_path(): 29 | return "sentence-transformers/all-MiniLM-L6-v2" 30 | 31 | 32 | @pytest.fixture 33 | def dataset_path(): 34 | return "Hierarchy-Transformers/WordNetNoun" 35 | 36 | 37 | def test_training(model_path, dataset_path): 38 | # Create a temporary directory for the output files 39 | with tempfile.TemporaryDirectory() as temp_dir: 40 | # 1. Load dataset and model 41 | hop_type = random.choice(["MultiHop", "MixedHop"]) 42 | neg_type = random.choice(["HardNegatives", "RandomNegatives"]) 43 | triplet_dataset = load_dataset(dataset_path, f"{hop_type}-{neg_type}-Triplets") 44 | trial_train = triplet_dataset["train"].select(range(64)) 45 | trial_val = triplet_dataset["val"].select(range(32)) 46 | model = HierarchyTransformer.from_pretrained(model_path) 47 | 48 | # 2. set up the loss function 49 | hit_loss = HierarchyTransformerLoss(model=model) 50 | 51 | # 3. Define the training arguments 52 | args = SentenceTransformerTrainingArguments( 53 | output_dir=temp_dir, 54 | num_train_epochs=1, 55 | learning_rate=1e-5, 56 | per_device_train_batch_size=16, 57 | per_device_eval_batch_size=16, 58 | warmup_ratio=0.05, # alternatively, set warmup_steps to 500 59 | eval_strategy="epoch", 60 | save_strategy="epoch", 61 | ) 62 | 63 | # 4. Train the model on trial samples 64 | trainer = HierarchyTransformerTrainer( 65 | model=model, 66 | args=args, 67 | train_dataset=trial_train, # train loss requires triplets 68 | eval_dataset=trial_val, # val loss requires triplets 69 | loss=hit_loss, 70 | ) 71 | trainer.train() 72 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/models/hierarchy_transformer/hyperbolic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import torch 17 | from geoopt.manifolds import PoincareBall 18 | 19 | 20 | def get_circum_poincareball(embed_dim: int) -> PoincareBall: 21 | """Get a Poincaré Ball with a curvature adapted to a given embedding dimension so that it circumscribes the output embedding space of pre-trained language models.""" 22 | curvature = 1 / embed_dim 23 | manifold = PoincareBall(c=curvature) 24 | return manifold 25 | 26 | 27 | def project_onto_subspace(manifold: PoincareBall, point: torch.Tensor, normal: torch.Tensor): 28 | """Compute the (hyperbolic) projection of a point onto a subspace (a hyper-plane through origin) of the input manifold. 29 | 30 | The projected point is the mid point of the geodesic segment that joins the input point and its reflection point about the subspace plane. 31 | 32 | Args: 33 | manifold (geoopt.manifolds.PoincareBall): The input Poincaré ball manifold. 34 | point (torch.Tensor): The input point 35 | normal (torch.Tensor): The normal vector of the subspace. 36 | """ 37 | reflection = reflect_about_subspace(point, normal) 38 | midpoint = manifold.weighted_midpoint(torch.vstack((point, reflection))) 39 | return midpoint 40 | 41 | 42 | def reflect_about_subspace(point: torch.Tensor, normal: torch.Tensor): 43 | """Compute the (Euclidean) reflection of a point about a sub-space (a hyper-plane through origin). 44 | 45 | This is a helper function for computing hyperbolic subspace projection of a point. 46 | 47 | Args: 48 | point (torch.Tensor): The input point. 49 | normal (torch.Tensor): The normal vector of the plane (through orgin). 50 | """ 51 | 52 | # Ensure the norm vector is non-zero 53 | if torch.all(normal == 0): 54 | raise ValueError("Norm vector cannot be zero.") 55 | 56 | # Calculate the dot product and magnitude squared of the norm vector 57 | dot_product = torch.dot(point, normal) 58 | normal_squared = torch.dot(normal, normal) 59 | 60 | # Compute the reflection point without explicitly normalizing the norm vector 61 | reflection = point - 2 * (dot_product / normal_squared) * normal 62 | 63 | return reflection 64 | -------------------------------------------------------------------------------- /scripts/evaluation/sbert/eval_sbert.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """This evaluation script is for probing pre-trained sBERT models.""" 15 | from __future__ import annotations 16 | 17 | import logging 18 | import sys 19 | 20 | import click 21 | from deeponto.utils import load_file 22 | from sentence_transformers import SentenceTransformer 23 | from yacs.config import CfgNode 24 | 25 | from hierarchy_transformers.datasets import load_hf_dataset 26 | from hierarchy_transformers.evaluation import SentenceTransformerEvaluator 27 | 28 | logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stderr)]) 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | @click.command() 33 | @click.option("-c", "--config_file", type=click.Path(exists=True)) 34 | @click.option("-o", "--output_path", type=click.Path(exists=True)) 35 | def main(config_file: str, output_path: str): 36 | 37 | # 0. load config 38 | config = CfgNode(load_file(config_file)) 39 | 40 | # 1. Load dataset and pre-trained model 41 | # NOTE: according to docs, it is very important to have column names ["child", "parent", "negative"] *in order* to match ["anchor", "positive", "negative"] 42 | pair_dataset = load_hf_dataset(config.dataset_path, config.dataset_name + "-Pairs") 43 | model = SentenceTransformer(model_name_or_path=config.model_path) 44 | 45 | # 2. Run validation for hyerparameter selection 46 | val_evaluator = SentenceTransformerEvaluator( 47 | child_entities=pair_dataset["val"]["child"], 48 | parent_entities=pair_dataset["val"]["parent"], 49 | labels=pair_dataset["val"]["label"], 50 | batch_size=config.eval_batch_size, 51 | truth_label=1, 52 | ) 53 | val_evaluator(model=model, output_path=output_path, epoch="validation") 54 | 55 | # 3. Evaluate the model performance on the test dataset 56 | val_results = val_evaluator.results 57 | best_val = val_results.loc[val_results["f1"].idxmax()] 58 | best_val_threshold = float(best_val["threshold"]) 59 | test_evaluator = SentenceTransformerEvaluator( 60 | child_entities=pair_dataset["test"]["child"], 61 | parent_entities=pair_dataset["test"]["parent"], 62 | labels=pair_dataset["test"]["label"], 63 | batch_size=config.eval_batch_size, 64 | truth_label=1, 65 | ) 66 | test_evaluator(model=model, output_path=output_path, best_threshold=best_val_threshold) 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /scripts/evaluation/hit/eval_hit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """This evaluation script is for HiT models.""" 15 | from __future__ import annotations 16 | 17 | import logging 18 | import sys 19 | 20 | import click 21 | from deeponto.utils import load_file 22 | from yacs.config import CfgNode 23 | 24 | from hierarchy_transformers.datasets import load_hf_dataset 25 | from hierarchy_transformers.evaluation import HierarchyTransformerEvaluator 26 | from hierarchy_transformers.models import HierarchyTransformer 27 | 28 | logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stderr)]) 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | @click.command() 33 | @click.option("-c", "--config_file", type=click.Path(exists=True)) 34 | @click.option("-o", "--output_path", type=click.Path(exists=True)) 35 | def main(config_file: str, output_path: str): 36 | 37 | # 0. load config 38 | config = CfgNode(load_file(config_file)) 39 | 40 | # 1. Load dataset and pre-trained model 41 | # NOTE: according to docs, it is very important to have column names ["child", "parent", "negative"] *in order* to match ["anchor", "positive", "negative"] 42 | pair_dataset = load_hf_dataset(config.dataset_path, config.dataset_name + "-Pairs") 43 | model = HierarchyTransformer.from_pretrained(model_name_or_path=config.model_path, revision=config.revision) 44 | 45 | # 2. Run validation for hyerparameter selection 46 | val_evaluator = HierarchyTransformerEvaluator( 47 | child_entities=pair_dataset["val"]["child"], 48 | parent_entities=pair_dataset["val"]["parent"], 49 | labels=pair_dataset["val"]["label"], 50 | batch_size=config.eval_batch_size, 51 | truth_label=1, 52 | ) 53 | val_evaluator(model=model, output_path=output_path, epoch="validation") 54 | 55 | # 3. Evaluate the model performance on the test dataset 56 | val_results = val_evaluator.results 57 | best_val = val_results.loc[val_results["f1"].idxmax()] 58 | best_val_centri_weight = float(best_val["centri_weight"]) 59 | best_val_threshold = float(best_val["threshold"]) 60 | test_evaluator = HierarchyTransformerEvaluator( 61 | child_entities=pair_dataset["test"]["child"], 62 | parent_entities=pair_dataset["test"]["parent"], 63 | labels=pair_dataset["test"]["label"], 64 | batch_size=config.eval_batch_size, 65 | truth_label=1, 66 | ) 67 | test_evaluator( 68 | model=model, 69 | output_path=output_path, 70 | best_centri_weight=best_val_centri_weight, 71 | best_threshold=best_val_threshold, 72 | ) 73 | 74 | 75 | if __name__ == "__main__": 76 | main() 77 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import pytest 17 | import torch 18 | from geoopt.manifolds import PoincareBall 19 | 20 | from hierarchy_transformers.models.hierarchy_transformer.hyperbolic import ( 21 | project_onto_subspace, 22 | reflect_about_subspace, 23 | ) 24 | 25 | 26 | @pytest.fixture 27 | def manifold(): 28 | # Create a PoincareBall manifold with a sample curvature 29 | return PoincareBall(c=1.0) 30 | 31 | 32 | @pytest.fixture 33 | def sample_point(): 34 | # Create a sample point on the manifold 35 | return torch.tensor([0.3, 0.3], dtype=torch.float32) 36 | 37 | 38 | @pytest.fixture 39 | def normal_vector(): 40 | # Create a sample normal vector for the subspace 41 | return torch.tensor([0.0, 1.0], dtype=torch.float32) 42 | 43 | 44 | def test_reflect_about_subspace(sample_point, normal_vector): 45 | # Test the reflect_about_subspace function 46 | reflection = reflect_about_subspace(sample_point, normal_vector) 47 | 48 | # Check that the reflection is a torch.Tensor 49 | assert isinstance(reflection, torch.Tensor), "Reflection should be a torch.Tensor" 50 | 51 | # Check the shape of the reflection 52 | assert reflection.shape == sample_point.shape, "Reflection should have the same shape as the input point" 53 | 54 | # Verify the reflection calculation 55 | expected_reflection = torch.tensor([0.3, -0.3], dtype=torch.float32) 56 | assert torch.allclose(reflection, expected_reflection, atol=1e-6), "Reflection values do not match expected values" 57 | 58 | # Edge case: Normal vector cannot be zero 59 | with pytest.raises(ValueError): 60 | reflect_about_subspace(sample_point, torch.zeros_like(normal_vector)) 61 | 62 | 63 | def test_project_onto_subspace(manifold, sample_point, normal_vector): 64 | # Test the project_onto_subspace function 65 | projection = project_onto_subspace(manifold, sample_point, normal_vector) 66 | 67 | # Check that the projection is a torch.Tensor 68 | assert isinstance(projection, torch.Tensor), "Projection should be a torch.Tensor" 69 | 70 | # Check the shape of the projection 71 | assert projection.shape == sample_point.shape, "Projection should have the same shape as the input point" 72 | 73 | # Check that the projection lies within the Poincare Ball (norm less than 1) 74 | norm = torch.norm(projection) 75 | assert norm < 1.0, "Projected point should lie within the Poincare Ball" 76 | 77 | # Verify the projection calculation 78 | # Note: You may need to adjust this expected value based on the specifics of the PoincareBall manifold projection 79 | expected_projection = torch.tensor([0.27321523, 0.0000], dtype=torch.float32) 80 | assert torch.allclose(projection, expected_projection, atol=1e-6), "Projection values do not match expected values" 81 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/models/static_embed/poincare_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import logging 17 | 18 | import torch 19 | from geoopt import ManifoldParameter 20 | from geoopt.manifolds import PoincareBall 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class PoincareStaticEmbedding(torch.nn.Module): 26 | r"""Class for the static hyperbolic embedding models: 27 | 28 | - [1] Poincaré Embedding by [Nickel et al., NeurIPS 2017](https://arxiv.org/abs/1705.08039). 29 | - [2] Hyperbolic Entailment Cone by [Ganea et al., ICML 2018](https://arxiv.org/abs/1804.01882). 30 | 31 | both of which lie in a unit Poincaré ball. According to [2], it is better to apply the entailment cone loss in the post-training phase of a Poincaré embedding model in [1]. 32 | 33 | Attributes: 34 | entities (list): The list of input entity IDs (fixed). 35 | idx2ent (dict): A dictionary that stores the `(index, entity_id)` pairs. 36 | ent2idx (dict): A dictionary that stores the `(entity_id, index)` pairs. 37 | embed_dim (int): The embedding dimension of this model. 38 | manifold (geoopt.manifolds.PoincareBall): The hyperbolic manifold (Poincaré Ball) of this model. 39 | embed (torch.nn.Embedding): The static hyperbolic embeddings for entities. 40 | """ 41 | 42 | def __init__(self, entity_ids: list, embed_dim: int, init_weights: float = 1e-3): 43 | super().__init__() 44 | 45 | self.entities = entity_ids 46 | self.idx2ent = {idx: ent for idx, ent in enumerate(self.entities)} 47 | self.ent2idx = {v: k for k, v in self.idx2ent.items()} 48 | self.embed_dim = embed_dim 49 | self.manifold = PoincareBall() 50 | self.dist = self.manifold.dist 51 | 52 | # initialise static embedding 53 | self.embed = torch.nn.Embedding( 54 | num_embeddings=len(self.idx2ent), # fixed num embeddings, 55 | embedding_dim=self.embed_dim, 56 | sparse=False, 57 | max_norm=1.0, # unit poincare ball projection 58 | ) 59 | self.embed.weight.data.uniform_(-init_weights, init_weights) 60 | self.embed.weight = ManifoldParameter(self.embed.weight, manifold=self.manifold) 61 | logger.info(f"Init static hyperbolic embedding for {len(self.idx2ent)} entities.") 62 | 63 | def forward(self, inputs: torch.Tensor): 64 | """Forward propagation. 65 | 66 | The inputs are organised as `(batch_size, num_entities, embed_dim)` where `dim=` includes `(child, parent, negative_parents*)`. 67 | """ 68 | 69 | input_embeds = self.embed( 70 | inputs 71 | ) # (batch_size, num_entities, hidden_dim), dim 1 includes (child, parent, negative_parents*) 72 | objects = input_embeds.narrow(dim=1, start=1, length=input_embeds.size(1) - 1) # use .narrow to keep dim 73 | subject = input_embeds.narrow(dim=1, start=0, length=1).expand_as(objects) 74 | return subject, objects 75 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/models/static_embed/poincare_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import logging 17 | 18 | import torch 19 | from geoopt.optim import RiemannianAdam 20 | from torch.utils.data import DataLoader 21 | from tqdm import tqdm 22 | from transformers import get_linear_schedule_with_warmup 23 | 24 | from .poincare_embed import PoincareStaticEmbedding 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class PoincareStaticEmbeddingTrainer: 30 | r"""Class for training the static hyperbolic embedding models: 31 | 32 | - [1] Poincaré Embedding by [Nickel et al., NeurIPS 2017](https://arxiv.org/abs/1705.08039). 33 | - [2] Hyperbolic Entailment Cone by [Ganea et al., ICML 2018](https://arxiv.org/abs/1804.01882). 34 | 35 | both of which lie in a unit Poincaré ball. According to [2], it is better to apply the entailment cone loss in the post-training phase of a Poincaré embedding model in [1]. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | model: PoincareStaticEmbedding, 41 | train_dataset: list, 42 | loss: torch.nn.Module, 43 | num_train_epochs: int = 256, 44 | learning_rate: float = 0.01, 45 | train_batch_size: int = 200, 46 | warmup_epochs: int = 10, 47 | ): 48 | self.model = model 49 | self.train_dataloader = DataLoader(torch.tensor(train_dataset), shuffle=True, batch_size=train_batch_size) 50 | self.loss = loss 51 | self.learning_rate = learning_rate 52 | self.optimizer = RiemannianAdam(self.model.parameters(), lr=self.learning_rate) 53 | self.current_epoch = 0 54 | self.num_train_epochs = num_train_epochs 55 | self.num_epoch_steps = len(self.train_dataloader) 56 | self.num_training_steps = self.num_epoch_steps * self.num_train_epochs 57 | self.warmup_epochs = warmup_epochs 58 | self.scheduler = get_linear_schedule_with_warmup( 59 | self.optimizer, 60 | num_warmup_steps=self.warmup_epochs * self.num_epoch_steps, # one epoch warming-up 61 | num_training_steps=self.num_training_steps, 62 | ) 63 | 64 | @property 65 | def lr(self): 66 | for g in self.optimizer.param_groups: 67 | return g["lr"] 68 | 69 | def training_step(self, batch, device): 70 | batch = batch.to(device) 71 | self.optimizer.zero_grad(set_to_none=True) 72 | subject, objects = self.model(batch) 73 | loss = self.loss(subject, objects) 74 | loss.backward() 75 | self.optimizer.step() 76 | self.scheduler.step() 77 | return loss 78 | 79 | def train(self, device): 80 | self.model.to(device) 81 | for _ in range(self.num_train_epochs): 82 | epoch_bar = tqdm( 83 | range(self.num_epoch_steps), desc=f"Epoch {self.current_epoch + 1}", leave=True, unit="batch" 84 | ) 85 | for batch in self.train_dataloader: 86 | loss = self.training_step(batch, device) 87 | epoch_bar.set_postfix({"batch_loss": loss.item(), "lr": self.lr}) 88 | epoch_bar.update() 89 | self.current_epoch += 1 90 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/models/hierarchy_transformer/hit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import logging 17 | from collections.abc import Iterable 18 | 19 | import torch 20 | from geoopt.manifolds import PoincareBall 21 | from sentence_transformers import SentenceTransformer 22 | from sentence_transformers.models import Pooling, Transformer 23 | 24 | from .hyperbolic import get_circum_poincareball 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class HierarchyTransformer(SentenceTransformer): 30 | r"""Class for Hierarchy Transformer encoder (HiT), extending from [`SentenceTransformer`](https://www.sbert.net/). 31 | 32 | Attributes: 33 | embed_dim (int): The embedding dimension of this model. 34 | manifold (geoopt.manifolds.PoincareBall): The hyperbolic manifold (Poincaré Ball) of this model. 35 | """ 36 | 37 | def __init__( 38 | self, 39 | model_name_or_path: str | None = None, 40 | modules: Iterable[torch.nn.Module] | None = None, 41 | device: str | None = None, 42 | revision: str | None = None, 43 | ): 44 | super().__init__(model_name_or_path=model_name_or_path, modules=modules, device=device, revision=revision) 45 | # PoincareBall in geoopt will be wrongly classified as a sub-module 46 | # so we use a dictionary to store the manifold 47 | self._register_buffer = {"manifold": self.get_circum_poincareball(self.embed_dim)} 48 | 49 | @property 50 | def embed_dim(self): 51 | return self._first_module().get_word_embedding_dimension() 52 | 53 | @property 54 | def manifold(self): 55 | return self._register_buffer["manifold"] 56 | 57 | @classmethod 58 | def from_pretrained( 59 | cls, 60 | model_name_or_path: str, 61 | revision: str | None = None, 62 | pooling_mode: str | None = "mean", 63 | device: torch.device | None = None, 64 | ): 65 | """Load a pretrained model from HuggingFace hub or local repository.""" 66 | try: 67 | # Load from sentence_transformers library 68 | pretrained_model = SentenceTransformer(model_name_or_path, device=device, revision=revision) 69 | transformer = pretrained_model._modules["0"] 70 | pooling = pretrained_model._modules["1"] 71 | assert isinstance(pooling, Pooling) 72 | logger.info( 73 | f"Load `{model_name_or_path}` from `sentence-transformers` with existing pooling (discard the normalising layer if any)." 74 | ) 75 | except Exception: 76 | # Load from huggingface transformers library 77 | transformer = Transformer(model_name_or_path, max_seq_length=256, model_args={"revision": revision}) 78 | pooling = Pooling( 79 | word_embedding_dimension=transformer.get_word_embedding_dimension(), pooling_mode=pooling_mode 80 | ) 81 | logger.info(f"Load `{model_name_or_path}` from `huggingface-transformers` with '{pooling_mode}' pooling.") 82 | 83 | return cls(modules=[transformer, pooling], device=device) 84 | 85 | @staticmethod 86 | def get_circum_poincareball(embed_dim: int) -> PoincareBall: 87 | """Get a Poincaré Ball with a curvature adapted to a given embedding dimension so that it circumscribes the output embedding space of pre-trained language models.""" 88 | manifold = get_circum_poincareball(embed_dim) 89 | logging.info(f"Poincare ball curvature: {manifold.c}") 90 | return manifold 91 | -------------------------------------------------------------------------------- /docs/static/css/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: "Noto Sans", sans-serif; 3 | } 4 | 5 | .footer .icon-link { 6 | font-size: 25px; 7 | color: #000; 8 | } 9 | 10 | .link-block a { 11 | margin-top: 5px; 12 | margin-bottom: 5px; 13 | } 14 | 15 | .dnerf { 16 | font-variant: small-caps; 17 | } 18 | 19 | .teaser .hero-body { 20 | padding-top: 0; 21 | padding-bottom: 3rem; 22 | } 23 | 24 | .teaser { 25 | font-family: "Google Sans", sans-serif; 26 | } 27 | 28 | .publication-banner { 29 | max-height: parent; 30 | } 31 | 32 | .publication-banner video { 33 | position: relative; 34 | left: auto; 35 | top: auto; 36 | transform: none; 37 | object-fit: fit; 38 | } 39 | 40 | .publication-title { 41 | font-family: "Google Sans", sans-serif; 42 | } 43 | 44 | .publication-authors { 45 | font-family: "Google Sans", sans-serif; 46 | } 47 | 48 | .publication-venue { 49 | color: #555; 50 | width: fit-content; 51 | font-weight: bold; 52 | } 53 | 54 | .publication-awards { 55 | color: #ff3860; 56 | width: fit-content; 57 | font-weight: bolder; 58 | } 59 | 60 | .publication-authors a { 61 | color: hsl(204, 86%, 53%) !important; 62 | } 63 | 64 | .publication-authors a:hover { 65 | text-decoration: underline; 66 | } 67 | 68 | .author-block { 69 | display: inline-block; 70 | } 71 | 72 | .publication-video { 73 | position: relative; 74 | width: 100%; 75 | height: 0; 76 | padding-bottom: 56.25%; 77 | 78 | overflow: hidden; 79 | border-radius: 10px !important; 80 | } 81 | 82 | .publication-video iframe { 83 | position: absolute; 84 | top: 0; 85 | left: 0; 86 | width: 100%; 87 | height: 100%; 88 | } 89 | 90 | .results-carousel { 91 | overflow: hidden; 92 | } 93 | 94 | .results-carousel .item { 95 | margin: 5px; 96 | overflow: hidden; 97 | padding: 20px; 98 | font-size: 0; 99 | } 100 | 101 | .results-carousel video { 102 | margin: 0; 103 | } 104 | 105 | .slider-pagination .slider-page { 106 | background: #000000; 107 | } 108 | 109 | .eql-cntrb { 110 | font-size: smaller; 111 | } 112 | 113 | .huggingface-icon { 114 | display: inline-block; 115 | width: 16px; /* Adjust width as needed */ 116 | height: 16px; /* Adjust height as needed */ 117 | background-image: url("https://huggingface.co/favicon.ico"); 118 | background-size: contain; 119 | background-repeat: no-repeat; 120 | margin-right: 8px; /* Space between the icon and the text */ 121 | } 122 | 123 | .zenodo-icon { 124 | display: inline-block; 125 | width: 16px; /* Adjust width as needed */ 126 | height: 16px; /* Adjust height as needed */ 127 | background-image: url("https://zenodo.org/static/favicon.ico"); 128 | background-size: contain; 129 | background-repeat: no-repeat; 130 | margin-right: 8px; /* Space between the icon and the text */ 131 | } 132 | 133 | .neurips-icon { 134 | display: inline-block; 135 | width: 16px; /* Adjust width as needed */ 136 | height: 16px; /* Adjust height as needed */ 137 | background-image: url("../../assets/images/neurips.svg"); 138 | background-size: contain; 139 | background-repeat: no-repeat; 140 | margin-right: 8px; /* Space between the icon and the text */ 141 | } 142 | 143 | /* Style for the outer container */ 144 | .box { 145 | margin: 20px; 146 | border-radius: 8px; 147 | background-color: #f6f8fa; /* Light gray background */ 148 | padding: 12px; /* Reduced padding for a tighter look */ 149 | overflow: auto; /* Add scroll if the content is too wide */ 150 | box-shadow: 0 2px 5px rgba(0, 0, 0, 0.15); /* Subtle shadow for depth */ 151 | } 152 | code { 153 | background-color: #f6f8fa !important; /* Light gray background */ 154 | color: #333 !important; /* Override the default theme text color */ 155 | } 156 | pre { 157 | margin: 0; /* Remove default margin */ 158 | padding: 12px; /* Add padding to the pre element instead */ 159 | font-family: "Fira Code", "Courier New", Courier, monospace; /* Improved font stack */ 160 | font-size: 14px; 161 | line-height: 1.5; /* Adjusted line height for better readability */ 162 | background-color: #f6f8fa !important; /* Light gray background */ 163 | border-radius: 8px; /* Match the border radius */ 164 | color: #333 !important; /* Override the default theme text color */ 165 | } 166 | /* GitHub-style inline code */ 167 | code.inline { 168 | font-family: "Fira Code", "Courier New", Courier, monospace; /* Improved font stack */ 169 | padding: 2px 4px; 170 | border-radius: 4px; 171 | font-size: 90%; /* Slightly smaller than regular text */ 172 | background-color: #f6f8fa !important; /* Light gray background */ 173 | } 174 | -------------------------------------------------------------------------------- /scripts/training/sft/training_sft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """This training script is for standard [CLS] supervised fine-tuning for BERT models.""" 15 | from __future__ import annotations 16 | 17 | import logging 18 | import os 19 | import shutil 20 | import sys 21 | 22 | import click 23 | import pandas as pd 24 | import torch 25 | from deeponto.utils import create_path, load_file, set_seed 26 | from transformers import ( 27 | AutoModelForSequenceClassification, 28 | AutoTokenizer, 29 | DataCollatorWithPadding, 30 | Trainer, 31 | TrainingArguments, 32 | ) 33 | from yacs.config import CfgNode 34 | 35 | from hierarchy_transformers.datasets import load_hf_dataset 36 | from hierarchy_transformers.evaluation.metrics import evaluate_by_threshold 37 | 38 | logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stderr)]) 39 | logger = logging.getLogger(__name__) 40 | 41 | 42 | @click.command() 43 | @click.option("-c", "--config_file", type=click.Path(exists=True)) 44 | def main(config_file: str): 45 | # 0. set seed, load config, and format output dir 46 | set_seed(8888) 47 | config = CfgNode(load_file(config_file)) 48 | model_path_suffix = config.model_path.split(os.path.sep)[-1] 49 | dataset_path_suffix = config.dataset_path.split(os.path.sep)[-1] 50 | output_dir = f"experiments/SFT-{model_path_suffix}-{dataset_path_suffix}-{config.dataset_name}" 51 | create_path(output_dir) 52 | try: 53 | shutil.copy2(config_file, os.path.join(output_dir, "config.yaml")) 54 | except Exception: 55 | pass 56 | 57 | # 1. Load dataset and pre-trained model 58 | # NOTE: according to docs, it is very important to have column names ["child", "parent", "negative"] *in order* to match ["anchor", "positive", "negative"] 59 | pair_dataset = load_hf_dataset(config.dataset_path, config.dataset_name + "-Pairs") 60 | model = AutoModelForSequenceClassification.from_pretrained( 61 | pretrained_model_name_or_path=config.model_path, num_labels=2 62 | ) 63 | 64 | # 2. Tokenise dataset and setup collator 65 | tokenizer = AutoTokenizer.from_pretrained(config.model_path) 66 | tok_func = lambda example: tokenizer(example["child"], example["parent"], truncation=True, max_length=256) 67 | train_examples = pair_dataset["train"].map(tok_func, batched=True) 68 | val_examples = pair_dataset["val"].map(tok_func, batched=True) 69 | test_examples = pair_dataset["test"].map(tok_func, batched=True) 70 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 71 | 72 | # 3. Define the training arguments 73 | args = TrainingArguments( 74 | output_dir=output_dir, 75 | num_train_epochs=int(config.num_train_epochs), 76 | learning_rate=float(config.learning_rate), 77 | per_device_train_batch_size=int(config.train_batch_size), 78 | per_device_eval_batch_size=int(config.eval_batch_size), 79 | evaluation_strategy="steps", 80 | save_strategy="steps", 81 | eval_steps=500, 82 | save_steps=500, 83 | logging_steps=100, 84 | save_total_limit=2, 85 | load_best_model_at_end=True, 86 | ) 87 | 88 | # 4. Create the trainer & start training 89 | trainer = Trainer( 90 | model=model, 91 | args=args, 92 | train_dataset=train_examples, 93 | eval_dataset=val_examples, 94 | data_collator=data_collator, 95 | tokenizer=tokenizer, 96 | ) 97 | trainer.train() 98 | 99 | # 5. Evaluate the model performance on the test dataset 100 | test_preds = trainer.predict(test_examples) 101 | test_scores = torch.tensor(test_preds.predictions).argmax(dim=1) 102 | test_labels = torch.tensor(test_preds.label_ids) 103 | test_results = pd.DataFrame( 104 | columns=["threshold", "precision", "recall", "f1", "accuracy", "accuracy_on_negatives"] 105 | ) 106 | test_results.loc["testing"] = evaluate_by_threshold(scores=test_scores, labels=test_labels, threshold=0.5) 107 | logger.info(test_results.loc["testing"]) 108 | create_path(os.path.join(output_dir, "eval")) 109 | test_results.to_csv(os.path.join(output_dir, "eval", "results.tsv"), sep="\t") 110 | 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /tests/test_loading_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import os 17 | import random 18 | 19 | import pytest 20 | from datasets import load_dataset 21 | 22 | 23 | @pytest.fixture(params=os.getenv("DATASET_PATHS", "").split(",")) 24 | def dataset_path(request): 25 | # Ensure there are valid dataset names 26 | if not request.param: 27 | pytest.fail("No valid dataset names found in the DATASET_PATHS environment variable") 28 | return request.param.strip() # Strip any extra spaces 29 | 30 | 31 | def test_dataset_loading(dataset_path): 32 | hop_type = random.choice(["MultiHop", "MixedHop"]) 33 | neg_type = random.choice(["HardNegatives", "RandomNegatives"]) 34 | struct_type = random.choice(["Triplets", "Pairs"]) 35 | part = f"{hop_type}-{neg_type}-{struct_type}" 36 | try: 37 | # Attempt to load the HierarchyTransformer model 38 | dataset = load_dataset(dataset_path, part) 39 | except Exception as e: 40 | pytest.fail(f"Dataset failed to load: {str(e)}") 41 | 42 | # Check that the datast is not None 43 | assert dataset is not None, "Loaded dataset is None" 44 | 45 | 46 | # [Deprecated] HF upload code 47 | 48 | # save_path = "WordNetNoun" 49 | # hop_type = "MultiHop" 50 | # for split in ["train", "val", "test"]: 51 | # n_rand = 0 52 | # n_hard = 0 53 | # triplet_examples = { 54 | # "random_negatives": [], 55 | # "hard_negatives": [] 56 | # } 57 | # pair_examples = { 58 | # "random_negatives": [], 59 | # "hard_negatives": [] 60 | # } 61 | # for sample in dataset[split]: 62 | # child = entity_lexicon[sample["child"]]["name"] 63 | # parent = entity_lexicon[sample["parent"]]["name"] 64 | # negative_parents = [entity_lexicon[neg]["name"] for neg in sample["random_negatives"]] 65 | # n_rand += len(negative_parents) 66 | # hard_negatives = [entity_lexicon[sib]["name"] for sib in sample["hard_negatives"]] 67 | # n_hard += len(hard_negatives) 68 | # triplet_examples["random_negatives"] += [(child, parent, neg) for neg in negative_parents] 69 | # triplet_examples["hard_negatives"] += [(child, parent, neg) for neg in hard_negatives] 70 | # pair_examples["random_negatives"] += [(child, parent, 1)] + [(child, neg, 0) for neg in negative_parents] 71 | # pair_examples["hard_negatives"] += [(child, parent, 1)] + [(child, neg, 0) for neg in hard_negatives] 72 | # assert n_rand == n_hard 73 | # pd.DataFrame(triplet_examples["random_negatives"], columns=["child", "parent", "negative"]).to_parquet(f"{save_path}/{hop_type}-RandomNegatives-Triplets/{split}.parquet", index=False) 74 | # pd.DataFrame(triplet_examples["hard_negatives"], columns=["child", "parent", "negative"]).to_parquet(f"{save_path}/{hop_type}-HardNegatives-Triplets/{split}.parquet", index=False) 75 | # pd.DataFrame(pair_examples["random_negatives"], columns=["child", "parent", "label"]).to_parquet(f"{save_path}/{hop_type}-RandomNegatives-Pairs/{split}.parquet", index=False) 76 | # pd.DataFrame(pair_examples["hard_negatives"], columns=["child", "parent", "label"]).to_parquet(f"{save_path}/{hop_type}-HardNegatives-Pairs/{split}.parquet", index=False) 77 | 78 | # [Deprecated] Compare HF and local version 79 | # data_path = "/home/yuan/projects/HiT/data/wordnet-multi" 80 | # dataset, entity_lexicon = load_hierarchy_dataset(data_path) 81 | 82 | # for split in ["train", "val", "test"]: 83 | # for is_hard in [True, False]: 84 | # for is_triplet in [True, False]: 85 | # dataset_zenodo = prepare_hierarchy_examples(entity_lexicon, dataset[split], is_hard, is_triplet) 86 | # if is_triplet: 87 | # dataset_zenodo = [{'child': x.texts[0], 'parent': x.texts[1], 'negative': x.texts[2]} for x in dataset_zenodo] 88 | # else: 89 | # dataset_zenodo = [{'child': x.texts[0], 'parent': x.texts[1], 'label': x.label} for x in dataset_zenodo] 90 | # neg = "HardNegatives" if is_hard else "RandomNegatives" 91 | # struct = "Triplets" if is_triplet else "Pairs" 92 | # dataset_hf = load_dataset("Hierarchy-Transformers/WordNetNoun", f"MultiHop-{neg}-{struct}")[split] 93 | # for i in tqdm(range(len(dataset_zenodo)), desc=f"Check MultiHop-{neg}-{struct}"): 94 | # assert dataset_zenodo[i] == dataset_hf[i], (dataset_zenodo[i], dataset_hf[i]) 95 | -------------------------------------------------------------------------------- /docs/static/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround { 2 | from { 3 | -webkit-transform: rotate(0); 4 | transform: rotate(0); 5 | } 6 | to { 7 | -webkit-transform: rotate(359deg); 8 | transform: rotate(359deg); 9 | } 10 | } 11 | @keyframes spinAround { 12 | from { 13 | -webkit-transform: rotate(0); 14 | transform: rotate(0); 15 | } 16 | to { 17 | -webkit-transform: rotate(359deg); 18 | transform: rotate(359deg); 19 | } 20 | } 21 | .slider { 22 | position: relative; 23 | width: 100%; 24 | } 25 | .slider-container { 26 | display: flex; 27 | flex-wrap: nowrap; 28 | flex-direction: row; 29 | overflow: hidden; 30 | -webkit-transform: translate3d(0, 0, 0); 31 | transform: translate3d(0, 0, 0); 32 | min-height: 100%; 33 | } 34 | .slider-container.is-vertical { 35 | flex-direction: column; 36 | } 37 | .slider-container .slider-item { 38 | flex: none; 39 | } 40 | .slider-container .slider-item .image.is-covered img { 41 | -o-object-fit: cover; 42 | object-fit: cover; 43 | -o-object-position: center center; 44 | object-position: center center; 45 | height: 100%; 46 | width: 100%; 47 | } 48 | .slider-container .slider-item .video-container { 49 | height: 0; 50 | padding-bottom: 0; 51 | padding-top: 56.25%; 52 | margin: 0; 53 | position: relative; 54 | } 55 | .slider-container .slider-item .video-container.is-1by1, 56 | .slider-container .slider-item .video-container.is-square { 57 | padding-top: 100%; 58 | } 59 | .slider-container .slider-item .video-container.is-4by3 { 60 | padding-top: 75%; 61 | } 62 | .slider-container .slider-item .video-container.is-21by9 { 63 | padding-top: 42.857143%; 64 | } 65 | .slider-container .slider-item .video-container embed, 66 | .slider-container .slider-item .video-container iframe, 67 | .slider-container .slider-item .video-container object { 68 | position: absolute; 69 | top: 0; 70 | left: 0; 71 | width: 100% !important; 72 | height: 100% !important; 73 | } 74 | .slider-navigation-next, 75 | .slider-navigation-previous { 76 | display: flex; 77 | justify-content: center; 78 | align-items: center; 79 | position: absolute; 80 | width: 42px; 81 | height: 42px; 82 | background: #fff center center no-repeat; 83 | background-size: 20px 20px; 84 | border: 1px solid #fff; 85 | border-radius: 25091983px; 86 | box-shadow: 0 2px 5px #3232321a; 87 | top: 50%; 88 | margin-top: -20px; 89 | left: 0; 90 | cursor: pointer; 91 | transition: opacity 0.3s, -webkit-transform 0.3s; 92 | transition: transform 0.3s, opacity 0.3s; 93 | transition: transform 0.3s, opacity 0.3s, -webkit-transform 0.3s; 94 | } 95 | .slider-navigation-next:hover, 96 | .slider-navigation-previous:hover { 97 | -webkit-transform: scale(1.2); 98 | transform: scale(1.2); 99 | } 100 | .slider-navigation-next.is-hidden, 101 | .slider-navigation-previous.is-hidden { 102 | display: none; 103 | opacity: 0; 104 | } 105 | .slider-navigation-next svg, 106 | .slider-navigation-previous svg { 107 | width: 25%; 108 | } 109 | .slider-navigation-next { 110 | left: auto; 111 | right: 0; 112 | background: #fff center center no-repeat; 113 | background-size: 20px 20px; 114 | } 115 | .slider-pagination { 116 | display: none; 117 | justify-content: center; 118 | align-items: center; 119 | position: absolute; 120 | bottom: 0; 121 | left: 0; 122 | right: 0; 123 | padding: 0.5rem 1rem; 124 | text-align: center; 125 | } 126 | .slider-pagination .slider-page { 127 | background: #fff; 128 | width: 10px; 129 | height: 10px; 130 | border-radius: 25091983px; 131 | display: inline-block; 132 | margin: 0 3px; 133 | box-shadow: 0 2px 5px #3232321a; 134 | transition: -webkit-transform 0.3s; 135 | transition: transform 0.3s; 136 | transition: transform 0.3s, -webkit-transform 0.3s; 137 | cursor: pointer; 138 | } 139 | .slider-pagination .slider-page.is-active, 140 | .slider-pagination .slider-page:hover { 141 | -webkit-transform: scale(1.4); 142 | transform: scale(1.4); 143 | } 144 | @media screen and (min-width: 800px) { 145 | .slider-pagination { 146 | display: flex; 147 | } 148 | } 149 | .hero.has-carousel { 150 | position: relative; 151 | } 152 | .hero.has-carousel + .hero-body, 153 | .hero.has-carousel + .hero-footer, 154 | .hero.has-carousel + .hero-head { 155 | z-index: 10; 156 | overflow: hidden; 157 | } 158 | .hero.has-carousel .hero-carousel { 159 | position: absolute; 160 | top: 0; 161 | left: 0; 162 | bottom: 0; 163 | right: 0; 164 | height: auto; 165 | border: none; 166 | margin: auto; 167 | padding: 0; 168 | z-index: 0; 169 | } 170 | .hero.has-carousel .hero-carousel .slider { 171 | width: 100%; 172 | max-width: 100%; 173 | overflow: hidden; 174 | height: 100% !important; 175 | max-height: 100%; 176 | z-index: 0; 177 | } 178 | .hero.has-carousel .hero-carousel .slider .has-background { 179 | max-height: 100%; 180 | } 181 | .hero.has-carousel .hero-carousel .slider .has-background .is-background { 182 | -o-object-fit: cover; 183 | object-fit: cover; 184 | -o-object-position: center center; 185 | object-position: center center; 186 | height: 100%; 187 | width: 100%; 188 | } 189 | .hero.has-carousel .hero-body { 190 | margin: 0 3rem; 191 | z-index: 10; 192 | } 193 | -------------------------------------------------------------------------------- /scripts/training/hit/training_hit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """This training script is hierarchy re-training of HiT models.""" 15 | from __future__ import annotations 16 | 17 | import logging 18 | import os 19 | import shutil 20 | import sys 21 | 22 | import click 23 | from deeponto.utils import create_path, load_file, set_seed 24 | from sentence_transformers.training_args import SentenceTransformerTrainingArguments 25 | from yacs.config import CfgNode 26 | 27 | from hierarchy_transformers.datasets import load_hf_dataset 28 | from hierarchy_transformers.evaluation import HierarchyTransformerEvaluator 29 | from hierarchy_transformers.losses import HierarchyTransformerLoss 30 | from hierarchy_transformers.models import HierarchyTransformer 31 | from hierarchy_transformers.models.hierarchy_transformer.hit_trainer import HierarchyTransformerTrainer 32 | 33 | logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stderr)]) 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | @click.command() 38 | @click.option("-c", "--config_file", type=click.Path(exists=True)) 39 | def main(config_file: str): 40 | # 0. set seed, load config, and format output dir 41 | set_seed(8888) 42 | config = CfgNode(load_file(config_file)) 43 | model_path_suffix = config.model_path.split(os.path.sep)[-1] 44 | dataset_path_suffix = config.dataset_path.split(os.path.sep)[-1] 45 | output_dir = f"experiments/HiT-{model_path_suffix}-{dataset_path_suffix}-{config.dataset_name}" 46 | create_path(output_dir) 47 | try: 48 | shutil.copy2(config_file, os.path.join(output_dir, "config.yaml")) 49 | except Exception: 50 | pass 51 | 52 | # 1. Load dataset and pre-trained model 53 | # NOTE: according to docs, it is very important to have column names ["child", "parent", "negative"] *in order* to match ["anchor", "positive", "negative"] 54 | triplet_dataset = load_hf_dataset(config.dataset_path, config.dataset_name + "-Triplets") 55 | pair_dataset = load_hf_dataset(config.dataset_path, config.dataset_name + "-Pairs") 56 | model = HierarchyTransformer.from_pretrained(model_name_or_path=config.model_path) 57 | 58 | # 2. set up the loss function 59 | hit_loss = HierarchyTransformerLoss( 60 | model=model, 61 | clustering_loss_weight=config.hit_loss.clustering_loss_weight, 62 | clustering_loss_margin=config.hit_loss.clustering_loss_margin, 63 | centripetal_loss_weight=config.hit_loss.centripetal_loss_weight, 64 | centripetal_loss_margin=config.hit_loss.centripetal_loss_margin, 65 | ) 66 | logger.info(f"HiT loss config: {hit_loss.get_config_dict()}") 67 | 68 | # 3. Define a validation evaluator for use during training. 69 | val_evaluator = HierarchyTransformerEvaluator( 70 | child_entities=pair_dataset["val"]["child"], 71 | parent_entities=pair_dataset["val"]["parent"], 72 | labels=pair_dataset["val"]["label"], 73 | batch_size=config.eval_batch_size, 74 | truth_label=1, 75 | ) 76 | 77 | # 4. Define the training arguments 78 | args = SentenceTransformerTrainingArguments( 79 | output_dir=output_dir, 80 | num_train_epochs=int(config.num_train_epochs), 81 | learning_rate=float(config.learning_rate), 82 | per_device_train_batch_size=int(config.train_batch_size), 83 | per_device_eval_batch_size=int(config.eval_batch_size), 84 | warmup_steps=500, 85 | eval_strategy="epoch", 86 | save_strategy="epoch", 87 | save_total_limit=2, 88 | logging_steps=100, 89 | metric_for_best_model="f1", # to override loss for model selection 90 | greater_is_better=True, # due to F1 score 91 | load_best_model_at_end=True, 92 | ) 93 | 94 | # 5. Create the trainer & start training 95 | trainer = HierarchyTransformerTrainer( 96 | model=model, 97 | args=args, 98 | train_dataset=triplet_dataset["train"], # train loss requires triplets 99 | eval_dataset=triplet_dataset["val"], # val loss requires triplets 100 | loss=hit_loss, 101 | evaluator=val_evaluator, # actual eval requires labelled pairs 102 | ) 103 | trainer.train() 104 | 105 | # 6. Evaluate the model performance on the test dataset 106 | val_results = val_evaluator.results 107 | best_val = val_results.loc[val_results["f1"].idxmax()] 108 | best_val_centri_weight = float(best_val["centri_weight"]) 109 | best_val_threshold = float(best_val["threshold"]) 110 | test_evaluator = HierarchyTransformerEvaluator( 111 | child_entities=pair_dataset["test"]["child"], 112 | parent_entities=pair_dataset["test"]["parent"], 113 | labels=pair_dataset["test"]["label"], 114 | batch_size=config.eval_batch_size, 115 | truth_label=1, 116 | ) 117 | test_evaluator( 118 | model=model, 119 | output_path=os.path.join(output_dir, "eval"), 120 | best_centri_weight=best_val_centri_weight, 121 | best_threshold=best_val_threshold, 122 | ) 123 | 124 | # 7. Save the trained & evaluated model locally 125 | final_output_dir = f"{output_dir}/final" 126 | model.save(final_output_dir) 127 | 128 | 129 | if __name__ == "__main__": 130 | main() 131 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/datasets/load.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import logging 17 | import os 18 | 19 | from datasets import Dataset, load_dataset 20 | from tqdm import tqdm 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def load_hf_dataset(path: str, name: str | None = None, **config_kwargs): 26 | """Load a HiT dataset from Hugging Face. 27 | 28 | See available datasets on: https://huggingface.co/Hierarchy-Transformers 29 | 30 | Args: 31 | path (str): Dataset path on Hugging Face. 32 | name (Optional[str]): Name of a specific subset if any. Defaults to `None`. 33 | """ 34 | return load_dataset(path, name, **config_kwargs) 35 | 36 | 37 | def load_zenodo_dataset( 38 | path: str, 39 | entity_lexicon_or_index: dict, 40 | negative_type: str = "random", 41 | example_type: str = "triplet", 42 | ): 43 | """Load a HiT dataset from a local version downloaded from Zenodo. 44 | 45 | It is recommended to use `load_hf_dataset` from this library or `load_dataset` from HuggingFace datasets if one doesn't require the original entity IDs. 46 | 47 | See available datasets on: https://doi.org/10.5281/zenodo.10511042 48 | 49 | Args: 50 | path (str): Path to a local dataset downloaded from Zenodo. 51 | entity_lexicon_or_index (dict): A dictionary to transform entity IDs to names required by langauge models or indices (one-hot encoding) required by the static hierarchy models. 52 | negative_type (str): Type of negative examples. Options are `['random', 'hard']`. 53 | example_type (str): Type of example structure. Options are `['triplet', 'pair', 'idx']`. 54 | """ 55 | 56 | assert negative_type in ["random", "hard"], f"Unknown negative type '{negative_type}'." 57 | assert example_type in ["triplet", "pair", "idx"], f"Unknown example type '{example_type}'." 58 | assert entity_lexicon_or_index is not None, "The entity transformation dictionary is not found." 59 | 60 | # check if train, val, test splits are all there 61 | datafiles = dict() 62 | for split in ["train", "val", "test"]: 63 | split_path = os.path.join(path, f"{split}.jsonl") 64 | if os.path.isfile(split_path): 65 | datafiles[split] = split_path 66 | else: 67 | logger.info(f"No {split} split available.") 68 | 69 | # load the jsonl dataset altogther 70 | dataset = load_dataset("json", data_files=datafiles) 71 | 72 | transform = { 73 | "triplet": zenodo_example_to_triplets, 74 | "pair": zenodo_example_to_pairs, 75 | "idx": zenodo_example_to_idxs, 76 | }[example_type] 77 | 78 | for split, examples in dataset.items(): 79 | 80 | if example_type == "idx": 81 | # for static embedding model, inputs are not flattened 82 | dataset_split = [ 83 | transform(example, negative_type, entity_lexicon_or_index) 84 | for example in tqdm(examples, desc=f"Map ({split})", leave=True) 85 | ] 86 | else: 87 | # for other models, inputs are flattened 88 | dataset_split = [ 89 | transformed 90 | for example in tqdm(examples, desc=f"Map ({split})", leave=True) 91 | for transformed in transform(example, negative_type, entity_lexicon_or_index) 92 | ] 93 | dataset_split = Dataset.from_list(dataset_split) 94 | 95 | dataset[split] = dataset_split 96 | 97 | return dataset 98 | 99 | 100 | def zenodo_example_to_triplets(example: dict, negative_type: str, entity_lexicon: dict): 101 | """Helper function to present Zenodo dataset examples into triplets of the form `(child, parent, negative)`.""" 102 | 103 | child = entity_lexicon[example["child"]]["name"] 104 | parent = entity_lexicon[example["parent"]]["name"] 105 | negative_type = f"{negative_type}_negatives" 106 | negative_parents = [entity_lexicon[neg]["name"] for neg in example[negative_type]] 107 | return [{"child": child, "parent": parent, "negative": neg} for neg in negative_parents] 108 | 109 | 110 | def zenodo_example_to_pairs(example: dict, negative_type: str, entity_lexicon: dict): 111 | """Helper function to present Zenodo dataset examples into labelled pairs of the form `(child, parent, label)`.""" 112 | 113 | child = entity_lexicon[example["child"]]["name"] 114 | parent = entity_lexicon[example["parent"]]["name"] 115 | negative_type = f"{negative_type}_negatives" 116 | negative_parents = [entity_lexicon[neg]["name"] for neg in example[negative_type]] 117 | return [{"child": child, "parent": parent, "label": 1}] + [ 118 | {"child": child, "parent": neg, "label": 0} for neg in negative_parents 119 | ] 120 | 121 | 122 | def zenodo_example_to_idxs(example: dict, negative_type: str, entity_to_indices: dict): 123 | """Helper function to present Zenodo dataset examples into an entity index list of `(child_idx, paren_idx, *negative_idxs)`.""" 124 | 125 | child = entity_to_indices[example["child"]] 126 | parent = entity_to_indices[example["parent"]] 127 | negative_type = f"{negative_type}_negatives" 128 | negative_parents = [entity_to_indices[neg] for neg in example[negative_type]] 129 | return [child, parent] + negative_parents 130 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import math 17 | 18 | import torch 19 | from tqdm import tqdm 20 | 21 | 22 | def f1_score(predictions: torch.Tensor, labels: torch.Tensor, truth_label: int = 1): 23 | """Pytorch tensor implementation of `f1_score` computation. 24 | 25 | Args: 26 | predictions (torch.Tensor): Predictions. 27 | labels (torch.Tensor): Reference labels. 28 | truth_label (int, optional): Specify which label represents the truth. Defaults to `1`. 29 | 30 | Returns: 31 | results (dict): result dictionary containing `Precision`, `Recall`, and `F1`. 32 | """ 33 | tp = torch.sum((labels == truth_label) & (predictions == truth_label)) # correct and positive 34 | fp = torch.sum((labels != truth_label) & (predictions == truth_label)) # incorrect but positive 35 | fn = torch.sum((labels == truth_label) & (predictions != truth_label)) # correct but negative 36 | precision = tp / (tp + fp) 37 | recall = tp / (tp + fn) 38 | f1 = 2 * (precision * recall) / (precision + recall) 39 | return {"precision": precision.item(), "recall": recall.item(), "f1": f1.item()} # recall is ACC+ 40 | 41 | 42 | def accurarcy(predictions: torch.Tensor, labels: torch.Tensor): 43 | """Pytorch tensor implementation of `accuracy` computation.""" 44 | acc = torch.sum(labels == predictions) / len(labels) 45 | return {"accuracy": acc.item()} 46 | 47 | 48 | def accurarcy_on_negatives(predictions: torch.Tensor, labels: torch.Tensor, truth_label: int = 1): 49 | """Pytorch tensor implementation of `accuracy_on_negatives` computation. 50 | 51 | That is, it computes accuracy only w.r.t. negative samples (with `label != truth_label`). 52 | """ 53 | neg_acc = torch.sum((labels == predictions) & (labels != truth_label)) / torch.sum(labels != truth_label) 54 | return {"accuracy_on_negatives": neg_acc.item()} 55 | 56 | 57 | def evaluate_by_threshold( 58 | scores: torch.Tensor, 59 | labels: torch.Tensor, 60 | threshold: float, 61 | truth_label: int = 1, 62 | smaller_scores_better: bool = False, 63 | ): 64 | r"""Compute evaluation metrics (`Precision`, `Recall`, `F1`, `Accurarcy`, `Accurarcy-`) based on the threshold. 65 | 66 | Args: 67 | scores (torch.Tensor): Prediction scores. 68 | labels (torch.Tensor): Reference labels. 69 | threshold (float): Threshold that splits the positive and negative predictions. 70 | truth_label (int): Specify which label represents the truth. Defaults to `1`. 71 | smaller_scores_better (bool): Specify if smaller than threshold indicates positive or not. Defaults to `False`. 72 | """ 73 | 74 | # thresholding 75 | if smaller_scores_better: 76 | predictions = scores <= threshold 77 | else: 78 | predictions = scores > threshold 79 | # compute results 80 | results = { 81 | "threshold": threshold, 82 | **f1_score(predictions=predictions, labels=labels, truth_label=truth_label), 83 | **accurarcy(predictions=predictions, labels=labels), 84 | **accurarcy_on_negatives(predictions=predictions, labels=labels, truth_label=truth_label), 85 | } 86 | return results 87 | 88 | 89 | def grid_search( 90 | scores: torch.Tensor, 91 | labels: torch.Tensor, 92 | threshold_granularity: int = 100, 93 | truth_label: int = 1, 94 | smaller_scores_better: bool = False, 95 | primary_metric: str = "f1", 96 | best_primary_metric_value: float = -math.inf, 97 | preformatted_best_results: dict = {}, 98 | ): 99 | """Grid search the best scoring threshold. 100 | 101 | Args: 102 | scores (torch.Tensor): Prediction scores. 103 | labels (torch.Tensor): Reference labels. 104 | threshold_granularity (int, optional): A score scaling factor to determine the granularity of grid search. Defaults to `100`. 105 | truth_label (int): Specify which label represents the truth. Defaults to `1`. 106 | smaller_scores_better (bool): Specify if smaller than threshold indicates positive or not. Defaults to `False`. 107 | primary_metric (str, optional): The primary evaluation metric to determine the grid search result. Defaults to `"F1"`. 108 | best_primary_metric_value (Optional[float], optional): Best previous primary metric value. Defaults to `-math.inf`. 109 | preformatted_best_results (dict, optional): Preformatted best results dictionary. Defaults to `{}`. 110 | """ 111 | 112 | best_results = None 113 | 114 | # grid search start and end are confined by the prediction scores 115 | start = int(scores.min() * threshold_granularity) 116 | end = int(scores.max() * threshold_granularity) 117 | 118 | # grid search to update the best results 119 | for threshold in tqdm(range(start, end), desc="Thresholding"): 120 | threshold = threshold / threshold_granularity 121 | results = evaluate_by_threshold( 122 | scores=scores, 123 | labels=labels, 124 | threshold=threshold, 125 | truth_label=truth_label, 126 | smaller_scores_better=smaller_scores_better, 127 | ) 128 | if results[primary_metric] >= best_primary_metric_value: 129 | best_results = preformatted_best_results 130 | best_results.update(results) 131 | best_primary_metric_value = results[primary_metric] 132 | 133 | return best_results 134 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/datasets/construct.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import json 17 | from pathlib import Path 18 | 19 | from deeponto.onto import Taxonomy, TaxonomyNegativeSampler 20 | from deeponto.utils import save_file 21 | from sklearn.model_selection import train_test_split 22 | from tqdm.auto import tqdm 23 | 24 | 25 | class HierarchyDatasetConstructor: 26 | """ 27 | Class for contructing hierarchy understanding datasets. 28 | """ 29 | def __init__(self, hierarchy: Taxonomy): 30 | self.hierarchy = hierarchy 31 | self.neg_sampler = TaxonomyNegativeSampler(self.hierarchy) 32 | 33 | def get_hard_negative(self, entity_id: str): 34 | """ 35 | Get a hard negative subsumer (sibling) for the input entity. 36 | """ 37 | parents = self.hierarchy.get_parents(entity_id) 38 | ancestors = self.hierarchy.get_parents(entity_id, True) 39 | siblings = [] 40 | for parent in parents: 41 | siblings += self.hierarchy.get_children(parent) 42 | hard_negatives = set(siblings) - set([entity_id]) - set(ancestors) 43 | return list(hard_negatives) 44 | 45 | def get_transitive_edges(self, edges: list): 46 | """ 47 | Get all indirect subsumptions by transitive reasoning. 48 | """ 49 | trans_edges = [] 50 | for child, _ in edges: 51 | trans_edges += [(child, parent) for parent in self.hierarchy.get_parents(child, True)] 52 | return list(set(trans_edges) - set(edges)) 53 | 54 | def save_entity_lexicon(self, output_dir: str): 55 | """ 56 | Save the entity lexicon. 57 | """ 58 | entity_lexicon = dict() 59 | for n in self.hierarchy.nodes: 60 | entity_lexicon[n] = self.hierarchy.get_node_attributes(n) 61 | save_file(entity_lexicon, f"{output_dir}/entity_lexicon.json") 62 | 63 | def save_dataset(self, dataset: list, output_file: str): 64 | """ 65 | Save the constructed dataset. 66 | """ 67 | with open(f"{output_file}", "w+") as f: 68 | f.writelines("\n".join([json.dumps(sample) for sample in dataset])) 69 | 70 | def construct_example(self, child: str, parent: str, num_negative: int = 10): 71 | """ 72 | Construct negative examples given a positive `(child, parent)` pair. 73 | """ 74 | example = {"child": child, "parent": parent} 75 | example["random_negatives"] = self.neg_sampler.sample(child, num_negative) 76 | example["hard_negatives"] = (self.get_hard_negative(child) + example["random_negatives"])[:num_negative] 77 | return example 78 | 79 | def construct(self, output_dir: str, num_negative: int = 10, eval_size=0.1): 80 | """ 81 | Construct the multi-hop and mixed-hop datasets. 82 | """ 83 | Path(output_dir).mkdir(parents=True, exist_ok=True) 84 | 85 | base_edges = [(child, parent) for parent, child in self.hierarchy.edges] 86 | trans_edges = self.get_transitive_edges(base_edges) 87 | assert not set(trans_edges).intersection(set(base_edges)) 88 | 89 | base_examples = [] 90 | for child, parent in tqdm(base_edges, desc="base"): 91 | base_examples.append(self.construct_example(child, parent, num_negative)) 92 | 93 | trans_examples = [] 94 | for child, parent in tqdm(trans_edges, desc="trans"): 95 | trans_examples.append(self.construct_example(child, parent, num_negative)) 96 | 97 | _, trans_eval_examples = train_test_split(trans_examples, test_size=eval_size) 98 | trans_val_examples, trans_test_examples = train_test_split(trans_eval_examples, test_size=0.5) 99 | 100 | trans_task_name = "multi" 101 | Path(f"{output_dir}/{trans_task_name}").mkdir(parents=True, exist_ok=True) 102 | self.save_dataset(base_examples, f"{output_dir}/{trans_task_name}/train.jsonl") 103 | # self.save_dataset(trans_train_examples, f"{output_dir}/transitivity/trans_train.jsonl") 104 | self.save_dataset(trans_val_examples, f"{output_dir}/{trans_task_name}/val.jsonl") 105 | self.save_dataset(trans_test_examples, f"{output_dir}/{trans_task_name}/test.jsonl") 106 | self.save_entity_lexicon(f"{output_dir}/{trans_task_name}") 107 | 108 | base_train_examples, base_eval_examples = train_test_split(base_examples, test_size=eval_size) 109 | base_val_examples, base_test_examples = train_test_split(base_eval_examples, test_size=0.5) 110 | # base_train_edges = [(x["child"], x["parent"]) for x in base_train_examples] 111 | # trans_base_train_edges = self.get_transitive_edges(base_train_edges) 112 | # trans_base_train_examples = [] 113 | # for child, parent in tqdm(trans_base_train_edges, desc="trans on base_train"): 114 | # trans_base_train_examples.append(self.construct_example(child, parent, num_negative)) 115 | pred_task_name = "mixed" 116 | Path(f"{output_dir}/{pred_task_name}").mkdir(parents=True, exist_ok=True) 117 | self.save_dataset(base_train_examples, f"{output_dir}/{pred_task_name}/train.jsonl") 118 | # self.save_dataset(trans_base_train_examples, f"{output_dir}/induc/trans_base_train.jsonl") 119 | self.save_dataset(base_val_examples + trans_val_examples, f"{output_dir}/{pred_task_name}/val.jsonl") 120 | self.save_dataset(base_test_examples + trans_test_examples, f"{output_dir}/{pred_task_name}/test.jsonl") 121 | self.save_entity_lexicon(f"{output_dir}/{pred_task_name}") 122 | -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.min.js: -------------------------------------------------------------------------------- 1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaSlider=e():t.bulmaSlider=e()}("undefined"!=typeof self?self:this,function(){return function(n){var r={};function i(t){if(r[t])return r[t].exports;var e=r[t]={i:t,l:!1,exports:{}};return n[t].call(e.exports,e,e.exports,i),e.l=!0,e.exports}return i.m=n,i.c=r,i.d=function(t,e,n){i.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:n})},i.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return i.d(e,"a",e),e},i.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},i.p="",i(i.s=0)}([function(t,e,n){"use strict";Object.defineProperty(e,"__esModule",{value:!0}),n.d(e,"isString",function(){return l});var r=n(1),i=Object.assign||function(t){for(var e=1;e=l.length&&(s=!0)):s=!0),s&&(t.once&&(u[e]=null),t.callback(r))});-1!==u.indexOf(null);)u.splice(u.indexOf(null),1)}}]),e}();e.a=i}]).default}); -------------------------------------------------------------------------------- /scripts/training/static_embed/training_static.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import logging 17 | import os 18 | import shutil 19 | 20 | import click 21 | import torch 22 | from deeponto.utils import create_path, load_file, set_seed 23 | from yacs.config import CfgNode 24 | 25 | from hierarchy_transformers.datasets import load_zenodo_dataset 26 | from hierarchy_transformers.evaluation import PoincareStaticEmbeddingEvaluator 27 | from hierarchy_transformers.losses import HyperbolicEntailmentConeStaticLoss, PoincareEmbeddingStaticLoss 28 | from hierarchy_transformers.models import PoincareStaticEmbedding, PoincareStaticEmbeddingTrainer 29 | from hierarchy_transformers.utils import get_torch_device 30 | 31 | logger = logging.getLogger(__name__) 32 | logger.setLevel(logging.INFO) 33 | 34 | 35 | @click.command() 36 | @click.option("-c", "--config_file", type=click.Path(exists=True)) 37 | @click.option("-g", "--gpu_id", type=int, default=0) 38 | def main(config_file: str, gpu_id: int): 39 | # 0. set seed, load config, and format output dir 40 | set_seed(8888) 41 | config = CfgNode(load_file(config_file)) 42 | dataset_path_suffix = config.dataset_path.split(os.path.sep)[-1] 43 | output_dir = f"experiments/PoincareStatic-{dataset_path_suffix}-{config.negative_type}" 44 | create_path(output_dir) 45 | try: 46 | shutil.copy2(config_file, os.path.join(output_dir, "config.yaml")) 47 | except Exception: 48 | pass 49 | 50 | # 1. Load dataset and pre-trained model 51 | entity_lexicon = load_file(os.path.join(config.dataset_path, "entity_lexicon.json")) 52 | model = PoincareStaticEmbedding(list(entity_lexicon.keys()), embed_dim=config.embed_dim) 53 | print(model) 54 | dataset = load_zenodo_dataset( 55 | path=config.dataset_path, 56 | entity_lexicon_or_index=model.ent2idx, 57 | negative_type=config.negative_type, 58 | example_type="idx", 59 | ) 60 | 61 | # 2. set up the loss function 62 | poincare_embed_loss = PoincareEmbeddingStaticLoss(model.manifold) 63 | 64 | # 3. Create the trainer & start training 65 | logger.info("Train Poincare embedding on the hyperbolic distance loss...") 66 | device = get_torch_device(gpu_id) 67 | trainer = PoincareStaticEmbeddingTrainer( 68 | model=model, 69 | train_dataset=dataset["train"], 70 | loss=poincare_embed_loss, 71 | num_train_epochs=int(config.num_train_epochs), 72 | learning_rate=float(config.learning_rate), 73 | train_batch_size=int(config.train_batch_size), 74 | warmup_epochs=int(config.warmup_epochs), 75 | ) 76 | trainer.train(device=device) 77 | torch.save(trainer.model, os.path.join(output_dir, "poincare_static.pt")) 78 | 79 | # 4. Evaluate the model performance on validation and test datasets 80 | create_path(os.path.join(output_dir, "eval_poincare")) 81 | val_evaluator = PoincareStaticEmbeddingEvaluator( 82 | eval_examples=dataset["val"], batch_size=config.eval_batch_size, truth_label=1 83 | ) 84 | val_evaluator( 85 | model=trainer.model, 86 | loss=trainer.loss, 87 | device=device, 88 | epoch="validation", 89 | output_path=os.path.join(output_dir, "eval_poincare"), 90 | ) 91 | val_results = val_evaluator.results 92 | best_val = val_results.loc[val_results["f1"].idxmax()] 93 | best_val_threshold = float(best_val["threshold"]) 94 | test_evaluator = PoincareStaticEmbeddingEvaluator( 95 | eval_examples=dataset["test"], batch_size=config.eval_batch_size, truth_label=1 96 | ) 97 | test_evaluator( 98 | model=trainer.model, 99 | loss=trainer.loss, 100 | device=device, 101 | output_path=os.path.join(output_dir, "eval_poincare"), 102 | best_threshold=best_val_threshold, 103 | ) 104 | 105 | # 5. Create the trainer & start post-training 106 | if int(config.num_post_train_epochs) > 0: 107 | logger.info("Post-train Poincare embedding on the hyperbolic entailment cone loss...") 108 | # set-up the cone loss for post-training 109 | hyperbolic_cone_loss = HyperbolicEntailmentConeStaticLoss(model.manifold) 110 | post_trainer = PoincareStaticEmbeddingTrainer( 111 | model=trainer.model, # continue to train 112 | train_dataset=dataset["train"], 113 | loss=hyperbolic_cone_loss, 114 | num_train_epochs=int(config.num_post_train_epochs), 115 | learning_rate=float(config.learning_rate), 116 | train_batch_size=int(config.train_batch_size), 117 | warmup_epochs=int(config.warmup_epochs), 118 | ) 119 | post_trainer.train(device=device) 120 | torch.save(post_trainer.model, os.path.join(output_dir, "hypercone_static.pt")) 121 | 122 | # 6. Evaluate the post-trained model performance on validation and test datasets 123 | create_path(os.path.join(output_dir, "eval_hypercone")) 124 | val_evaluator = PoincareStaticEmbeddingEvaluator( 125 | eval_examples=dataset["val"], batch_size=config.eval_batch_size, truth_label=1 126 | ) 127 | val_evaluator( 128 | model=post_trainer.model, 129 | loss=post_trainer.loss, 130 | device=device, 131 | epoch="validation", 132 | output_path=os.path.join(output_dir, "eval_hypercone"), 133 | ) 134 | val_results = val_evaluator.results 135 | best_val = val_results.loc[val_results["f1"].idxmax()] 136 | best_val_threshold = float(best_val["threshold"]) 137 | test_evaluator = PoincareStaticEmbeddingEvaluator( 138 | eval_examples=dataset["test"], batch_size=config.eval_batch_size, truth_label=1 139 | ) 140 | test_evaluator( 141 | model=post_trainer.model, 142 | loss=post_trainer.loss, 143 | device=device, 144 | output_path=os.path.join(output_dir, "eval_hypercone"), 145 | best_threshold=best_val_threshold, 146 | ) 147 | 148 | 149 | if __name__ == "__main__": 150 | main() 151 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/losses/hyper_cone_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from geoopt.manifolds import PoincareBall 19 | 20 | from hierarchy_transformers.utils import format_citation 21 | 22 | 23 | class HyperbolicEntailmentConeLoss(torch.nn.Module): 24 | """Hyperbolic loss that construct entailment cones for entities. 25 | 26 | Essentially, this loss is expected to achieve: 27 | $$ 28 | angle(child, parent_cone_axis) < angle(parent_cone) 29 | $$ 30 | 31 | Inputs are labelled pairs presented in `(rep_anchor, rep_other, label)`. 32 | """ 33 | 34 | def __init__(self, manifold: PoincareBall, min_euclidean_norm: float = 0.1, margin: float = 0.1, eps: float = 1e-5): 35 | super().__init__() 36 | self.manifold = manifold 37 | assert self.manifold.c == 1.0, f"Entailment cone loss is not defined for curvature: {manifold.c}." 38 | self.min_euclidean_norm = min_euclidean_norm 39 | self.margin = margin 40 | self.eps = eps 41 | 42 | def get_config_dict(self): 43 | config = {"distance_metric": "PoincareBall(c=1.0).cone_angle", "margin": self.margin} 44 | return config 45 | 46 | def half_cone_aperture(self, cone_tip: torch.Tensor): 47 | """Angle between the axis [0, x] (line through 0 and x) and the boundary of the cone at x, 48 | where x is the cone tip. 49 | """ 50 | # cone tip means the point x is the tip of the hyperbolic cone 51 | # norm_tip = cone_tip.norm(dim=-1).clamp(min=self.min_euclidean_norm) # to prevent undefined aperture 52 | sq_norm_tip = cone_tip.pow(2).sum(dim=-1).clamp(min=self.min_euclidean_norm + self.eps, max=1 - self.eps) 53 | return torch.arcsin(self.min_euclidean_norm * (1 - sq_norm_tip) / torch.sqrt(sq_norm_tip)).clamp( 54 | min=-1 + self.eps, max=1 - self.eps 55 | ) 56 | 57 | def cone_angle_at_u(self, cone_tip: torch.Tensor, u: torch.Tensor): 58 | """Angle between the axis [0, x] and the line [x, u]. This angle should be smaller than the 59 | half cone aperture at x for real children. 60 | """ 61 | # parent point is treated as the cone tip 62 | norm_tip = cone_tip.norm(2, dim=-1) 63 | norm_child = u.norm(2, dim=-1) 64 | dot_prod = (cone_tip * u).sum(dim=-1) 65 | edist = (cone_tip - u).norm(2, dim=-1) # euclidean distance 66 | numerator = dot_prod * (1 + norm_tip**2) - norm_tip**2 * (1 + norm_child**2) 67 | denominator = norm_tip * edist * torch.sqrt(1 + (norm_child**2) * (norm_tip**2) - 2 * dot_prod) 68 | 69 | angle = torch.arccos((numerator / denominator.clamp(min=self.eps)).clamp(min=-1 + self.eps, max=1 - self.eps)) 70 | # Debugging step 71 | if torch.isnan(angle).any(): 72 | print("Numerator:", numerator) 73 | print("Denominator:", denominator) 74 | print("Angle calculation resulted in NaNs") 75 | 76 | return angle 77 | 78 | def energy(self, cone_tip: torch.Tensor, u: torch.Tensor): 79 | """Enery function defined as: max(0, cone_angle(u) - half_cone_aperture) given a cone tip.""" 80 | return F.relu(self.cone_angle_at_u(cone_tip, u) - self.half_cone_aperture(cone_tip)) 81 | 82 | def forward(self, rep_anchor: torch.Tensor, rep_other: torch.Tensor, labels: torch.Tensor): 83 | # anchors are children 84 | energies = self.energy(cone_tip=rep_other, u=rep_anchor) 85 | cone_loss = labels.float() * energies + (1 - labels).float() * F.relu(self.margin - energies) 86 | return cone_loss.mean() 87 | 88 | @property 89 | def citation(self) -> str: 90 | return format_citation( 91 | """ 92 | @inproceedings{ganea2018hyperbolic, 93 | title={Hyperbolic entailment cones for learning hierarchical embeddings}, 94 | author={Ganea, Octavian and B{\'e}cigneul, Gary and Hofmann, Thomas}, 95 | booktitle={International conference on machine learning}, 96 | pages={1646--1655}, 97 | year={2018}, 98 | organization={PMLR} 99 | } 100 | """ 101 | ) 102 | 103 | 104 | class HyperbolicEntailmentConeTripletLoss(HyperbolicEntailmentConeLoss): 105 | """Hyperbolic loss that construct entailment cones for entities. 106 | 107 | Essentially, this loss is expected to achieve: 108 | $$ 109 | angle(child, parent_cone_axis) < angle(parent_cone) 110 | $$ 111 | 112 | Inputs are triplets presented in `(rep_anchor, rep_positive, rep_negative)`. 113 | """ 114 | 115 | def __init__(self, manifold: PoincareBall, min_euclidean_norm: float = 0.1, margin: float = 0.1, eps: float = 1e-5): 116 | super().__init__(manifold, min_euclidean_norm, margin, eps) 117 | 118 | def forward(self, rep_anchor: torch.Tensor, rep_positive: torch.Tensor, rep_negative: torch.Tensor): 119 | # anchors are children 120 | energies_positive = self.energy(cone_tip=rep_positive, u=rep_anchor) 121 | energies_negative = self.energy(cone_tip=rep_negative, u=rep_anchor) 122 | cone_triplet_loss = F.relu(energies_positive - energies_negative + self.margin) 123 | return cone_triplet_loss.mean() 124 | 125 | 126 | class HyperbolicEntailmentConeStaticLoss(HyperbolicEntailmentConeLoss): 127 | """Hyperbolic loss that construct entailment cones for entities. 128 | 129 | Essentially, this loss is expected to achieve: 130 | $$ 131 | angle(child, parent_cone_axis) < angle(parent_cone) 132 | $$ 133 | 134 | Inputs are presented in `(subject, *objects)` where the first `object` is positive and the rest are negative. 135 | 136 | This is designed for the static embedding implementation. 137 | """ 138 | 139 | def __init__(self, manifold: PoincareBall, min_euclidean_norm: float = 0.1, margin: float = 0.1, eps: float = 1e-5): 140 | super().__init__(manifold, min_euclidean_norm, margin, eps) 141 | 142 | def forward(self, subject: torch.Tensor, objects: torch.Tensor): 143 | # the first object is positive 144 | energy = self.energy(objects, subject) 145 | return (energy[:, 0].sum() + F.relu(self.margin - energy[:, 1:]).sum()) / torch.numel(energy) 146 | 147 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/evaluation/sbert_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import logging 17 | import os.path 18 | import warnings 19 | from string import Template 20 | 21 | import pandas as pd 22 | import torch 23 | from sentence_transformers import SentenceTransformer 24 | from sentence_transformers.evaluation import SentenceEvaluator 25 | 26 | from .metrics import evaluate_by_threshold, grid_search 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class SentenceTransformerEvaluator(SentenceEvaluator): 32 | """Evaluating sBERT models for predicting entity hierarchical relationships. 33 | 34 | The main evaluation metrics are `Precision`, `Recall`, and `F-score`, with overall accuracy (`ACC`) and accuracy on negatives (`ACC-`) additionally reported. The results are written in a `.csv`. If a result file already exists, then values are appended. 35 | 36 | The labels need to be `0` for unrelated pairs and `1` for related pairs. 37 | 38 | Args: 39 | child_entities (list[str]): List of child entity names. 40 | parent_entities (list[str]): List of parent entity names. 41 | labels (list[int]): List of reference labels. 42 | batch_size (int): Evaluation batch size. 43 | truth_label (int, optional): Specify which label represents the truth. Defaults to `1`. 44 | template (str, optional): The probing template. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | child_entities: list[str], 50 | parent_entities: list[str], 51 | labels: list[int], 52 | batch_size: int, 53 | truth_label: int = 1, 54 | template: str = "${child} is a ${parent}.", 55 | ): 56 | super().__init__() 57 | # set primary metric for model selection 58 | self.primary_metric = "f1" 59 | # input evaluation examples 60 | self.child_entities = child_entities 61 | self.parent_entities = parent_entities 62 | self.labels = labels 63 | # eval batch size 64 | self.batch_size = batch_size 65 | # truth reference label 66 | self.truth_label = truth_label 67 | # result file 68 | self.results = pd.DataFrame( 69 | columns=["threshold", "precision", "recall", "f1", "accuracy", "accuracy_on_negatives"] 70 | ) 71 | # template for probing 72 | self.template = Template(template) 73 | 74 | def inference(self, model: SentenceTransformer): 75 | """The probing method of the pre-trained sBERT model. It output scores that indicate hierarchical relationships between entities.""" 76 | sentences = [] 77 | masked_sentences = [] 78 | for child, parent in zip(self.child_entities, self.parent_entities): 79 | sentences.append(self.template.substitute(child=child, parent=parent)) 80 | masked_sentences.append(self.template.substitute(child=child, parent=model.tokenizer.mask_token)) 81 | 82 | sentence_embeds = model.encode(sentences=sentences, convert_to_tensor=True, show_progress_bar=True) 83 | masked_embeds = model.encode(sentences=masked_sentences, convert_to_tensor=True, show_progress_bar=True) 84 | 85 | # use the cosine similarity between masked and 86 | return torch.cosine_similarity(masked_embeds, sentence_embeds) 87 | 88 | def __call__( 89 | self, 90 | model: SentenceTransformer, 91 | output_path: str | None = None, 92 | epoch: int = -1, 93 | steps: int = -1, 94 | best_threshold: float | None = None, 95 | ) -> dict[str, float]: 96 | """Compute the evaluation metrics for the given model. 97 | 98 | Args: 99 | model (HierarchyTransformer): The model to evaluate. 100 | output_path (str, optional): Path to save the evaluation results `.csv` file. Defaults to `None`. 101 | epoch (int, optional): The epoch number. Defaults to `-1`. 102 | steps (int, optional): The number of steps. Defaults to `-1`. 103 | best_centri_weight (float, optional): The best centripetal score weight searched on a validation set (used for testing). Defaults to `None`. 104 | best_threshold (float, optional): The best overall threshold searched on a validation set (used for testing). Defaults to `None`. 105 | 106 | Returns: 107 | Dict[str, float]: A dictionary containing the evaluation metrics. 108 | """ 109 | 110 | if best_threshold: 111 | # Testing with pre-defined hyperparameters 112 | logger.info(f"Evaluate on given hyperparemeters `best_threshold={best_threshold}`.") 113 | 114 | # Compute the scores 115 | scores = self.inference(model=model) 116 | 117 | # Compute the evaluation metrics 118 | best_results = evaluate_by_threshold( 119 | scores=scores, 120 | labels=torch.tensor(self.labels).to(scores.device), 121 | threshold=best_threshold, 122 | truth_label=self.truth_label, 123 | smaller_scores_better=False, 124 | ) 125 | 126 | # log the results 127 | if os.path.exists(os.path.join(output_path, "results.tsv")): 128 | self.results = pd.read_csv(os.path.join(output_path, "results.tsv"), sep="\t", index_col=0) 129 | else: 130 | warnings.warn("No previous `results.tsv` detected.") 131 | self.results.loc["testing"] = best_results 132 | else: 133 | # Validation with no pre-defined hyerparameters 134 | logger.info("Evaluate with grid search on hyperparameters `best_threshold` (overall threshold)") 135 | best_f1 = -1.0 136 | best_results = None 137 | 138 | # Compute the scores 139 | scores = self.inference(model=model) 140 | 141 | # Perform grid search on hyperparameters 142 | cur_best_results = grid_search( 143 | scores=scores, 144 | labels=torch.tensor(self.labels).to(scores.device), 145 | threshold_granularity=1000, 146 | truth_label=self.truth_label, 147 | smaller_scores_better=False, 148 | primary_metric="f1", 149 | best_primary_metric_value=best_f1, 150 | preformatted_best_results={}, 151 | ) 152 | if cur_best_results: 153 | best_results = cur_best_results 154 | best_f1 = best_results["f1"] 155 | 156 | idx = f"epoch={epoch}" if epoch != "validation" else epoch 157 | self.results.loc[idx] = best_results 158 | 159 | self.results.to_csv(os.path.join(output_path, "results.tsv"), sep="\t") 160 | 161 | logger.info(f"Eval results: {best_results}") 162 | 163 | return best_results 164 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 16 | 17 | 18 |

19 | 20 | license 21 | 22 |

23 | 24 |

25 | 26 | license 27 | 28 | 29 | docs 30 | 31 | 32 | pypi 33 | 34 |

35 | 36 |

37 |

38 | Project | 39 | HuggingFace | 40 | arXiv | 41 | Zenodo 42 |

43 |

44 | 45 |

46 |

Embedding hierarchies with language models.

47 |

48 | 49 | **News** ([changelog](docs/changelog.md)) :newspaper: 50 | 51 | - [X] Refactor code and add customised HiT trainer (**v0.1.1**). 52 | - [X] Significant development to align with `sentence-transformers>=3.4.0.dev0` (**v0.1.0**). 53 | - [X] Project page is now available ([click](https://krr-oxford.github.io/HierarchyTransformers/)). 54 | - [X] Initial release (should work with `sentence-transformers<3.0.0` ) and bug fix. (**v0.0.3**) 55 | 56 | 57 | ## About 58 | 59 | Hierarchy Transformer (HiT) is a framework that enables transformer encoder-based language models (LMs) to learn hierarchical structures in hyperbolic space. The main idea is to construct a Poincaré ball that directly circumscribes the output embedding space of LMs,leveraging the exponential expansion of hyperbolic space to organise entity embeddings hierarchically. In addition to presenting this framework (see code on [GitHub](https://github.com/KRR-Oxford/HierarchyTransformers)), we are committed to training and releasing HiT models across various hierachiies. The models and datasets will be accessible on [HuggingFace](https://huggingface.co/Hierarchy-Transformers/). 60 | 61 | ## Installation 62 | 63 | ### Main Dependencies 64 | 65 | This repository follows a similar layout as the [`sentence-transformers`](https://www.sbert.net/index.html) library. The main model directly extends the sentence transformer architecture. We also utilise [`deeponto`](https://krr-oxford.github.io/DeepOnto/) for extracting hierarchies from source data and constructing datasets from hierarchies, and [`geoopt`](https://geoopt.readthedocs.io/en/latest/index.html) for arithmetic in hyperbolic space. 66 | 67 | > The current release of `sentence-transformers=3.3.1` contains bugs during evaluation, which were fixed in their GitHub dev version `sentence-transformers=3.4.0.dev0`, please update the dependency manually until the official `3.4.0` is released. 68 | 69 | ### Install from PyPI 70 | 71 | ```bash 72 | # requiring Python>=3.9 73 | pip install hierarchy_transformers 74 | ``` 75 | 76 | ### Install from GitHub 77 | 78 | ```bash 79 | pip install git+https://github.com/KRR-Oxford/HierarchyTransformers.git 80 | ``` 81 | 82 | ## Huggingface Hub 83 | 84 | Our HiT models and datasets are released on the [HuggingFace Hub](https://huggingface.co/Hierarchy-Transformers). 85 | 86 | ### Get Started 87 | 88 | ```python 89 | from hierarchy_transformers import HierarchyTransformer 90 | 91 | # load the model 92 | model = HierarchyTransformer.from_pretrained('Hierarchy-Transformers/HiT-MiniLM-L12-WordNetNoun') 93 | 94 | # entity names to be encoded. 95 | entity_names = ["computer", "personal computer", "fruit", "berry"] 96 | 97 | # get the entity embeddings 98 | entity_embeddings = model.encode(entity_names) 99 | ``` 100 | 101 | ### Default Probing for Subsumption Prediction 102 | 103 | Use the entity embeddings to predict the subsumption relationships between them. 104 | 105 | ```python 106 | # suppose we want to compare "personal computer" and "computer", "berry" and "fruit" 107 | child_entity_embeddings = model.encode(["personal computer", "berry"], convert_to_tensor=True) 108 | parent_entity_embeddings = model.encode(["computer", "fruit"], convert_to_tensor=True) 109 | 110 | # compute the hyperbolic distances and norms of entity embeddings 111 | dists = model.manifold.dist(child_entity_embeddings, parent_entity_embeddings) 112 | child_norms = model.manifold.dist0(child_entity_embeddings) 113 | parent_norms = model.manifold.dist0(parent_entity_embeddings) 114 | 115 | # use the empirical function for subsumption prediction proposed in the paper 116 | # `centri_score_weight` and the overall threshold are determined on the validation set 117 | subsumption_scores = - (dists + centri_score_weight * (parent_norms - child_norms)) 118 | ``` 119 | 120 | ### Train Your Own Models 121 | 122 | Use the example scripts in our [repository](https://github.com/KRR-Oxford/HierarchyTransformers/tree/main/scripts) to reproduce existing models and train/evaluate your own models. 123 | 124 | ## License 125 | 126 | 127 | 128 | Copyright 2023 Yuan He. 129 | All rights reserved. 130 | 131 | Licensed under the Apache License, Version 2.0 (the "License"); 132 | you may not use this file except in compliance with the License. 133 | You may obtain a copy of the License at ** 134 | 135 | Unless required by applicable law or agreed to in writing, software 136 | distributed under the License is distributed on an "AS IS" BASIS, 137 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 138 | See the License for the specific language governing permissions and 139 | limitations under the License. 140 | 141 | ## Citation 142 | 143 | If you find this repository or the released models useful, please cite our publication: 144 | 145 | *Yuan He, Zhangdie Yuan, Jiaoyan Chen, Ian Horrocks.* **Language Models as Hierarchy Encoders.** Advances in Neural Information Processing Systems 37 (NeurIPS 2024). /[arxiv](https://arxiv.org/abs/2401.11374)/ /[neurips](https://neurips.cc/virtual/2024/poster/95913)/ 146 | 147 | ``` 148 | @inproceedings{NEURIPS2024_1a970a3e, 149 | author = {He, Yuan and Yuan, Moy and Chen, Jiaoyan and Horrocks, Ian}, 150 | booktitle = {Advances in Neural Information Processing Systems}, 151 | editor = {A. Globerson and L. Mackey and D. Belgrave and A. Fan and U. Paquet and J. Tomczak and C. Zhang}, 152 | pages = {14690--14711}, 153 | publisher = {Curran Associates, Inc.}, 154 | title = {Language Models as Hierarchy Encoders}, 155 | url = {https://proceedings.neurips.cc/paper_files/paper/2024/file/1a970a3e62ac31c76ec3cea3a9f68fdf-Paper-Conference.pdf}, 156 | volume = {37}, 157 | year = {2024} 158 | } 159 | ``` 160 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/losses/hit_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import logging 17 | from collections.abc import Iterable 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | from geoopt.manifolds import PoincareBall 22 | 23 | from hierarchy_transformers.models import HierarchyTransformer 24 | from hierarchy_transformers.utils import format_citation 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class HierarchyTransformerLoss(torch.nn.Module): 30 | """Hyperbolic loss that linearly combines hperbolic clustering loss and hyperbolic Centripetal loss and applies weights for joint optimisation.""" 31 | 32 | def __init__( 33 | self, 34 | model: HierarchyTransformer, 35 | clustering_loss_weight: float = 1.0, 36 | clustering_loss_margin: float = 5.0, 37 | centripetal_loss_weight: float = 1.0, 38 | centripetal_loss_margin: float = 0.5, 39 | ): 40 | super().__init__() 41 | 42 | self.model = model 43 | self.manifold = self.model.manifold 44 | self.cluster_loss = HyperbolicClusteringLoss(self.model.manifold, clustering_loss_margin) 45 | self.centri_loss = HyperbolicCentripetalLoss(self.model.manifold, centripetal_loss_margin) 46 | self.cluster_weight = clustering_loss_weight 47 | self.centri_weight = centripetal_loss_weight 48 | 49 | def get_config_dict(self): 50 | # distance_metric_name = self.distance_metric.__name__ 51 | config = {"distance_metric": f"PoincareBall(c={self.manifold.c}).dist and dist0"} 52 | config[HyperbolicClusteringLoss.__name__] = { 53 | "weight": self.cluster_weight, 54 | **self.cluster_loss.get_config_dict(), 55 | } 56 | config[HyperbolicCentripetalLoss.__name__] = { 57 | "weight": self.centri_weight, 58 | **self.centri_loss.get_config_dict(), 59 | } 60 | return config 61 | 62 | def forward(self, sentence_features: Iterable[dict[str, torch.Tensor]], labels: torch.Tensor): 63 | """Forward propagation that follows [`sentence_transformers.losses`](https://github.com/UKPLab/sentence-transformers/tree/master/sentence_transformers/losses).""" 64 | reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features] 65 | assert len(reps) == 3 66 | rep_anchor, rep_positive, rep_negative = reps 67 | 68 | # compute and combine hyperbolic clustering and centripetal losses 69 | cluster_loss = self.cluster_loss(rep_anchor, rep_positive, rep_negative) 70 | centri_loss = self.centri_loss(rep_anchor, rep_positive, rep_negative) 71 | combined_loss = self.cluster_weight * cluster_loss + self.centri_weight * centri_loss 72 | 73 | return { 74 | "loss": combined_loss, 75 | "cluster_loss": cluster_loss, 76 | "centri_loss": centri_loss, 77 | } 78 | 79 | @property 80 | def citation(self) -> str: 81 | return format_citation( 82 | """ 83 | @article{he2024language, 84 | title={Language models as hierarchy encoders}, 85 | author={He, Yuan and Yuan, Zhangdie and Chen, Jiaoyan and Horrocks, Ian}, 86 | journal={arXiv preprint arXiv:2401.11374}, 87 | year={2024} 88 | } 89 | """ 90 | ) 91 | 92 | 93 | class HyperbolicClusteringLoss(torch.nn.Module): 94 | r"""Hyperbolic loss that clusters entities in subsumptions. 95 | 96 | Essentially, this loss is expected to achieve: 97 | 98 | $$d(child, parent) < d(child, negative)$$ 99 | 100 | Inputs are presented in `(rep_anchor, rep_positive, rep_negative)`. 101 | """ 102 | 103 | def __init__(self, manifold: PoincareBall, margin: float): 104 | super().__init__() 105 | self.manifold = manifold 106 | self.margin = margin 107 | 108 | def get_config_dict(self): 109 | config = { 110 | "distance_metric": f"PoincareBall(c={self.manifold.c}).dist", 111 | "margin": self.margin, 112 | } 113 | return config 114 | 115 | def forward(self, rep_anchor: torch.Tensor, rep_positive: torch.Tensor, rep_negative: torch.Tensor): 116 | """Forward propagation. 117 | 118 | Args: 119 | rep_anchor (torch.Tensor): The input tensor for child entities. 120 | rep_positive (torch.Tensor): The input tensor for parent entities. 121 | rep_negative (torch.Tensor): The input tensor for negative parent entities. 122 | """ 123 | distances_positive = self.manifold.dist(rep_anchor, rep_positive) 124 | distances_negative = self.manifold.dist(rep_anchor, rep_negative) 125 | cluster_triplet_loss = F.relu(distances_positive - distances_negative + self.margin) 126 | return cluster_triplet_loss.mean() 127 | 128 | @property 129 | def citation(self) -> str: 130 | return format_citation( 131 | """ 132 | @article{he2024language, 133 | title={Language models as hierarchy encoders}, 134 | author={He, Yuan and Yuan, Zhangdie and Chen, Jiaoyan and Horrocks, Ian}, 135 | journal={arXiv preprint arXiv:2401.11374}, 136 | year={2024} 137 | } 138 | """ 139 | ) 140 | 141 | 142 | class HyperbolicCentripetalLoss(torch.nn.Module): 143 | r"""Hyperbolic loss that regulates the norms of child and parent entities. 144 | 145 | Essentially, this loss is expected to achieve: 146 | 147 | $$d(child, origin) > d(parent, origin)$$ 148 | 149 | Inputs are presented in `(rep_anchor, rep_positive, rep_negative)` but only `(rep_anchor, rep_positive)` pairs are involved in this loss. 150 | """ 151 | 152 | def __init__(self, manifold: PoincareBall, margin: float): 153 | super().__init__() 154 | self.manifold = manifold 155 | self.margin = margin 156 | 157 | def get_config_dict(self): 158 | config = { 159 | "distance_metric": f"PoincareBall(c={self.manifold.c}).dist0", 160 | "margin": self.margin, 161 | } 162 | return config 163 | 164 | def forward(self, rep_anchor: torch.Tensor, rep_positive: torch.Tensor, rep_negative: torch.Tensor): 165 | """Forward propagation. 166 | 167 | Args: 168 | rep_anchor (torch.Tensor): The input tensor for child entities. 169 | rep_positive (torch.Tensor): The input tensor for parent entities. 170 | rep_negative (torch.Tensor): The input tensor for negative parent entities (actually not required in this loss). 171 | """ 172 | rep_anchor_hyper_norms = self.manifold.dist0(rep_anchor) 173 | rep_positive_hyper_norms = self.manifold.dist0(rep_positive) 174 | # child further than parent w.r.t. origin 175 | centri_triplet_loss = F.relu(self.margin + rep_positive_hyper_norms - rep_anchor_hyper_norms) 176 | return centri_triplet_loss.mean() 177 | 178 | @property 179 | def citation(self) -> str: 180 | return format_citation( 181 | """ 182 | @article{he2024language, 183 | title={Language models as hierarchy encoders}, 184 | author={He, Yuan and Yuan, Zhangdie and Chen, Jiaoyan and Horrocks, Ian}, 185 | journal={arXiv preprint arXiv:2401.11374}, 186 | year={2024} 187 | } 188 | """ 189 | ) 190 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/evaluation/static_embed_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import logging 17 | import os.path 18 | import warnings 19 | 20 | import pandas as pd 21 | import torch 22 | from torch.utils.data import DataLoader 23 | 24 | from hierarchy_transformers.losses import HyperbolicEntailmentConeStaticLoss, PoincareEmbeddingStaticLoss 25 | from hierarchy_transformers.models import PoincareStaticEmbedding 26 | 27 | from .metrics import evaluate_by_threshold, grid_search 28 | 29 | logger = logging.getLogger(__name__) 30 | 31 | 32 | class PoincareStaticEmbeddingEvaluator: 33 | """Evaluating hyperbolic static embedding models ([1] and [2]) for predicting entity hierarchical relationships. 34 | 35 | - [1] Poincaré Embedding by [Nickel et al., NeurIPS 2017](https://arxiv.org/abs/1705.08039). 36 | - [2] Hyperbolic Entailment Cone by [Ganea et al., ICML 2018](https://arxiv.org/abs/1804.01882). 37 | 38 | both of which lie in a unit Poincaré ball. According to [2], it is better to apply the entailment cone loss in the post-training phase of a Poincaré embedding model in [1]. 39 | 40 | The main evaluation metrics are `Precision`, `Recall`, and `F-score`, with overall accuracy (`ACC`) and accuracy on negatives (`ACC-`) additionally reported. The results are written in a `.csv`. If a result file already exists, then values are appended. 41 | 42 | The labels need to be `0` for unrelated pairs and `1` for related pairs. 43 | 44 | Args: 45 | examples (list[int]): List of input examples containing entity IDs. Each example is formatted as `[child_id, parent_id, *negative_ids]`. 46 | batch_size (int): Evaluation batch size. 47 | truth_label (int, optional): Specify which label represents the truth. Defaults to `1`. 48 | """ 49 | 50 | def __init__(self, eval_examples: list, batch_size: int, truth_label: int = 1): 51 | self.examples = eval_examples 52 | self.batch_size = batch_size 53 | self.truth_label = truth_label 54 | # result file 55 | self.results = pd.DataFrame( 56 | columns=["threshold", "precision", "recall", "f1", "accuracy", "accuracy_on_negatives"] 57 | ) 58 | 59 | def inference( 60 | self, 61 | model: PoincareStaticEmbedding, 62 | loss: PoincareEmbeddingStaticLoss | HyperbolicEntailmentConeStaticLoss, 63 | device: torch.device, 64 | ): 65 | """The probing method of the pre-trained hyperbolic static embedding models. It output scores that indicate hierarchical relationships between entities.""" 66 | 67 | # set up scoring function according to input loss 68 | # NOTE: both scores are smaller the better 69 | if isinstance(loss, PoincareEmbeddingStaticLoss): 70 | # distance scoring from [Nickel et al., NeurIPS 2017] 71 | def score_func(subject: torch.Tensor, objects: torch.Tensor, norm_score_weight: float = 1000.0): 72 | dists = loss.manifold.dist(subject, objects) 73 | subject_norms = subject.norm(dim=-1) 74 | objects_norms = objects.norm(dim=-1) 75 | return (1 + norm_score_weight * (objects_norms - subject_norms)) * dists 76 | 77 | elif isinstance(loss, HyperbolicEntailmentConeStaticLoss): 78 | # hyperbolic entailment cone scoring from [Ganea et al., ICML 2018] 79 | score_func = lambda subject, objects: loss.energy(objects, subject) 80 | else: 81 | raise ValueError(f"Unknown loss function type: {type(loss)}.") 82 | 83 | # set model to eval mode 84 | model.eval() 85 | 86 | # make predictions 87 | dataloader = DataLoader(torch.tensor(self.examples).to(device), shuffle=False, batch_size=self.batch_size) 88 | num_negatives = len(self.examples[0]) - 2 # each example is formatted as [child, parent, *negatives] 89 | scores = [] 90 | labels = [] 91 | with torch.no_grad(): 92 | for batch in dataloader: 93 | subject, objects = model(batch) 94 | cur_scores = score_func(subject, objects) 95 | scores.append(cur_scores.reshape((-1,))) 96 | scores = torch.concat(scores, dim=0) 97 | labels = torch.tensor( 98 | ([self.truth_label] + [1 - self.truth_label] * num_negatives) * (int(len(scores) / (1 + num_negatives))) 99 | ).to(scores.device) 100 | assert len(labels) == len(scores) 101 | 102 | return scores, labels 103 | 104 | def __call__( 105 | self, 106 | model: PoincareStaticEmbedding, 107 | loss: PoincareEmbeddingStaticLoss | HyperbolicEntailmentConeStaticLoss, 108 | device: torch.device, 109 | output_path: str | None = None, 110 | epoch: int = -1, 111 | steps: int = -1, 112 | best_threshold: float | None = None, 113 | ): 114 | """Compute the evaluation metrics for the given model. 115 | 116 | Args: 117 | model (HierarchyTransformer): The model to evaluate. 118 | loss (Union[PoincareEmbeddingStaticLoss, HyperbolicEntailmentConeStaticLoss]): The training loss function decides which scoring function to be used. 119 | device (torch.device): The torch device used for evaluation. 120 | output_path (str, optional): Path to save the evaluation results `.csv` file. Defaults to `None`. 121 | epoch (int, optional): The epoch number. Defaults to `-1`. 122 | steps (int, optional): The number of steps. Defaults to `-1`. 123 | best_threshold (float, optional): The best overall threshold searched on a validation set (used for testing). Defaults to `None`. 124 | 125 | Returns: 126 | Dict[str, float]: A dictionary containing the evaluation metrics. 127 | """ 128 | 129 | if best_threshold: 130 | # Testing with pre-defined hyperparameters 131 | logger.info(f"Evaluate on given hyperparemeters `best_threshold={best_threshold}`.") 132 | 133 | # Compute the scores 134 | scores, labels = self.inference(model=model, loss=loss, device=device) 135 | 136 | # Compute the evaluation metrics 137 | best_results = evaluate_by_threshold( 138 | scores=scores, 139 | labels=labels, 140 | threshold=best_threshold, 141 | truth_label=self.truth_label, 142 | smaller_scores_better=True, 143 | ) 144 | 145 | # log the results 146 | if os.path.exists(os.path.join(output_path, "results.tsv")): 147 | self.results = pd.read_csv(os.path.join(output_path, "results.tsv"), sep="\t", index_col=0) 148 | else: 149 | warnings.warn("No previous `results.tsv` detected.") 150 | self.results.loc["testing"] = best_results 151 | else: 152 | # Validation with no pre-defined hyerparameters 153 | logger.info("Evaluate with grid search on hyperparameters `best_threshold` (overall threshold)") 154 | best_f1 = -1.0 155 | best_results = None 156 | 157 | # Compute the scores 158 | scores, labels = self.inference(model=model, loss=loss, device=device) 159 | 160 | # Perform grid search on hyperparameters 161 | cur_best_results = grid_search( 162 | scores=scores, 163 | labels=labels, 164 | threshold_granularity=1 if isinstance(loss, PoincareEmbeddingStaticLoss) else 100, 165 | truth_label=self.truth_label, 166 | smaller_scores_better=True, 167 | primary_metric="f1", 168 | best_primary_metric_value=best_f1, 169 | preformatted_best_results={}, 170 | ) 171 | if cur_best_results: 172 | best_results = cur_best_results 173 | best_f1 = best_results["f1"] 174 | 175 | idx = f"epoch={epoch}" if epoch != "validation" else epoch 176 | self.results.loc[idx] = best_results 177 | 178 | self.results.to_csv(os.path.join(output_path, "results.tsv"), sep="\t") 179 | 180 | logger.info(f"Eval results: {best_results}") 181 | 182 | return best_results 183 | -------------------------------------------------------------------------------- /src/hierarchy_transformers/evaluation/hit_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Yuan He 2 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from __future__ import annotations 15 | 16 | import logging 17 | import os.path 18 | import warnings 19 | 20 | import pandas as pd 21 | import torch 22 | from sentence_transformers.evaluation import SentenceEvaluator 23 | 24 | from hierarchy_transformers import HierarchyTransformer 25 | 26 | from .metrics import evaluate_by_threshold, grid_search 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class HierarchyTransformerEvaluator(SentenceEvaluator): 32 | """Evaluating HiT models for predicting entity hierarchical relationships. 33 | 34 | The main evaluation metrics are `Precision`, `Recall`, and `F-score`, with overall accuracy (`ACC`) and accuracy on negatives (`ACC-`) additionally reported. The results are written in a `.csv`. If a result file already exists, then values are appended. 35 | 36 | The labels need to be `0` for unrelated pairs and `1` for related pairs. 37 | 38 | Args: 39 | child_entities (list[str]): List of child entity names. 40 | parent_entities (list[str]): List of parent entity names. 41 | labels (list[int]): List of reference labels. 42 | batch_size (int): Evaluation batch size. 43 | truth_label (int, optional): Specify which label represents the truth. Defaults to `1`. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | child_entities: list[str], 49 | parent_entities: list[str], 50 | labels: list[int], 51 | batch_size: int, 52 | truth_label: int = 1, 53 | ): 54 | super().__init__() 55 | # set primary metric for model selection 56 | self.primary_metric = "f1" 57 | # input evaluation examples 58 | self.child_entities = child_entities 59 | self.parent_entities = parent_entities 60 | self.labels = labels 61 | # eval batch size 62 | self.batch_size = batch_size 63 | # truth reference label 64 | self.truth_label = truth_label 65 | # result file 66 | self.results = pd.DataFrame( 67 | columns=["centri_weight", "threshold", "precision", "recall", "f1", "accuracy", "accuracy_on_negatives"] 68 | ) 69 | # NOTE: static transformation staticmethod to do 70 | 71 | def inference( 72 | self, 73 | model: HierarchyTransformer, 74 | centri_weight: float, 75 | child_embeds: torch.Tensor | None = None, 76 | parent_embeds: torch.Tensor | None = None, 77 | ): 78 | """The default probing method of the HiT model. It output scores that indicate hierarchical relationships between entities. 79 | 80 | Optional `child_embeds` and `parent_embeds` are used to save time from repetitive encoding. 81 | """ 82 | if child_embeds is None: 83 | logger.info("Encode child entities.") 84 | child_embeds = model.encode( 85 | sentences=self.child_entities, batch_size=self.batch_size, convert_to_tensor=True 86 | ) 87 | if parent_embeds is None: 88 | logger.info("Encode parent entities.") 89 | parent_embeds = model.encode( 90 | sentences=self.parent_entities, batch_size=self.batch_size, convert_to_tensor=True 91 | ) 92 | dists = model.manifold.dist(child_embeds, parent_embeds) 93 | child_norms = model.manifold.dist0(child_embeds) 94 | parent_norms = model.manifold.dist0(parent_embeds) 95 | return -(dists + centri_weight * (parent_norms - child_norms)) 96 | 97 | def __call__( 98 | self, 99 | model: HierarchyTransformer, 100 | output_path: str | None = None, 101 | epoch: int = -1, 102 | steps: int = -1, 103 | best_centri_weight: float | None = None, 104 | best_threshold: float | None = None, 105 | ) -> dict[str, float]: 106 | """Compute the evaluation metrics for the given model. 107 | 108 | Args: 109 | model (HierarchyTransformer): The model to evaluate. 110 | output_path (str, optional): Path to save the evaluation results `.csv` file. Defaults to `None`. 111 | epoch (int, optional): The epoch number. Defaults to `-1`. 112 | steps (int, optional): The number of steps. Defaults to `-1`. 113 | best_centri_weight (float, optional): The best centripetal score weight searched on a validation set (used for testing). Defaults to `None`. 114 | best_threshold (float, optional): The best overall threshold searched on a validation set (used for testing). Defaults to `None`. 115 | 116 | Returns: 117 | Dict[str, float]: A dictionary containing the evaluation metrics. 118 | """ 119 | 120 | # best thresholds and metric searched on validation sets 121 | assert ( 122 | type(best_centri_weight) is type(best_threshold) 123 | ), "Inconsistent types of hyperparameters 'best_centri_weight' (centripetal score weight) and 'best_threshold' (overall threshold)" 124 | 125 | logger.info("Encode child entities.") 126 | child_embeds = model.encode(sentences=self.child_entities, batch_size=self.batch_size, convert_to_tensor=True) 127 | logger.info("Encode parent entities.") 128 | parent_embeds = model.encode( 129 | sentences=self.parent_entities, batch_size=self.batch_size, convert_to_tensor=True 130 | ) 131 | 132 | if best_centri_weight and best_threshold: 133 | # Testing with pre-defined hyperparameters 134 | logger.info( 135 | f"Evaluate on given hyperparemeters `best_centri_weight={best_centri_weight}` (centripetal score weight) and `best_threshold={best_threshold}` (overall threshold)." 136 | ) 137 | 138 | # Compute the scores 139 | scores = self.inference( 140 | model=model, 141 | centri_weight=best_centri_weight, 142 | child_embeds=child_embeds, 143 | parent_embeds=parent_embeds, 144 | ) 145 | 146 | # Compute the evaluation metrics 147 | best_results = {"centri_weight": best_centri_weight} 148 | best_results.update( 149 | evaluate_by_threshold( 150 | scores=scores, 151 | labels=torch.tensor(self.labels).to(scores.device), 152 | threshold=best_threshold, 153 | truth_label=self.truth_label, 154 | smaller_scores_better=False, 155 | ) 156 | ) 157 | 158 | # log the results 159 | if os.path.exists(os.path.join(output_path, "results.tsv")): 160 | self.results = pd.read_csv(os.path.join(output_path, "results.tsv"), sep="\t", index_col=0) 161 | else: 162 | warnings.warn("No previous `results.tsv` detected.") 163 | self.results.loc["testing"] = best_results 164 | else: 165 | # Validation with no pre-defined hyerparameters 166 | logger.info( 167 | "Evaluate with grid search on hyperparameters `best_centri_weight` (centripetal score weight) and `best_threshold` (overall threshold)." 168 | ) 169 | best_f1 = -1.0 170 | best_results = None 171 | is_updated = True 172 | 173 | for centri_weight in range(50): 174 | # early stop if increasing the centri score weight does not help 175 | if not is_updated: 176 | break 177 | is_updated = False 178 | 179 | centri_weight /= 10 180 | 181 | # Compute the scores 182 | scores = self.inference( 183 | model=model, centri_weight=centri_weight, child_embeds=child_embeds, parent_embeds=parent_embeds 184 | ) 185 | 186 | # Perform grid search on hyperparameters 187 | cur_best_results = grid_search( 188 | scores=scores, 189 | labels=torch.tensor(self.labels).to(scores.device), 190 | threshold_granularity=100, 191 | truth_label=self.truth_label, 192 | smaller_scores_better=False, 193 | primary_metric="f1", 194 | best_primary_metric_value=best_f1, 195 | preformatted_best_results={"centri_weight": centri_weight}, 196 | ) 197 | if cur_best_results: 198 | best_results = cur_best_results 199 | best_f1 = best_results["f1"] 200 | is_updated = True 201 | 202 | idx = f"epoch={epoch}" if epoch != "validation" else epoch 203 | self.results.loc[idx] = best_results 204 | 205 | self.results.to_csv(os.path.join(output_path, "results.tsv"), sep="\t") 206 | 207 | logger.info(f"Eval results: {best_results}") 208 | 209 | return best_results 210 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Yuan He 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /docs/assets/images/neurips.svg: -------------------------------------------------------------------------------- 1 | 2 | 16 | 18 | 19 | 21 | image/svg+xml 22 | 24 | 25 | 26 | 27 | 29 | 59 | 64 | 71 | 78 | 79 | NeurIPS 91 | 92 | -------------------------------------------------------------------------------- /docs/static/js/bulma-slider.js: -------------------------------------------------------------------------------- 1 | (function webpackUniversalModuleDefinition(root, factory) { 2 | if(typeof exports === 'object' && typeof module === 'object') 3 | module.exports = factory(); 4 | else if(typeof define === 'function' && define.amd) 5 | define([], factory); 6 | else if(typeof exports === 'object') 7 | exports["bulmaSlider"] = factory(); 8 | else 9 | root["bulmaSlider"] = factory(); 10 | })(typeof self !== 'undefined' ? self : this, function() { 11 | return /******/ (function(modules) { // webpackBootstrap 12 | /******/ // The module cache 13 | /******/ var installedModules = {}; 14 | /******/ 15 | /******/ // The require function 16 | /******/ function __webpack_require__(moduleId) { 17 | /******/ 18 | /******/ // Check if module is in cache 19 | /******/ if(installedModules[moduleId]) { 20 | /******/ return installedModules[moduleId].exports; 21 | /******/ } 22 | /******/ // Create a new module (and put it into the cache) 23 | /******/ var module = installedModules[moduleId] = { 24 | /******/ i: moduleId, 25 | /******/ l: false, 26 | /******/ exports: {} 27 | /******/ }; 28 | /******/ 29 | /******/ // Execute the module function 30 | /******/ modules[moduleId].call(module.exports, module, module.exports, __webpack_require__); 31 | /******/ 32 | /******/ // Flag the module as loaded 33 | /******/ module.l = true; 34 | /******/ 35 | /******/ // Return the exports of the module 36 | /******/ return module.exports; 37 | /******/ } 38 | /******/ 39 | /******/ 40 | /******/ // expose the modules object (__webpack_modules__) 41 | /******/ __webpack_require__.m = modules; 42 | /******/ 43 | /******/ // expose the module cache 44 | /******/ __webpack_require__.c = installedModules; 45 | /******/ 46 | /******/ // define getter function for harmony exports 47 | /******/ __webpack_require__.d = function(exports, name, getter) { 48 | /******/ if(!__webpack_require__.o(exports, name)) { 49 | /******/ Object.defineProperty(exports, name, { 50 | /******/ configurable: false, 51 | /******/ enumerable: true, 52 | /******/ get: getter 53 | /******/ }); 54 | /******/ } 55 | /******/ }; 56 | /******/ 57 | /******/ // getDefaultExport function for compatibility with non-harmony modules 58 | /******/ __webpack_require__.n = function(module) { 59 | /******/ var getter = module && module.__esModule ? 60 | /******/ function getDefault() { return module['default']; } : 61 | /******/ function getModuleExports() { return module; }; 62 | /******/ __webpack_require__.d(getter, 'a', getter); 63 | /******/ return getter; 64 | /******/ }; 65 | /******/ 66 | /******/ // Object.prototype.hasOwnProperty.call 67 | /******/ __webpack_require__.o = function(object, property) { return Object.prototype.hasOwnProperty.call(object, property); }; 68 | /******/ 69 | /******/ // __webpack_public_path__ 70 | /******/ __webpack_require__.p = ""; 71 | /******/ 72 | /******/ // Load entry module and return exports 73 | /******/ return __webpack_require__(__webpack_require__.s = 0); 74 | /******/ }) 75 | /************************************************************************/ 76 | /******/ ([ 77 | /* 0 */ 78 | /***/ (function(module, __webpack_exports__, __webpack_require__) { 79 | 80 | "use strict"; 81 | Object.defineProperty(__webpack_exports__, "__esModule", { value: true }); 82 | /* harmony export (binding) */ __webpack_require__.d(__webpack_exports__, "isString", function() { return isString; }); 83 | /* harmony import */ var __WEBPACK_IMPORTED_MODULE_0__events__ = __webpack_require__(1); 84 | var _extends = Object.assign || function (target) { for (var i = 1; i < arguments.length; i++) { var source = arguments[i]; for (var key in source) { if (Object.prototype.hasOwnProperty.call(source, key)) { target[key] = source[key]; } } } return target; }; 85 | 86 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); 87 | 88 | var _typeof = typeof Symbol === "function" && typeof Symbol.iterator === "symbol" ? function (obj) { return typeof obj; } : function (obj) { return obj && typeof Symbol === "function" && obj.constructor === Symbol && obj !== Symbol.prototype ? "symbol" : typeof obj; }; 89 | 90 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 91 | 92 | function _possibleConstructorReturn(self, call) { if (!self) { throw new ReferenceError("this hasn't been initialised - super() hasn't been called"); } return call && (typeof call === "object" || typeof call === "function") ? call : self; } 93 | 94 | function _inherits(subClass, superClass) { if (typeof superClass !== "function" && superClass !== null) { throw new TypeError("Super expression must either be null or a function, not " + typeof superClass); } subClass.prototype = Object.create(superClass && superClass.prototype, { constructor: { value: subClass, enumerable: false, writable: true, configurable: true } }); if (superClass) Object.setPrototypeOf ? Object.setPrototypeOf(subClass, superClass) : subClass.__proto__ = superClass; } 95 | 96 | 97 | 98 | var isString = function isString(unknown) { 99 | return typeof unknown === 'string' || !!unknown && (typeof unknown === 'undefined' ? 'undefined' : _typeof(unknown)) === 'object' && Object.prototype.toString.call(unknown) === '[object String]'; 100 | }; 101 | 102 | var bulmaSlider = function (_EventEmitter) { 103 | _inherits(bulmaSlider, _EventEmitter); 104 | 105 | function bulmaSlider(selector) { 106 | var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {}; 107 | 108 | _classCallCheck(this, bulmaSlider); 109 | 110 | var _this = _possibleConstructorReturn(this, (bulmaSlider.__proto__ || Object.getPrototypeOf(bulmaSlider)).call(this)); 111 | 112 | _this.element = typeof selector === 'string' ? document.querySelector(selector) : selector; 113 | // An invalid selector or non-DOM node has been provided. 114 | if (!_this.element) { 115 | throw new Error('An invalid selector or non-DOM node has been provided.'); 116 | } 117 | 118 | _this._clickEvents = ['click']; 119 | /// Set default options and merge with instance defined 120 | _this.options = _extends({}, options); 121 | 122 | _this.onSliderInput = _this.onSliderInput.bind(_this); 123 | 124 | _this.init(); 125 | return _this; 126 | } 127 | 128 | /** 129 | * Initiate all DOM element containing selector 130 | * @method 131 | * @return {Array} Array of all slider instances 132 | */ 133 | 134 | 135 | _createClass(bulmaSlider, [{ 136 | key: 'init', 137 | 138 | 139 | /** 140 | * Initiate plugin 141 | * @method init 142 | * @return {void} 143 | */ 144 | value: function init() { 145 | this._id = 'bulmaSlider' + new Date().getTime() + Math.floor(Math.random() * Math.floor(9999)); 146 | this.output = this._findOutputForSlider(); 147 | 148 | this._bindEvents(); 149 | 150 | if (this.output) { 151 | if (this.element.classList.contains('has-output-tooltip')) { 152 | // Get new output position 153 | var newPosition = this._getSliderOutputPosition(); 154 | 155 | // Set output position 156 | this.output.style['left'] = newPosition.position; 157 | } 158 | } 159 | 160 | this.emit('bulmaslider:ready', this.element.value); 161 | } 162 | }, { 163 | key: '_findOutputForSlider', 164 | value: function _findOutputForSlider() { 165 | var _this2 = this; 166 | 167 | var result = null; 168 | var outputs = document.getElementsByTagName('output') || []; 169 | 170 | Array.from(outputs).forEach(function (output) { 171 | if (output.htmlFor == _this2.element.getAttribute('id')) { 172 | result = output; 173 | return true; 174 | } 175 | }); 176 | return result; 177 | } 178 | }, { 179 | key: '_getSliderOutputPosition', 180 | value: function _getSliderOutputPosition() { 181 | // Update output position 182 | var newPlace, minValue; 183 | 184 | var style = window.getComputedStyle(this.element, null); 185 | // Measure width of range input 186 | var sliderWidth = parseInt(style.getPropertyValue('width'), 10); 187 | 188 | // Figure out placement percentage between left and right of input 189 | if (!this.element.getAttribute('min')) { 190 | minValue = 0; 191 | } else { 192 | minValue = this.element.getAttribute('min'); 193 | } 194 | var newPoint = (this.element.value - minValue) / (this.element.getAttribute('max') - minValue); 195 | 196 | // Prevent bubble from going beyond left or right (unsupported browsers) 197 | if (newPoint < 0) { 198 | newPlace = 0; 199 | } else if (newPoint > 1) { 200 | newPlace = sliderWidth; 201 | } else { 202 | newPlace = sliderWidth * newPoint; 203 | } 204 | 205 | return { 206 | 'position': newPlace + 'px' 207 | }; 208 | } 209 | 210 | /** 211 | * Bind all events 212 | * @method _bindEvents 213 | * @return {void} 214 | */ 215 | 216 | }, { 217 | key: '_bindEvents', 218 | value: function _bindEvents() { 219 | if (this.output) { 220 | // Add event listener to update output when slider value change 221 | this.element.addEventListener('input', this.onSliderInput, false); 222 | } 223 | } 224 | }, { 225 | key: 'onSliderInput', 226 | value: function onSliderInput(e) { 227 | e.preventDefault(); 228 | 229 | if (this.element.classList.contains('has-output-tooltip')) { 230 | // Get new output position 231 | var newPosition = this._getSliderOutputPosition(); 232 | 233 | // Set output position 234 | this.output.style['left'] = newPosition.position; 235 | } 236 | 237 | // Check for prefix and postfix 238 | var prefix = this.output.hasAttribute('data-prefix') ? this.output.getAttribute('data-prefix') : ''; 239 | var postfix = this.output.hasAttribute('data-postfix') ? this.output.getAttribute('data-postfix') : ''; 240 | 241 | // Update output with slider value 242 | this.output.value = prefix + this.element.value + postfix; 243 | 244 | this.emit('bulmaslider:ready', this.element.value); 245 | } 246 | }], [{ 247 | key: 'attach', 248 | value: function attach() { 249 | var _this3 = this; 250 | 251 | var selector = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : 'input[type="range"].slider'; 252 | var options = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : {}; 253 | 254 | var instances = new Array(); 255 | 256 | var elements = isString(selector) ? document.querySelectorAll(selector) : Array.isArray(selector) ? selector : [selector]; 257 | elements.forEach(function (element) { 258 | if (typeof element[_this3.constructor.name] === 'undefined') { 259 | var instance = new bulmaSlider(element, options); 260 | element[_this3.constructor.name] = instance; 261 | instances.push(instance); 262 | } else { 263 | instances.push(element[_this3.constructor.name]); 264 | } 265 | }); 266 | 267 | return instances; 268 | } 269 | }]); 270 | 271 | return bulmaSlider; 272 | }(__WEBPACK_IMPORTED_MODULE_0__events__["a" /* default */]); 273 | 274 | /* harmony default export */ __webpack_exports__["default"] = (bulmaSlider); 275 | 276 | /***/ }), 277 | /* 1 */ 278 | /***/ (function(module, __webpack_exports__, __webpack_require__) { 279 | 280 | "use strict"; 281 | var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length; i++) { var descriptor = props[i]; descriptor.enumerable = descriptor.enumerable || false; descriptor.configurable = true; if ("value" in descriptor) descriptor.writable = true; Object.defineProperty(target, descriptor.key, descriptor); } } return function (Constructor, protoProps, staticProps) { if (protoProps) defineProperties(Constructor.prototype, protoProps); if (staticProps) defineProperties(Constructor, staticProps); return Constructor; }; }(); 282 | 283 | function _classCallCheck(instance, Constructor) { if (!(instance instanceof Constructor)) { throw new TypeError("Cannot call a class as a function"); } } 284 | 285 | var EventEmitter = function () { 286 | function EventEmitter() { 287 | var listeners = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : []; 288 | 289 | _classCallCheck(this, EventEmitter); 290 | 291 | this._listeners = new Map(listeners); 292 | this._middlewares = new Map(); 293 | } 294 | 295 | _createClass(EventEmitter, [{ 296 | key: "listenerCount", 297 | value: function listenerCount(eventName) { 298 | if (!this._listeners.has(eventName)) { 299 | return 0; 300 | } 301 | 302 | var eventListeners = this._listeners.get(eventName); 303 | return eventListeners.length; 304 | } 305 | }, { 306 | key: "removeListeners", 307 | value: function removeListeners() { 308 | var _this = this; 309 | 310 | var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 311 | var middleware = arguments.length > 1 && arguments[1] !== undefined ? arguments[1] : false; 312 | 313 | if (eventName !== null) { 314 | if (Array.isArray(eventName)) { 315 | name.forEach(function (e) { 316 | return _this.removeListeners(e, middleware); 317 | }); 318 | } else { 319 | this._listeners.delete(eventName); 320 | 321 | if (middleware) { 322 | this.removeMiddleware(eventName); 323 | } 324 | } 325 | } else { 326 | this._listeners = new Map(); 327 | } 328 | } 329 | }, { 330 | key: "middleware", 331 | value: function middleware(eventName, fn) { 332 | var _this2 = this; 333 | 334 | if (Array.isArray(eventName)) { 335 | name.forEach(function (e) { 336 | return _this2.middleware(e, fn); 337 | }); 338 | } else { 339 | if (!Array.isArray(this._middlewares.get(eventName))) { 340 | this._middlewares.set(eventName, []); 341 | } 342 | 343 | this._middlewares.get(eventName).push(fn); 344 | } 345 | } 346 | }, { 347 | key: "removeMiddleware", 348 | value: function removeMiddleware() { 349 | var _this3 = this; 350 | 351 | var eventName = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 352 | 353 | if (eventName !== null) { 354 | if (Array.isArray(eventName)) { 355 | name.forEach(function (e) { 356 | return _this3.removeMiddleware(e); 357 | }); 358 | } else { 359 | this._middlewares.delete(eventName); 360 | } 361 | } else { 362 | this._middlewares = new Map(); 363 | } 364 | } 365 | }, { 366 | key: "on", 367 | value: function on(name, callback) { 368 | var _this4 = this; 369 | 370 | var once = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false; 371 | 372 | if (Array.isArray(name)) { 373 | name.forEach(function (e) { 374 | return _this4.on(e, callback); 375 | }); 376 | } else { 377 | name = name.toString(); 378 | var split = name.split(/,|, | /); 379 | 380 | if (split.length > 1) { 381 | split.forEach(function (e) { 382 | return _this4.on(e, callback); 383 | }); 384 | } else { 385 | if (!Array.isArray(this._listeners.get(name))) { 386 | this._listeners.set(name, []); 387 | } 388 | 389 | this._listeners.get(name).push({ once: once, callback: callback }); 390 | } 391 | } 392 | } 393 | }, { 394 | key: "once", 395 | value: function once(name, callback) { 396 | this.on(name, callback, true); 397 | } 398 | }, { 399 | key: "emit", 400 | value: function emit(name, data) { 401 | var _this5 = this; 402 | 403 | var silent = arguments.length > 2 && arguments[2] !== undefined ? arguments[2] : false; 404 | 405 | name = name.toString(); 406 | var listeners = this._listeners.get(name); 407 | var middlewares = null; 408 | var doneCount = 0; 409 | var execute = silent; 410 | 411 | if (Array.isArray(listeners)) { 412 | listeners.forEach(function (listener, index) { 413 | // Start Middleware checks unless we're doing a silent emit 414 | if (!silent) { 415 | middlewares = _this5._middlewares.get(name); 416 | // Check and execute Middleware 417 | if (Array.isArray(middlewares)) { 418 | middlewares.forEach(function (middleware) { 419 | middleware(data, function () { 420 | var newData = arguments.length > 0 && arguments[0] !== undefined ? arguments[0] : null; 421 | 422 | if (newData !== null) { 423 | data = newData; 424 | } 425 | doneCount++; 426 | }, name); 427 | }); 428 | 429 | if (doneCount >= middlewares.length) { 430 | execute = true; 431 | } 432 | } else { 433 | execute = true; 434 | } 435 | } 436 | 437 | // If Middleware checks have been passed, execute 438 | if (execute) { 439 | if (listener.once) { 440 | listeners[index] = null; 441 | } 442 | listener.callback(data); 443 | } 444 | }); 445 | 446 | // Dirty way of removing used Events 447 | while (listeners.indexOf(null) !== -1) { 448 | listeners.splice(listeners.indexOf(null), 1); 449 | } 450 | } 451 | } 452 | }]); 453 | 454 | return EventEmitter; 455 | }(); 456 | 457 | /* harmony default export */ __webpack_exports__["a"] = (EventEmitter); 458 | 459 | /***/ }) 460 | /******/ ])["default"]; 461 | }); -------------------------------------------------------------------------------- /docs/static/css/bulma-slider.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround { 2 | from { 3 | -webkit-transform: rotate(0); 4 | transform: rotate(0); 5 | } 6 | to { 7 | -webkit-transform: rotate(359deg); 8 | transform: rotate(359deg); 9 | } 10 | } 11 | @keyframes spinAround { 12 | from { 13 | -webkit-transform: rotate(0); 14 | transform: rotate(0); 15 | } 16 | to { 17 | -webkit-transform: rotate(359deg); 18 | transform: rotate(359deg); 19 | } 20 | } 21 | input[type="range"].slider { 22 | -webkit-appearance: none; 23 | -moz-appearance: none; 24 | appearance: none; 25 | margin: 1rem 0; 26 | background: 0 0; 27 | touch-action: none; 28 | } 29 | input[type="range"].slider.is-fullwidth { 30 | display: block; 31 | width: 100%; 32 | } 33 | input[type="range"].slider:focus { 34 | outline: 0; 35 | } 36 | input[type="range"].slider:not( 37 | [orient="vertical"] 38 | )::-webkit-slider-runnable-track { 39 | width: 100%; 40 | } 41 | input[type="range"].slider:not([orient="vertical"])::-moz-range-track { 42 | width: 100%; 43 | } 44 | input[type="range"].slider:not([orient="vertical"])::-ms-track { 45 | width: 100%; 46 | } 47 | input[type="range"].slider:not([orient="vertical"]).has-output + output, 48 | input[type="range"].slider:not([orient="vertical"]).has-output-tooltip 49 | + output { 50 | width: 3rem; 51 | background: #4a4a4a; 52 | border-radius: 4px; 53 | padding: 0.4rem 0.8rem; 54 | font-size: 0.75rem; 55 | line-height: 0.75rem; 56 | text-align: center; 57 | text-overflow: ellipsis; 58 | white-space: nowrap; 59 | color: #fff; 60 | overflow: hidden; 61 | pointer-events: none; 62 | z-index: 200; 63 | } 64 | input[type="range"].slider:not([orient="vertical"]).has-output-tooltip:disabled 65 | + output, 66 | input[type="range"].slider:not([orient="vertical"]).has-output:disabled 67 | + output { 68 | opacity: 0.5; 69 | } 70 | input[type="range"].slider:not([orient="vertical"]).has-output { 71 | display: inline-block; 72 | vertical-align: middle; 73 | width: calc(100% - (4.2rem)); 74 | } 75 | input[type="range"].slider:not([orient="vertical"]).has-output + output { 76 | display: inline-block; 77 | margin-left: 0.75rem; 78 | vertical-align: middle; 79 | } 80 | input[type="range"].slider:not([orient="vertical"]).has-output-tooltip { 81 | display: block; 82 | } 83 | input[type="range"].slider:not([orient="vertical"]).has-output-tooltip 84 | + output { 85 | position: absolute; 86 | left: 0; 87 | top: -0.1rem; 88 | } 89 | input[type="range"].slider[orient="vertical"] { 90 | -webkit-appearance: slider-vertical; 91 | -moz-appearance: slider-vertical; 92 | appearance: slider-vertical; 93 | -webkit-writing-mode: bt-lr; 94 | -ms-writing-mode: bt-lr; 95 | writing-mode: bt-lr; 96 | } 97 | input[type="range"].slider[orient="vertical"]::-webkit-slider-runnable-track { 98 | height: 100%; 99 | } 100 | input[type="range"].slider[orient="vertical"]::-moz-range-track { 101 | height: 100%; 102 | } 103 | input[type="range"].slider[orient="vertical"]::-ms-track { 104 | height: 100%; 105 | } 106 | input[type="range"].slider::-webkit-slider-runnable-track { 107 | cursor: pointer; 108 | animate: 0.2s; 109 | box-shadow: 0 0 0 #7a7a7a; 110 | background: #dbdbdb; 111 | border-radius: 4px; 112 | border: 0 solid #7a7a7a; 113 | } 114 | input[type="range"].slider::-moz-range-track { 115 | cursor: pointer; 116 | animate: 0.2s; 117 | box-shadow: 0 0 0 #7a7a7a; 118 | background: #dbdbdb; 119 | border-radius: 4px; 120 | border: 0 solid #7a7a7a; 121 | } 122 | input[type="range"].slider::-ms-track { 123 | cursor: pointer; 124 | animate: 0.2s; 125 | box-shadow: 0 0 0 #7a7a7a; 126 | background: #dbdbdb; 127 | border-radius: 4px; 128 | border: 0 solid #7a7a7a; 129 | } 130 | input[type="range"].slider::-ms-fill-lower { 131 | background: #dbdbdb; 132 | border-radius: 4px; 133 | } 134 | input[type="range"].slider::-ms-fill-upper { 135 | background: #dbdbdb; 136 | border-radius: 4px; 137 | } 138 | input[type="range"].slider::-webkit-slider-thumb { 139 | box-shadow: none; 140 | border: 1px solid #b5b5b5; 141 | border-radius: 4px; 142 | background: #fff; 143 | cursor: pointer; 144 | } 145 | input[type="range"].slider::-moz-range-thumb { 146 | box-shadow: none; 147 | border: 1px solid #b5b5b5; 148 | border-radius: 4px; 149 | background: #fff; 150 | cursor: pointer; 151 | } 152 | input[type="range"].slider::-ms-thumb { 153 | box-shadow: none; 154 | border: 1px solid #b5b5b5; 155 | border-radius: 4px; 156 | background: #fff; 157 | cursor: pointer; 158 | } 159 | input[type="range"].slider::-webkit-slider-thumb { 160 | -webkit-appearance: none; 161 | appearance: none; 162 | } 163 | input[type="range"].slider.is-circle::-webkit-slider-thumb { 164 | border-radius: 290486px; 165 | } 166 | input[type="range"].slider.is-circle::-moz-range-thumb { 167 | border-radius: 290486px; 168 | } 169 | input[type="range"].slider.is-circle::-ms-thumb { 170 | border-radius: 290486px; 171 | } 172 | input[type="range"].slider:active::-webkit-slider-thumb { 173 | -webkit-transform: scale(1.25); 174 | transform: scale(1.25); 175 | } 176 | input[type="range"].slider:active::-moz-range-thumb { 177 | transform: scale(1.25); 178 | } 179 | input[type="range"].slider:active::-ms-thumb { 180 | transform: scale(1.25); 181 | } 182 | input[type="range"].slider:disabled { 183 | opacity: 0.5; 184 | cursor: not-allowed; 185 | } 186 | input[type="range"].slider:disabled::-webkit-slider-thumb { 187 | cursor: not-allowed; 188 | -webkit-transform: scale(1); 189 | transform: scale(1); 190 | } 191 | input[type="range"].slider:disabled::-moz-range-thumb { 192 | cursor: not-allowed; 193 | transform: scale(1); 194 | } 195 | input[type="range"].slider:disabled::-ms-thumb { 196 | cursor: not-allowed; 197 | transform: scale(1); 198 | } 199 | input[type="range"].slider:not([orient="vertical"]) { 200 | min-height: calc((1rem + 2px) * 1.25); 201 | } 202 | input[type="range"].slider:not( 203 | [orient="vertical"] 204 | )::-webkit-slider-runnable-track { 205 | height: 0.5rem; 206 | } 207 | input[type="range"].slider:not([orient="vertical"])::-moz-range-track { 208 | height: 0.5rem; 209 | } 210 | input[type="range"].slider:not([orient="vertical"])::-ms-track { 211 | height: 0.5rem; 212 | } 213 | input[type="range"].slider[orient="vertical"]::-webkit-slider-runnable-track { 214 | width: 0.5rem; 215 | } 216 | input[type="range"].slider[orient="vertical"]::-moz-range-track { 217 | width: 0.5rem; 218 | } 219 | input[type="range"].slider[orient="vertical"]::-ms-track { 220 | width: 0.5rem; 221 | } 222 | input[type="range"].slider::-webkit-slider-thumb { 223 | height: 1rem; 224 | width: 1rem; 225 | } 226 | input[type="range"].slider::-moz-range-thumb { 227 | height: 1rem; 228 | width: 1rem; 229 | } 230 | input[type="range"].slider::-ms-thumb { 231 | height: 1rem; 232 | width: 1rem; 233 | } 234 | input[type="range"].slider::-ms-thumb { 235 | margin-top: 0; 236 | } 237 | input[type="range"].slider::-webkit-slider-thumb { 238 | margin-top: -0.25rem; 239 | } 240 | input[type="range"].slider[orient="vertical"]::-webkit-slider-thumb { 241 | margin-top: auto; 242 | margin-left: -0.25rem; 243 | } 244 | input[type="range"].slider.is-small:not([orient="vertical"]) { 245 | min-height: calc((0.75rem + 2px) * 1.25); 246 | } 247 | input[type="range"].slider.is-small:not( 248 | [orient="vertical"] 249 | )::-webkit-slider-runnable-track { 250 | height: 0.375rem; 251 | } 252 | input[type="range"].slider.is-small:not([orient="vertical"])::-moz-range-track { 253 | height: 0.375rem; 254 | } 255 | input[type="range"].slider.is-small:not([orient="vertical"])::-ms-track { 256 | height: 0.375rem; 257 | } 258 | input[type="range"].slider.is-small[orient="vertical"]::-webkit-slider-runnable-track { 259 | width: 0.375rem; 260 | } 261 | input[type="range"].slider.is-small[orient="vertical"]::-moz-range-track { 262 | width: 0.375rem; 263 | } 264 | input[type="range"].slider.is-small[orient="vertical"]::-ms-track { 265 | width: 0.375rem; 266 | } 267 | input[type="range"].slider.is-small::-webkit-slider-thumb { 268 | height: 0.75rem; 269 | width: 0.75rem; 270 | } 271 | input[type="range"].slider.is-small::-moz-range-thumb { 272 | height: 0.75rem; 273 | width: 0.75rem; 274 | } 275 | input[type="range"].slider.is-small::-ms-thumb { 276 | height: 0.75rem; 277 | width: 0.75rem; 278 | } 279 | input[type="range"].slider.is-small::-ms-thumb { 280 | margin-top: 0; 281 | } 282 | input[type="range"].slider.is-small::-webkit-slider-thumb { 283 | margin-top: -0.1875rem; 284 | } 285 | input[type="range"].slider.is-small[orient="vertical"]::-webkit-slider-thumb { 286 | margin-top: auto; 287 | margin-left: -0.1875rem; 288 | } 289 | input[type="range"].slider.is-medium:not([orient="vertical"]) { 290 | min-height: calc((1.25rem + 2px) * 1.25); 291 | } 292 | input[type="range"].slider.is-medium:not( 293 | [orient="vertical"] 294 | )::-webkit-slider-runnable-track { 295 | height: 0.625rem; 296 | } 297 | input[type="range"].slider.is-medium:not( 298 | [orient="vertical"] 299 | )::-moz-range-track { 300 | height: 0.625rem; 301 | } 302 | input[type="range"].slider.is-medium:not([orient="vertical"])::-ms-track { 303 | height: 0.625rem; 304 | } 305 | input[type="range"].slider.is-medium[orient="vertical"]::-webkit-slider-runnable-track { 306 | width: 0.625rem; 307 | } 308 | input[type="range"].slider.is-medium[orient="vertical"]::-moz-range-track { 309 | width: 0.625rem; 310 | } 311 | input[type="range"].slider.is-medium[orient="vertical"]::-ms-track { 312 | width: 0.625rem; 313 | } 314 | input[type="range"].slider.is-medium::-webkit-slider-thumb { 315 | height: 1.25rem; 316 | width: 1.25rem; 317 | } 318 | input[type="range"].slider.is-medium::-moz-range-thumb { 319 | height: 1.25rem; 320 | width: 1.25rem; 321 | } 322 | input[type="range"].slider.is-medium::-ms-thumb { 323 | height: 1.25rem; 324 | width: 1.25rem; 325 | } 326 | input[type="range"].slider.is-medium::-ms-thumb { 327 | margin-top: 0; 328 | } 329 | input[type="range"].slider.is-medium::-webkit-slider-thumb { 330 | margin-top: -0.3125rem; 331 | } 332 | input[type="range"].slider.is-medium[orient="vertical"]::-webkit-slider-thumb { 333 | margin-top: auto; 334 | margin-left: -0.3125rem; 335 | } 336 | input[type="range"].slider.is-large:not([orient="vertical"]) { 337 | min-height: calc((1.5rem + 2px) * 1.25); 338 | } 339 | input[type="range"].slider.is-large:not( 340 | [orient="vertical"] 341 | )::-webkit-slider-runnable-track { 342 | height: 0.75rem; 343 | } 344 | input[type="range"].slider.is-large:not([orient="vertical"])::-moz-range-track { 345 | height: 0.75rem; 346 | } 347 | input[type="range"].slider.is-large:not([orient="vertical"])::-ms-track { 348 | height: 0.75rem; 349 | } 350 | input[type="range"].slider.is-large[orient="vertical"]::-webkit-slider-runnable-track { 351 | width: 0.75rem; 352 | } 353 | input[type="range"].slider.is-large[orient="vertical"]::-moz-range-track { 354 | width: 0.75rem; 355 | } 356 | input[type="range"].slider.is-large[orient="vertical"]::-ms-track { 357 | width: 0.75rem; 358 | } 359 | input[type="range"].slider.is-large::-webkit-slider-thumb { 360 | height: 1.5rem; 361 | width: 1.5rem; 362 | } 363 | input[type="range"].slider.is-large::-moz-range-thumb { 364 | height: 1.5rem; 365 | width: 1.5rem; 366 | } 367 | input[type="range"].slider.is-large::-ms-thumb { 368 | height: 1.5rem; 369 | width: 1.5rem; 370 | } 371 | input[type="range"].slider.is-large::-ms-thumb { 372 | margin-top: 0; 373 | } 374 | input[type="range"].slider.is-large::-webkit-slider-thumb { 375 | margin-top: -0.375rem; 376 | } 377 | input[type="range"].slider.is-large[orient="vertical"]::-webkit-slider-thumb { 378 | margin-top: auto; 379 | margin-left: -0.375rem; 380 | } 381 | input[type="range"].slider.is-white::-moz-range-track { 382 | background: #fff !important; 383 | } 384 | input[type="range"].slider.is-white::-webkit-slider-runnable-track { 385 | background: #fff !important; 386 | } 387 | input[type="range"].slider.is-white::-ms-track { 388 | background: #fff !important; 389 | } 390 | input[type="range"].slider.is-white::-ms-fill-lower { 391 | background: #fff; 392 | } 393 | input[type="range"].slider.is-white::-ms-fill-upper { 394 | background: #fff; 395 | } 396 | input[type="range"].slider.is-white .has-output-tooltip + output, 397 | input[type="range"].slider.is-white.has-output + output { 398 | background-color: #fff; 399 | color: #0a0a0a; 400 | } 401 | input[type="range"].slider.is-black::-moz-range-track { 402 | background: #0a0a0a !important; 403 | } 404 | input[type="range"].slider.is-black::-webkit-slider-runnable-track { 405 | background: #0a0a0a !important; 406 | } 407 | input[type="range"].slider.is-black::-ms-track { 408 | background: #0a0a0a !important; 409 | } 410 | input[type="range"].slider.is-black::-ms-fill-lower { 411 | background: #0a0a0a; 412 | } 413 | input[type="range"].slider.is-black::-ms-fill-upper { 414 | background: #0a0a0a; 415 | } 416 | input[type="range"].slider.is-black .has-output-tooltip + output, 417 | input[type="range"].slider.is-black.has-output + output { 418 | background-color: #0a0a0a; 419 | color: #fff; 420 | } 421 | input[type="range"].slider.is-light::-moz-range-track { 422 | background: #f5f5f5 !important; 423 | } 424 | input[type="range"].slider.is-light::-webkit-slider-runnable-track { 425 | background: #f5f5f5 !important; 426 | } 427 | input[type="range"].slider.is-light::-ms-track { 428 | background: #f5f5f5 !important; 429 | } 430 | input[type="range"].slider.is-light::-ms-fill-lower { 431 | background: #f5f5f5; 432 | } 433 | input[type="range"].slider.is-light::-ms-fill-upper { 434 | background: #f5f5f5; 435 | } 436 | input[type="range"].slider.is-light .has-output-tooltip + output, 437 | input[type="range"].slider.is-light.has-output + output { 438 | background-color: #f5f5f5; 439 | color: #363636; 440 | } 441 | input[type="range"].slider.is-dark::-moz-range-track { 442 | background: #363636 !important; 443 | } 444 | input[type="range"].slider.is-dark::-webkit-slider-runnable-track { 445 | background: #363636 !important; 446 | } 447 | input[type="range"].slider.is-dark::-ms-track { 448 | background: #363636 !important; 449 | } 450 | input[type="range"].slider.is-dark::-ms-fill-lower { 451 | background: #363636; 452 | } 453 | input[type="range"].slider.is-dark::-ms-fill-upper { 454 | background: #363636; 455 | } 456 | input[type="range"].slider.is-dark .has-output-tooltip + output, 457 | input[type="range"].slider.is-dark.has-output + output { 458 | background-color: #363636; 459 | color: #f5f5f5; 460 | } 461 | input[type="range"].slider.is-primary::-moz-range-track { 462 | background: #00d1b2 !important; 463 | } 464 | input[type="range"].slider.is-primary::-webkit-slider-runnable-track { 465 | background: #00d1b2 !important; 466 | } 467 | input[type="range"].slider.is-primary::-ms-track { 468 | background: #00d1b2 !important; 469 | } 470 | input[type="range"].slider.is-primary::-ms-fill-lower { 471 | background: #00d1b2; 472 | } 473 | input[type="range"].slider.is-primary::-ms-fill-upper { 474 | background: #00d1b2; 475 | } 476 | input[type="range"].slider.is-primary .has-output-tooltip + output, 477 | input[type="range"].slider.is-primary.has-output + output { 478 | background-color: #00d1b2; 479 | color: #fff; 480 | } 481 | input[type="range"].slider.is-link::-moz-range-track { 482 | background: #3273dc !important; 483 | } 484 | input[type="range"].slider.is-link::-webkit-slider-runnable-track { 485 | background: #3273dc !important; 486 | } 487 | input[type="range"].slider.is-link::-ms-track { 488 | background: #3273dc !important; 489 | } 490 | input[type="range"].slider.is-link::-ms-fill-lower { 491 | background: #3273dc; 492 | } 493 | input[type="range"].slider.is-link::-ms-fill-upper { 494 | background: #3273dc; 495 | } 496 | input[type="range"].slider.is-link .has-output-tooltip + output, 497 | input[type="range"].slider.is-link.has-output + output { 498 | background-color: #3273dc; 499 | color: #fff; 500 | } 501 | input[type="range"].slider.is-info::-moz-range-track { 502 | background: #209cee !important; 503 | } 504 | input[type="range"].slider.is-info::-webkit-slider-runnable-track { 505 | background: #209cee !important; 506 | } 507 | input[type="range"].slider.is-info::-ms-track { 508 | background: #209cee !important; 509 | } 510 | input[type="range"].slider.is-info::-ms-fill-lower { 511 | background: #209cee; 512 | } 513 | input[type="range"].slider.is-info::-ms-fill-upper { 514 | background: #209cee; 515 | } 516 | input[type="range"].slider.is-info .has-output-tooltip + output, 517 | input[type="range"].slider.is-info.has-output + output { 518 | background-color: #209cee; 519 | color: #fff; 520 | } 521 | input[type="range"].slider.is-success::-moz-range-track { 522 | background: #23d160 !important; 523 | } 524 | input[type="range"].slider.is-success::-webkit-slider-runnable-track { 525 | background: #23d160 !important; 526 | } 527 | input[type="range"].slider.is-success::-ms-track { 528 | background: #23d160 !important; 529 | } 530 | input[type="range"].slider.is-success::-ms-fill-lower { 531 | background: #23d160; 532 | } 533 | input[type="range"].slider.is-success::-ms-fill-upper { 534 | background: #23d160; 535 | } 536 | input[type="range"].slider.is-success .has-output-tooltip + output, 537 | input[type="range"].slider.is-success.has-output + output { 538 | background-color: #23d160; 539 | color: #fff; 540 | } 541 | input[type="range"].slider.is-warning::-moz-range-track { 542 | background: #ffdd57 !important; 543 | } 544 | input[type="range"].slider.is-warning::-webkit-slider-runnable-track { 545 | background: #ffdd57 !important; 546 | } 547 | input[type="range"].slider.is-warning::-ms-track { 548 | background: #ffdd57 !important; 549 | } 550 | input[type="range"].slider.is-warning::-ms-fill-lower { 551 | background: #ffdd57; 552 | } 553 | input[type="range"].slider.is-warning::-ms-fill-upper { 554 | background: #ffdd57; 555 | } 556 | input[type="range"].slider.is-warning .has-output-tooltip + output, 557 | input[type="range"].slider.is-warning.has-output + output { 558 | background-color: #ffdd57; 559 | color: rgba(0, 0, 0, 0.7); 560 | } 561 | input[type="range"].slider.is-danger::-moz-range-track { 562 | background: #ff3860 !important; 563 | } 564 | input[type="range"].slider.is-danger::-webkit-slider-runnable-track { 565 | background: #ff3860 !important; 566 | } 567 | input[type="range"].slider.is-danger::-ms-track { 568 | background: #ff3860 !important; 569 | } 570 | input[type="range"].slider.is-danger::-ms-fill-lower { 571 | background: #ff3860; 572 | } 573 | input[type="range"].slider.is-danger::-ms-fill-upper { 574 | background: #ff3860; 575 | } 576 | input[type="range"].slider.is-danger .has-output-tooltip + output, 577 | input[type="range"].slider.is-danger.has-output + output { 578 | background-color: #ff3860; 579 | color: #fff; 580 | } 581 | --------------------------------------------------------------------------------