├── LICENSE ├── README.md ├── __init__.py ├── cifar10 ├── .DS_Store ├── __init__.py ├── class_hierarchy.txt ├── coarse_map.txt ├── data.py ├── dataset_class_info.json ├── hyperparameters.json └── node_names.txt ├── cifar100 ├── .DS_Store ├── __init__.py ├── class_hierarchy.txt ├── coarse_map.txt ├── data.py ├── dataset_class_info.json ├── hyperparameters.json └── node_names.txt ├── data_loader.py ├── environment.yml ├── geom ├── __init__.py ├── euclidean.py ├── horo.py ├── hyperboloid.py ├── minkowski.py ├── nn.py ├── pmath.py └── poincare.py ├── hierarchy ├── .DS_Store ├── __init__.py └── data.py ├── imagenet100 ├── __init__.py ├── class_hierarchy.txt ├── class_list.txt ├── coarse_map.json ├── data.py ├── dataset_class_info.json ├── hyperparameters.json └── node_names.txt ├── images ├── .DS_Store ├── algorithm.png └── cifar10_viz.png ├── loss.py ├── main.py ├── main_evaluate.py ├── model.py ├── param.py ├── svhn ├── __init__.py ├── data.py └── select_svhn_data.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Aditya Sinha 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [NeurIPS'24] HypStructure 2 | [Paper](https://arxiv.org/pdf/2412.01023) | [NeurIPS Virtual (Video)](https://neurips.cc/virtual/2024/poster/93170) | [Poster](https://neurips.cc/media/PosterPDFs/NeurIPS%202024/93170.png?t=1731743482.698525) | [Slides](https://neurips.cc/media/neurips-2024/Slides/93170.pdf) | [Code](https://github.com/uiuctml/HypStructure) 3 | 4 | This is the official repository for HypStructure: Hyperbolic Structured regularization proposed in our NeurIPS 2024 paper [Learning Structured Representations with Hyperbolic Embeddings](https://arxiv.org/pdf/2412.01023). 5 | 6 | --- 7 | ## Abstract 8 | Most real-world datasets consist of a natural hierarchy between classes or an inherent label structure that is either already available or can be constructed cheaply. However, most existing representation learning methods ignore this hierarchy, treating labels as permutation invariant. Recent work [Zeng et al., 2023](https://openreview.net/pdf?id=7J-30ilaUZM) proposes using this structured information explicitly, but the use of Euclidean distance may distort the underlying semantic context [Chen et al., 2013](https://arxiv.org/pdf/1201.1717). In this work, motivated by the advantage of hyperbolic spaces in modeling hierarchical relationships, we propose a novel approach HypStructure: a Hyperbolic Structured regularization approach to accurately embed the label hierarchy into the learned representations. HypStructure is a simple-yet-effective regularizer that consists of a hyperbolic tree-based representation loss along with a centering loss, and can be combined with any standard task loss to learn hierarchy-informed features. Extensive experiments on several large-scale vision benchmarks demonstrate the efficacy of HypStructure in reducing distortion and boosting generalization performance especially under low dimensional scenarios. For a better understanding of structured representation, we perform eigenvalue analysis that links the representation geometry to improved Out-of-Distribution (OOD) detection performance seen empirically. 9 | 10 | --- 11 | ## Motivation 12 | 13 | This work extends the work of [Zeng et al., 2023](https://github.com/uiuctml/HierarchyCPCC) for structured regularization in the hyperbolic space. 14 | 15 | Visualization of the learnt representation for the CIFAR-10 dataset in the hyperbolic space. 16 | 17 | ![Viz CIFAR10](images/cifar10_viz.png) 18 | 19 | This repository implements the following algorithm in `main.py` and `loss.py`. 20 | 21 | ![Algorithm](images/algorithm.png) 22 | 23 | --- 24 | 25 | ## Setup 26 | - To download the repository and install all necessary requirements, clone the repository followed by installation of the required dependencies using conda: 27 | ``` 28 | git clone https://github.com/uiuctml/HypStructure 29 | conda env create -n hierarchy --file environment.yml 30 | ``` 31 | 32 | ## Datasets 33 | 34 | We use the CIFAR10, CIFAR100 and IMAGENET100 datasets as In-Distribution (ID) datasets and SVHN, Textures, Places365, LSUN, iSUN datasets as the Out-of-Distribution (OOD) datasets for CIFAR10/100 and SUN, Places, Textures and iNaturalist as the OOD datasets for IMAGENET100 respectively. The ID datasets are constructed in `data/` whereas the OOD datasets have to be downloaded and placed in the respective `ood_datasets//` folder. 35 | 36 | - ID Datasets: Following the setup in [Zeng et al., 2023](https://github.com/uiuctml/HierarchyCPCC), we use the datasets from `torchvision.datasets` along with the hierarchies defined in the files `/class_hierarchy.txt`. For IMAGENET100, we construct the subset dataset using the 100 class names defined in prior work [MCM](https://github.com/deeplearning-wisc/MCM/blob/main/data/ImageNet100/class_list.txt). 37 | 38 | - OOD Datasets: Following the dataset source and pre-processing setup from [CIDER](https://github.com/deeplearning-wisc/cider), the datasets are from: 39 | - SVHN 40 | - Textures 41 | - Places365 42 | - LSUN 43 | - iSUN 44 | - iNaturalist 45 | - SUN 46 | - Places 47 | - Textures 48 | 49 | The OOD datasets for IMAGENET100 as the ID dataset are curated and de-duplicated with the IMAGENET-1k dataset, created by [Huang et al., 2021](https://github.com/deeplearning-wisc/large_scale_ood) 50 | 51 | 52 | ## Training and Evaluation 53 | The primary training code is specified in `main.py` which can be run with the followed command 54 | ``` 55 | CUDA_VISIBLE_DEVICES=0 python main.py --timestamp --num_workers 1 56 | ``` 57 | 58 | There are several arguments that can be used to specify the experimental setup and the hyperparameters, we describe a few while some are self-explanatory: 59 | 60 | 61 | - `--timestamp`: a unique id to identify your experiment. You can use `datetime.now().strftime("%m%d%Y%H%M%S")` if you want to timestamp as the identifier. 62 | 63 | - `--dataset`: specify the In-Distribution (ID) datasets for the experimetn from `CIFAR10`, `CIFAR100`, or `IMAGENET100`. Out-of-distribution evalution is run for the corresponding OOD datasets automatically. 64 | 65 | - `--exp_name`: loss functions, the code currently supports two functions 66 | - `ERM`: empirical risk minimization 67 | - `SupCon`: Supervised Contrastive Loss (SupCon) [[Khosla et al. 2021](https://arxiv.org/pdf/2004.11362)] adapted from the original [SupCon implementation](https://github.com/google-research/google-research/tree/master/supcon). 68 | 69 | - `--model_name`: specify the backbone model architecture, for e.g. `resnet34` 70 | 71 | - `--cpcc`: 0 to use a loss function without any CPCC based regularization, 1 to use loss functions with CPCC based regularization. 72 | 73 | - `--cpcc_metric`: metric on feature space. Available choices include `l2`, `l1`, `poincare` , `poincare_exp` and `poincare_mean`, where `poincare_exp` refers to the euclidean-averaging algorithmic variant followed by hyperbolic mapping [(See Section 3.2)](https://arxiv.org/pdf/2412.01023) whereas `poincare_mean` refers to the hyperbolic averaging algorithmic variant. We empirically observe that the `poincare_mean` variant performs better for the ID-classification, whereas the `poincare_exp` variant demonstrates better OOD detection performance. 74 | 75 | 76 | - `--center`: setting this to `1` uses the centering loss defined in the paper, wheras setting it to `0` disables the centering loss. 77 | 78 | - `--lamb`: regularization factor of the CPCC based regularizer. Default value is 1. 79 | 80 | All default training hyperparameters are saved in json files under corresponding dataset folders in `hyperparameters.json`. 81 | 82 | Commands to run experiments for each dataset: 83 | 84 | CIFAR10: 85 | ``` 86 | CUDA_VISIBLE_DEVICES=0 python main.py --timestamp hypstructure_cifar10 --num_workers 1 --dataset CIFAR10 --exp_name SupCon --model_name resnet18 --cpcc 1 --seeds 1 --cpcc_metric poincare_mean --save_freq 20 --feature_dim 512 --center 1 87 | ``` 88 | 89 | 90 | CIFAR100: 91 | ``` 92 | CUDA_VISIBLE_DEVICES=0 python main.py --timestamp hypstructure_cifar100 --num_workers 1 --dataset CIFAR100 --exp_name SupCon --model_name resnet34 --cpcc 1 --seeds 1 --cpcc_metric poincare_mean --save_freq 20 --feature_dim 512 --center 1 93 | ``` 94 | 95 | 96 | IMAGENET100: 97 | ``` 98 | CUDA_VISIBLE_DEVICES=0 python main.py --timestamp hypstructure_imagenet100 --num_workers 4 --dataset IMAGENET100 --exp_name SupCon --model_name resnet34 --cpcc 1 --seeds 1 --cpcc_metric poincare_mean --save_freq 5 --feature_dim 512 --center 1 --lamb 0.1 99 | ``` 100 | 101 | To evaluate the saved models, use the evaluation script as: 102 | 103 | ``` 104 | python evaluate.py --dataset --model_save_location /path/to/model_save_location 105 | ``` 106 | 107 | 108 | ## Acknowledgement 109 | Our work is built upon several code repositories, namely [HierarchyCPCC](https://github.com/uiuctml/HierarchyCPCC), [Hyperbolic Image Embeddings](https://github.com/leymir/hyperbolic-image-embeddings) and [Hyperbolic Vision Transformers](https://github.com/htdt/hyp_metric) for implementation of training functions, and [CIDER](https://github.com/deeplearning-wisc/cider) and it's attributed sources for setup and processing of the OOD datasets. 110 | 111 | ## Citation 112 | If you find our work helpful, please consider citing our paper: 113 | 114 | ``` 115 | @inproceedings{ 116 | sinha2024learning, 117 | title={Learning Structured Representations with Hyperbolic Embeddings}, 118 | author={Aditya Sinha and Siqi Zeng and Makoto Yamada and Han Zhao}, 119 | booktitle={Neural Information Processing Systems}, 120 | year={2024}, 121 | url={https://arxiv.org/pdf/2412.01023} 122 | } 123 | ``` 124 | 125 | ## Contact 126 | Please contact as146@illinois.edu for any questions or comments. 127 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/HypStructure/3c11ea961da48d7c570d973a198e40e6534268be/__init__.py -------------------------------------------------------------------------------- /cifar10/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/HypStructure/3c11ea961da48d7c570d973a198e40e6534268be/cifar10/.DS_Store -------------------------------------------------------------------------------- /cifar10/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/HypStructure/3c11ea961da48d7c570d973a198e40e6534268be/cifar10/__init__.py -------------------------------------------------------------------------------- /cifar10/class_hierarchy.txt: -------------------------------------------------------------------------------- 1 | 9 11 2 | 9 1 3 | 11 0 4 | 11 2 5 | 11 10 6 | 11 12 7 | 1 3 8 | 1 4 9 | 1 5 10 | 1 6 11 | 1 7 12 | 1 8 13 | -------------------------------------------------------------------------------- /cifar10/coarse_map.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 0 3 | 1 4 | 1 5 | 1 6 | 1 7 | 1 8 | 1 9 | 0 10 | 0 11 | -------------------------------------------------------------------------------- /cifar10/data.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from PIL import Image 3 | from torchvision.datasets import CIFAR10 4 | from hierarchy.data import Hierarchy 5 | 6 | import json, os 7 | 8 | 9 | 10 | class HierarchyCIFAR10(Hierarchy, CIFAR10): 11 | def __init__(self, *args, **kw): 12 | super(HierarchyCIFAR10, self).__init__(dataset_name='cifar10', *args, **kw) 13 | 14 | def __getitem__(self, index: int): 15 | img, target = self.data[index], int(self.targets[index]) 16 | img = Image.fromarray(img) 17 | 18 | if self.transform is not None: 19 | img = self.transform(img) 20 | return img, target 21 | 22 | def generate_files(self): 23 | 24 | hierarchy = {"root" : ['transportation','animal'], 25 | "transportation" : ['airplane','automobile','ship','truck'], 26 | "animal" : ['bird','cat','deer','dog','frog','horse'], 27 | } 28 | 29 | 30 | keys = set(hierarchy.keys()) 31 | leaf_names = set() 32 | 33 | for value_list in hierarchy.values(): 34 | for value in value_list: 35 | if value not in keys: 36 | leaf_names.add(value) 37 | 38 | leaf_names = sorted(list(leaf_names)) 39 | 40 | 41 | all_names = sorted(list(hierarchy.keys()) + leaf_names) 42 | name2wnid = {all_names[i]: str(i) for i in range(len(all_names))} 43 | leaf2labelid = {leaf_names[i]: i for i in range(len(leaf_names))} 44 | 45 | edges = [] 46 | for parent, children in hierarchy.items(): 47 | parent_wnid = name2wnid[parent] 48 | for child in children: 49 | child_wnid = name2wnid[child] 50 | edges.append((parent_wnid, child_wnid)) 51 | 52 | with open(os.path.join(self.info_dir, f'{self.dataset_name}/class_hierarchy.txt'), 'w') as file: 53 | for id1, id2 in edges: 54 | file.write(f"{id1} {id2}\n") 55 | 56 | data_class_info = [[leaf2labelid[leaf],name2wnid[leaf],leaf] for leaf in leaf_names] 57 | with open(os.path.join(self.info_dir, f"{self.dataset_name}/dataset_class_info.json"), 'w') as file: 58 | json.dump(data_class_info, file) 59 | 60 | with open(os.path.join(self.info_dir, f'{self.dataset_name}/node_names.txt'), 'w') as file: 61 | for key, value in name2wnid.items(): 62 | file.write(f"{value}\t{key}\n") 63 | -------------------------------------------------------------------------------- /cifar10/dataset_class_info.json: -------------------------------------------------------------------------------- 1 | [[0, "0", "airplane"], [1, "2", "automobile"], [2, "3", "bird"], [3, "4", "cat"], [4, "5", "deer"], [5, "6", "dog"], [6, "7", "frog"], [7, "8", "horse"], [8, "10", "ship"], [9, "12", "truck"]] -------------------------------------------------------------------------------- /cifar10/hyperparameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "SupCon":{ 3 | "epochs": 500, 4 | "optimizer": { 5 | "lr": 0.5, 6 | "momentum": 0.9, 7 | "weight_decay": 0.0001 8 | }, 9 | "lr_decay_rate" : 0.1 10 | }, 11 | "ERM":{ 12 | "epochs": 200, 13 | "optimizer": { 14 | "lr": 0.1, 15 | "momentum": 0.9, 16 | "weight_decay": 0.0005 17 | }, 18 | "scheduler": { 19 | "step_size": 60, 20 | "gamma": 0.2 21 | } 22 | } 23 | } -------------------------------------------------------------------------------- /cifar10/node_names.txt: -------------------------------------------------------------------------------- 1 | 0 airplane 2 | 1 animal 3 | 2 automobile 4 | 3 bird 5 | 4 cat 6 | 5 deer 7 | 6 dog 8 | 7 frog 9 | 8 horse 10 | 9 root 11 | 10 ship 12 | 11 transportation 13 | 12 truck 14 | -------------------------------------------------------------------------------- /cifar100/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/HypStructure/3c11ea961da48d7c570d973a198e40e6534268be/cifar100/.DS_Store -------------------------------------------------------------------------------- /cifar100/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/HypStructure/3c11ea961da48d7c570d973a198e40e6534268be/cifar100/__init__.py -------------------------------------------------------------------------------- /cifar100/class_hierarchy.txt: -------------------------------------------------------------------------------- 1 | 86 2 2 | 86 34 3 | 86 36 4 | 86 37 5 | 86 40 6 | 86 44 7 | 86 45 8 | 86 46 9 | 86 49 10 | 86 50 11 | 86 51 12 | 86 52 13 | 86 60 14 | 86 65 15 | 86 72 16 | 86 83 17 | 86 94 18 | 86 109 19 | 86 113 20 | 86 114 21 | 2 5 22 | 2 32 23 | 2 69 24 | 2 89 25 | 2 116 26 | 34 1 27 | 34 35 28 | 34 82 29 | 34 90 30 | 34 110 31 | 36 68 32 | 36 77 33 | 36 87 34 | 36 100 35 | 36 111 36 | 37 10 37 | 37 11 38 | 37 17 39 | 37 30 40 | 37 76 41 | 40 0 42 | 40 64 43 | 40 67 44 | 40 71 45 | 40 101 46 | 44 23 47 | 44 26 48 | 44 48 49 | 44 104 50 | 44 105 51 | 45 6 52 | 45 21 53 | 45 27 54 | 45 102 55 | 45 115 56 | 46 7 57 | 46 8 58 | 46 15 59 | 46 19 60 | 46 25 61 | 49 4 62 | 49 54 63 | 49 55 64 | 49 106 65 | 49 118 66 | 50 13 67 | 50 18 68 | 50 43 69 | 50 84 70 | 50 93 71 | 51 24 72 | 51 38 73 | 51 62 74 | 51 75 75 | 51 88 76 | 52 16 77 | 52 20 78 | 52 22 79 | 52 33 80 | 52 47 81 | 60 39 82 | 60 78 83 | 60 79 84 | 60 81 85 | 60 92 86 | 65 28 87 | 65 57 88 | 65 95 89 | 65 97 90 | 65 120 91 | 72 3 92 | 72 12 93 | 72 41 94 | 72 58 95 | 72 119 96 | 83 29 97 | 83 31 98 | 83 56 99 | 83 96 100 | 83 112 101 | 94 42 102 | 94 63 103 | 94 80 104 | 94 91 105 | 94 98 106 | 109 59 107 | 109 66 108 | 109 70 109 | 109 74 110 | 109 117 111 | 113 9 112 | 113 14 113 | 113 61 114 | 113 73 115 | 113 108 116 | 114 53 117 | 114 85 118 | 114 99 119 | 114 103 120 | 114 107 121 | -------------------------------------------------------------------------------- /cifar100/coarse_map.txt: -------------------------------------------------------------------------------- 1 | 4 2 | 1 3 | 14 4 | 8 5 | 0 6 | 6 7 | 7 8 | 7 9 | 18 10 | 3 11 | 3 12 | 14 13 | 9 14 | 18 15 | 7 16 | 11 17 | 3 18 | 9 19 | 7 20 | 11 21 | 6 22 | 11 23 | 5 24 | 10 25 | 7 26 | 6 27 | 13 28 | 15 29 | 3 30 | 15 31 | 0 32 | 11 33 | 1 34 | 10 35 | 12 36 | 14 37 | 16 38 | 9 39 | 11 40 | 5 41 | 5 42 | 19 43 | 8 44 | 8 45 | 15 46 | 13 47 | 14 48 | 17 49 | 18 50 | 10 51 | 16 52 | 4 53 | 17 54 | 4 55 | 2 56 | 0 57 | 17 58 | 4 59 | 18 60 | 17 61 | 10 62 | 3 63 | 2 64 | 12 65 | 12 66 | 16 67 | 12 68 | 1 69 | 9 70 | 19 71 | 2 72 | 10 73 | 0 74 | 1 75 | 16 76 | 12 77 | 9 78 | 13 79 | 15 80 | 13 81 | 16 82 | 19 83 | 2 84 | 4 85 | 6 86 | 19 87 | 5 88 | 5 89 | 8 90 | 19 91 | 18 92 | 1 93 | 2 94 | 15 95 | 6 96 | 0 97 | 17 98 | 8 99 | 14 100 | 13 101 | -------------------------------------------------------------------------------- /cifar100/data.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from PIL import Image 3 | from torchvision.datasets import CIFAR100 4 | from hierarchy.data import Hierarchy 5 | 6 | import json, os 7 | 8 | 9 | 10 | class HierarchyCIFAR100(Hierarchy, CIFAR100): 11 | def __init__(self, *args, **kw): 12 | super(HierarchyCIFAR100, self).__init__(dataset_name='cifar100', *args, **kw) 13 | 14 | def __getitem__(self, index: int): 15 | img, target = self.data[index], int(self.targets[index]) 16 | img = Image.fromarray(img) 17 | 18 | if self.transform is not None: 19 | img = self.transform(img) 20 | return img, target 21 | 22 | def generate_files(self): 23 | # Following the convention from : https://github.com/MadryLab/BREEDS-Benchmarks/tree/master/imagenet_class_hierarchy/modified 24 | 25 | # Step 1: 26 | # wordnet id = unique id for all nodes 27 | # assume each node has a unique name 28 | # hierarchy: parent -> children 29 | hierarchy = {"root" : ["aquatic_mammals", "fish", "flowers", "food_containers","fruit_and_vegetables", 30 | "household_electrical_devices", "household_furniture", "insects", "large_carnivores", "large_man-made_outdoor_things", 31 | "large_natural_outdoor_scenes","large_omnivores_and_herbivores","medium_mammals","non-insect_invertebrates","people", 32 | "reptiles","small_mammals","trees","vehicles_1","vehicles_2"], 33 | "aquatic_mammals": ["beaver", "dolphin", "otter", "seal", "whale"], 34 | "fish": ["aquarium_fish", "flatfish", "ray", "shark", "trout"], 35 | "flowers": ["orchids", "poppies", "roses", "sunflowers", "tulips"], 36 | "food_containers": ["bottles", "bowls", "cans", "cups", "plates"], 37 | "fruit_and_vegetables": ["apples", "mushrooms", "oranges", "pears", "sweet_peppers"], 38 | "household_electrical_devices": ["clock", "computer_keyboard", "lamp", "telephone", "television"], 39 | "household_furniture": ["bed", "chair", "couch", "table", "wardrobe"], 40 | "insects": ["bee", "beetle", "butterfly", "caterpillar", "cockroach"], 41 | "large_carnivores": ["bear", "leopard", "lion", "tiger", "wolf"], 42 | "large_man-made_outdoor_things": ["bridge", "castle", "house", "road", "skyscraper"], 43 | "large_natural_outdoor_scenes": ["cloud", "forest", "mountain", "plain", "sea"], 44 | "large_omnivores_and_herbivores": ["camel", "cattle", "chimpanzee", "elephant", "kangaroo"], 45 | "medium_mammals": ["fox", "porcupine", "possum", "raccoon", "skunk"], 46 | "non-insect_invertebrates": ["crab", "lobster", "snail", "spider", "worm"], 47 | "people": ["baby", "boy", "girl", "man", "woman"], 48 | "reptiles": ["crocodile", "dinosaur", "lizard", "snake", "turtle"], 49 | "small_mammals": ["hamster", "mouse", "rabbit", "shrew", "squirrel"], 50 | "trees": ["maple_tree", "oak_tree", "palm_tree", "pine_tree", "willow_tree"], 51 | "vehicles_1": ["bicycle", "bus", "motorcycle", "pickup_truck", "train"], 52 | "vehicles_2": ["lawn_mower", "rocket", "streetcar", "tank", "tractor"] 53 | } 54 | 55 | 56 | keys = set(hierarchy.keys()) 57 | leaf_names = set() 58 | 59 | for value_list in hierarchy.values(): 60 | for value in value_list: 61 | if value not in keys: 62 | leaf_names.add(value) 63 | 64 | leaf_names = sorted(list(leaf_names)) 65 | 66 | 67 | all_names = sorted(list(hierarchy.keys()) + leaf_names) 68 | name2wnid = {all_names[i]: str(i) for i in range(len(all_names))} 69 | 70 | 71 | # Step 2: 72 | # label_id = id from the target attribute of the orginal torchvision dataset 73 | leaf2labelid = {leaf_names[i]: i for i in range(len(leaf_names))} 74 | 75 | # Step 3: 76 | # label name = leaf_names 77 | # parent wordnet id \space children wordnet id 78 | 79 | edges = [] 80 | for parent, children in hierarchy.items(): 81 | parent_wnid = name2wnid[parent] 82 | for child in children: 83 | child_wnid = name2wnid[child] 84 | edges.append((parent_wnid, child_wnid)) 85 | 86 | with open(os.path.join(self.info_dir, f'{self.dataset_name}/class_hierarchy.txt'), 'w') as file: 87 | for id1, id2 in edges: 88 | file.write(f"{id1} {id2}\n") 89 | 90 | 91 | # dataset_class_info.json 92 | # a list, each entry is [int(label_id), wordnet id, label name] 93 | # all leaves 94 | data_class_info = [[leaf2labelid[leaf],name2wnid[leaf],leaf] for leaf in leaf_names] 95 | with open(os.path.join(self.info_dir, f"{self.dataset_name}/dataset_class_info.json"), 'w') as file: 96 | json.dump(data_class_info, file) 97 | 98 | # node_names.txt 99 | # wordnet id \tab node name 100 | with open(os.path.join(self.info_dir, f'{self.dataset_name}/node_names.txt'), 'w') as file: 101 | for key, value in name2wnid.items(): 102 | file.write(f"{value}\t{key}\n") 103 | -------------------------------------------------------------------------------- /cifar100/dataset_class_info.json: -------------------------------------------------------------------------------- 1 | [[0, "0", "apples"], [1, "1", "aquarium_fish"], [2, "3", "baby"], [3, "4", "bear"], [4, "5", "beaver"], [5, "6", "bed"], [6, "7", "bee"], [7, "8", "beetle"], [8, "9", "bicycle"], [9, "10", "bottles"], [10, "11", "bowls"], [11, "12", "boy"], [12, "13", "bridge"], [13, "14", "bus"], [14, "15", "butterfly"], [15, "16", "camel"], [16, "17", "cans"], [17, "18", "castle"], [18, "19", "caterpillar"], [19, "20", "cattle"], [20, "21", "chair"], [21, "22", "chimpanzee"], [22, "23", "clock"], [23, "24", "cloud"], [24, "25", "cockroach"], [25, "26", "computer_keyboard"], [26, "27", "couch"], [27, "28", "crab"], [28, "29", "crocodile"], [29, "30", "cups"], [30, "31", "dinosaur"], [31, "32", "dolphin"], [32, "33", "elephant"], [33, "35", "flatfish"], [34, "38", "forest"], [35, "39", "fox"], [36, "41", "girl"], [37, "42", "hamster"], [38, "43", "house"], [39, "47", "kangaroo"], [40, "48", "lamp"], [41, "53", "lawn_mower"], [42, "54", "leopard"], [43, "55", "lion"], [44, "56", "lizard"], [45, "57", "lobster"], [46, "58", "man"], [47, "59", "maple_tree"], [48, "61", "motorcycle"], [49, "62", "mountain"], [50, "63", "mouse"], [51, "64", "mushrooms"], [52, "66", "oak_tree"], [53, "67", "oranges"], [54, "68", "orchids"], [55, "69", "otter"], [56, "70", "palm_tree"], [57, "71", "pears"], [58, "73", "pickup_truck"], [59, "74", "pine_tree"], [60, "75", "plain"], [61, "76", "plates"], [62, "77", "poppies"], [63, "78", "porcupine"], [64, "79", "possum"], [65, "80", "rabbit"], [66, "81", "raccoon"], [67, "82", "ray"], [68, "84", "road"], [69, "85", "rocket"], [70, "87", "roses"], [71, "88", "sea"], [72, "89", "seal"], [73, "90", "shark"], [74, "91", "shrew"], [75, "92", "skunk"], [76, "93", "skyscraper"], [77, "95", "snail"], [78, "96", "snake"], [79, "97", "spider"], [80, "98", "squirrel"], [81, "99", "streetcar"], [82, "100", "sunflowers"], [83, "101", "sweet_peppers"], [84, "102", "table"], [85, "103", "tank"], [86, "104", "telephone"], [87, "105", "television"], [88, "106", "tiger"], [89, "107", "tractor"], [90, "108", "train"], [91, "110", "trout"], [92, "111", "tulips"], [93, "112", "turtle"], [94, "115", "wardrobe"], [95, "116", "whale"], [96, "117", "willow_tree"], [97, "118", "wolf"], [98, "119", "woman"], [99, "120", "worm"]] -------------------------------------------------------------------------------- /cifar100/hyperparameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "SupCon":{ 3 | "epochs": 500, 4 | "optimizer": { 5 | "lr": 0.5, 6 | "momentum": 0.9, 7 | "weight_decay": 0.0001 8 | }, 9 | "lr_decay_rate" : 0.1 10 | }, 11 | "ERM":{ 12 | "epochs": 200, 13 | "optimizer": { 14 | "lr": 0.1, 15 | "momentum": 0.9, 16 | "weight_decay": 0.0005 17 | }, 18 | "scheduler": { 19 | "step_size": 60, 20 | "gamma": 0.2 21 | } 22 | } 23 | } -------------------------------------------------------------------------------- /cifar100/node_names.txt: -------------------------------------------------------------------------------- 1 | 0 apples 2 | 1 aquarium_fish 3 | 2 aquatic_mammals 4 | 3 baby 5 | 4 bear 6 | 5 beaver 7 | 6 bed 8 | 7 bee 9 | 8 beetle 10 | 9 bicycle 11 | 10 bottles 12 | 11 bowls 13 | 12 boy 14 | 13 bridge 15 | 14 bus 16 | 15 butterfly 17 | 16 camel 18 | 17 cans 19 | 18 castle 20 | 19 caterpillar 21 | 20 cattle 22 | 21 chair 23 | 22 chimpanzee 24 | 23 clock 25 | 24 cloud 26 | 25 cockroach 27 | 26 computer_keyboard 28 | 27 couch 29 | 28 crab 30 | 29 crocodile 31 | 30 cups 32 | 31 dinosaur 33 | 32 dolphin 34 | 33 elephant 35 | 34 fish 36 | 35 flatfish 37 | 36 flowers 38 | 37 food_containers 39 | 38 forest 40 | 39 fox 41 | 40 fruit_and_vegetables 42 | 41 girl 43 | 42 hamster 44 | 43 house 45 | 44 household_electrical_devices 46 | 45 household_furniture 47 | 46 insects 48 | 47 kangaroo 49 | 48 lamp 50 | 49 large_carnivores 51 | 50 large_man-made_outdoor_things 52 | 51 large_natural_outdoor_scenes 53 | 52 large_omnivores_and_herbivores 54 | 53 lawn_mower 55 | 54 leopard 56 | 55 lion 57 | 56 lizard 58 | 57 lobster 59 | 58 man 60 | 59 maple_tree 61 | 60 medium_mammals 62 | 61 motorcycle 63 | 62 mountain 64 | 63 mouse 65 | 64 mushrooms 66 | 65 non-insect_invertebrates 67 | 66 oak_tree 68 | 67 oranges 69 | 68 orchids 70 | 69 otter 71 | 70 palm_tree 72 | 71 pears 73 | 72 people 74 | 73 pickup_truck 75 | 74 pine_tree 76 | 75 plain 77 | 76 plates 78 | 77 poppies 79 | 78 porcupine 80 | 79 possum 81 | 80 rabbit 82 | 81 raccoon 83 | 82 ray 84 | 83 reptiles 85 | 84 road 86 | 85 rocket 87 | 86 root 88 | 87 roses 89 | 88 sea 90 | 89 seal 91 | 90 shark 92 | 91 shrew 93 | 92 skunk 94 | 93 skyscraper 95 | 94 small_mammals 96 | 95 snail 97 | 96 snake 98 | 97 spider 99 | 98 squirrel 100 | 99 streetcar 101 | 100 sunflowers 102 | 101 sweet_peppers 103 | 102 table 104 | 103 tank 105 | 104 telephone 106 | 105 television 107 | 106 tiger 108 | 107 tractor 109 | 108 train 110 | 109 trees 111 | 110 trout 112 | 111 tulips 113 | 112 turtle 114 | 113 vehicles_1 115 | 114 vehicles_2 116 | 115 wardrobe 117 | 116 whale 118 | 117 willow_tree 119 | 118 wolf 120 | 119 woman 121 | 120 worm 122 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import numpy as np 3 | import os 4 | 5 | import torchvision.transforms as transforms 6 | from torch.utils.data import DataLoader, Subset 7 | from torchvision.datasets import ImageFolder, CIFAR10, CIFAR100 8 | from torch.utils.data.distributed import DistributedSampler 9 | 10 | import os 11 | import os.path 12 | import numpy as np 13 | 14 | from svhn.data import SVHN 15 | from cifar10.data import HierarchyCIFAR10 16 | from cifar100.data import HierarchyCIFAR100 17 | from imagenet100.data import ImageNet100, HierarchyImageNet100 18 | 19 | 20 | class TwoCropTransform: 21 | """Create two crops of the same image""" 22 | def __init__(self, transform): 23 | self.transform = transform 24 | 25 | def __call__(self, x): 26 | return [self.transform(x), self.transform(x)] 27 | 28 | 29 | def make_dataloader(exp_name : str, num_workers : int, batch_size : int, 30 | task : str, in_dataset_name : str, 31 | ood_dataset_name : str = None, 32 | ood_root : str = '/path/to/ood_datasets/') -> Tuple[DataLoader, DataLoader]: 33 | ''' 34 | Creat (a subset of) train test dataloader. Train & test has the same number of classes. 35 | Args: 36 | num_workers : number of workers of train and test loader. 37 | batch_size : batch size of train and test loader 38 | task : 'train','test', 'ood' 39 | ''' 40 | 41 | if in_dataset_name == 'CIFAR10': 42 | img_size = 32 43 | mean = [0.4914, 0.4822, 0.4465] 44 | std = [0.2470, 0.2435, 0.2616] 45 | elif in_dataset_name == 'CIFAR100': 46 | img_size = 32 47 | mean = [0.5071, 0.4867, 0.4408] 48 | std = [0.2675, 0.2565, 0.2761] 49 | elif in_dataset_name == 'IMAGENET100': 50 | img_size = 224 51 | mean = [0.485, 0.456, 0.406] 52 | std = [0.229, 0.224, 0.225] 53 | else: 54 | raise ValueError() 55 | normalize = transforms.Normalize(mean=mean, std=std) 56 | 57 | def make_train_dataset(exp_name, dataset_name): 58 | if exp_name == 'SupCon': 59 | # data augmentations for supcon 60 | if dataset_name == 'IMAGENET100': 61 | transform = TwoCropTransform(transforms.Compose([ 62 | transforms.RandomResizedCrop(size=224, scale=(0.4, 1.)), 63 | transforms.RandomHorizontalFlip(), 64 | transforms.RandomApply([ 65 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 66 | ], p=0.8), 67 | transforms.RandomGrayscale(p=0.2), 68 | transforms.ToTensor(), 69 | normalize, 70 | ])) 71 | else: 72 | transform = TwoCropTransform(transforms.Compose([ 73 | transforms.RandomResizedCrop(size=32, scale=(0.2, 1.)), 74 | transforms.RandomHorizontalFlip(), 75 | transforms.RandomApply([ 76 | transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 77 | ], p=0.8), 78 | transforms.RandomGrayscale(p=0.2), 79 | transforms.ToTensor(), 80 | normalize, 81 | ])) 82 | elif exp_name == 'ERM': 83 | transform = transforms.Compose([transforms.RandomCrop(32, padding=4), 84 | transforms.RandomHorizontalFlip(), 85 | transforms.RandomRotation(15), 86 | transforms.ToTensor(), 87 | normalize 88 | ]) 89 | 90 | if dataset_name == 'CIFAR10': 91 | dataset = HierarchyCIFAR10(root = './data', 92 | train = True, 93 | transform = transform, 94 | download=False 95 | ) 96 | elif dataset_name == 'CIFAR100': 97 | dataset = HierarchyCIFAR100(root = './data', 98 | train = True, 99 | transform = transform, 100 | download=False 101 | ) 102 | elif dataset_name == 'IMAGENET100': 103 | dataset = HierarchyImageNet100(root = '/data/common/ImageNet100/ImageNet100', 104 | train = True, 105 | transform = transform, 106 | ) 107 | return dataset 108 | 109 | def make_test_dataset(dataset_name): 110 | # inD 111 | transform = transforms.Compose([transforms.Resize(img_size), 112 | transforms.CenterCrop(img_size), 113 | transforms.ToTensor(), 114 | normalize 115 | ]) 116 | 117 | if dataset_name == 'CIFAR10': 118 | dataset = CIFAR10(root = './data', 119 | train = False, 120 | transform = transform) 121 | elif dataset_name == 'CIFAR100': 122 | dataset = CIFAR100(root = './data', 123 | train = False, 124 | transform = transform) 125 | elif dataset_name == 'IMAGENET100': 126 | dataset = ImageNet100(root = '/path/to/ImageNet100', 127 | train = False, 128 | transform = transform) 129 | 130 | return dataset 131 | 132 | def make_outlier_dataset_large(dataset_name): 133 | if dataset_name == 'CIFAR10': 134 | dataset = make_test_dataset('CIFAR10') 135 | elif dataset_name == 'CIFAR100': 136 | dataset = make_test_dataset('CIFAR100') 137 | else: # far-ood 138 | 139 | if dataset_name == 'Places365': 140 | dataset = ImageFolder(root= os.path.join(ood_root, 'Places'), 141 | transform=transforms.Compose([transforms.Resize(img_size), 142 | transforms.CenterCrop(img_size), transforms.ToTensor(),normalize])) 143 | elif dataset_name == 'SUN': 144 | dataset = ImageFolder(root = os.path.join(ood_root, 'SUN'), 145 | transform=transforms.Compose([transforms.Resize(img_size), 146 | transforms.CenterCrop(img_size),transforms.ToTensor(),normalize])) 147 | elif dataset_name == 'dtd': 148 | dataset = ImageFolder(root=os.path.join(ood_root, 'dtd', 'images'), 149 | transform=transforms.Compose([transforms.Resize(img_size), 150 | transforms.CenterCrop(img_size), transforms.ToTensor(),normalize])) 151 | elif dataset_name == 'iNaturalist': 152 | dataset = ImageFolder(root = os.path.join(ood_root, 'iNaturalist'), 153 | transform=transforms.Compose([transforms.Resize(img_size), 154 | transforms.CenterCrop(img_size),transforms.ToTensor(),normalize])) 155 | elif dataset_name == 'placesbg': 156 | dataset = ImageFolder(root = os.path.join(ood_root, 'placesbg'), 157 | transform=transforms.Compose([transforms.Resize(img_size), 158 | transforms.CenterCrop(img_size),transforms.ToTensor(),normalize])) 159 | 160 | return dataset 161 | 162 | def make_outlier_dataset(dataset_name): 163 | if dataset_name == 'CIFAR10': 164 | dataset = make_test_dataset('CIFAR10') 165 | elif dataset_name == 'CIFAR100': 166 | dataset = make_test_dataset('CIFAR100') 167 | else: # far-ood 168 | if dataset_name == 'SVHN': 169 | dataset = SVHN(root=os.path.join(ood_root, 'svhn'), split='test', 170 | transform=transforms.Compose([transforms.Resize(img_size), 171 | transforms.CenterCrop(img_size),transforms.ToTensor(), normalize]), download=False) 172 | elif dataset_name == 'Textures': 173 | dataset = ImageFolder(root=os.path.join(ood_root, 'dtd', 'images'), 174 | transform=transforms.Compose([transforms.Resize(img_size), 175 | transforms.CenterCrop(img_size), transforms.ToTensor(),normalize])) 176 | elif dataset_name == 'Places365': 177 | dataset = ImageFolder(root= os.path.join(ood_root, 'places365'), 178 | transform=transforms.Compose([transforms.Resize(img_size), 179 | transforms.CenterCrop(img_size), transforms.ToTensor(),normalize])) 180 | elif dataset_name == 'LSUN': 181 | dataset = ImageFolder(root = os.path.join(ood_root, 'LSUN'), 182 | transform=transforms.Compose([transforms.Resize(img_size), 183 | transforms.CenterCrop(img_size),transforms.ToTensor(),normalize])) 184 | elif dataset_name == 'iSUN': 185 | dataset = ImageFolder(root = os.path.join(ood_root, 'iSUN'), 186 | transform=transforms.Compose([transforms.Resize(img_size), 187 | transforms.CenterCrop(img_size),transforms.ToTensor(),normalize])) 188 | 189 | if len(dataset) > 10000: 190 | print("Sampling 10000 samples") 191 | dataset = Subset(dataset, np.random.choice(len(dataset), 10000, replace=False)) 192 | 193 | return dataset 194 | 195 | if task == 'train': 196 | dataset = make_train_dataset(exp_name, in_dataset_name) 197 | world_size = int(os.environ.get('WORLD_SIZE',0)) 198 | if world_size > 1: # DDP 199 | rank = int(os.environ['RANK']) 200 | sampler = DistributedSampler(dataset=dataset, num_replicas=world_size, rank=rank, shuffle=True) 201 | dataloader = DataLoader(dataset=dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers, pin_memory=True) 202 | else: 203 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True) 204 | elif task == 'test': 205 | dataset = make_test_dataset(in_dataset_name) 206 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 207 | elif task == 'ood': 208 | if in_dataset_name == 'IMAGENET100': 209 | dataset = make_outlier_dataset_large(ood_dataset_name) 210 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 211 | else: 212 | dataset = make_outlier_dataset(ood_dataset_name) 213 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) 214 | return dataloader 215 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: torch2ood 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2024.3.11=h06a4308_0 8 | - libedit=3.1.20230828=h5eee18b_0 9 | - libffi=3.2.1=hf484d3e_1007 10 | - libgcc-ng=11.2.0=h1234567_1 11 | - libgomp=11.2.0=h1234567_1 12 | - libstdcxx-ng=11.2.0=h1234567_1 13 | - ncurses=6.4=h6a678d5_0 14 | - openssl=1.1.1w=h7f8727e_0 15 | - pip=23.3.1=py38h06a4308_0 16 | - python=3.8.0=h0371630_2 17 | - readline=7.0=h7b6447c_5 18 | - setuptools=68.2.2=py38h06a4308_0 19 | - sqlite=3.33.0=h62c20be_0 20 | - tk=8.6.12=h1ccaba5_0 21 | - wheel=0.41.2=py38h06a4308_0 22 | - xz=5.4.6=h5eee18b_0 23 | - zlib=1.2.13=h5eee18b_0 24 | - pip: 25 | - anyio==4.3.0 26 | - appdirs==1.4.4 27 | - argon2-cffi==23.1.0 28 | - argon2-cffi-bindings==21.2.0 29 | - arrow==1.3.0 30 | - asttokens==2.4.1 31 | - async-lru==2.0.4 32 | - attrs==23.2.0 33 | - babel==2.14.0 34 | - backcall==0.2.0 35 | - beautifulsoup4==4.12.3 36 | - bleach==6.1.0 37 | - blosc2==2.0.0 38 | - certifi==2024.2.2 39 | - cffi==1.16.0 40 | - charset-normalizer==3.3.2 41 | - click==8.1.7 42 | - comm==0.2.2 43 | - contourpy==1.1.1 44 | - cox==0.1.post3 45 | - cycler==0.12.1 46 | - cython==3.0.9 47 | - debugpy==1.8.1 48 | - decorator==5.1.1 49 | - defusedxml==0.7.1 50 | - dill==0.3.8 51 | - docker-pycreds==0.4.0 52 | - exceptiongroup==1.2.1 53 | - executing==2.0.1 54 | - faiss-gpu==1.7.2 55 | - fastjsonschema==2.19.1 56 | - filelock==3.13.3 57 | - fonttools==4.50.0 58 | - fqdn==1.5.1 59 | - fsspec==2024.3.1 60 | - gitdb==4.0.11 61 | - gitpython==3.1.42 62 | - gputil==1.4.0 63 | - grpcio==1.62.1 64 | - h11==0.14.0 65 | - httpcore==1.0.5 66 | - httpx==0.27.0 67 | - idna==3.6 68 | - importlib-metadata==7.1.0 69 | - importlib-resources==6.4.0 70 | - ipykernel==6.29.4 71 | - ipython==8.12.3 72 | - ipywidgets==8.1.2 73 | - isoduration==20.11.0 74 | - jedi==0.19.1 75 | - jinja2==3.1.3 76 | - joblib==1.3.2 77 | - json5==0.9.25 78 | - jsonpointer==2.4 79 | - jsonschema==4.22.0 80 | - jsonschema-specifications==2023.12.1 81 | - jupyter==1.0.0 82 | - jupyter-client==8.6.1 83 | - jupyter-console==6.6.3 84 | - jupyter-core==5.7.2 85 | - jupyter-events==0.10.0 86 | - jupyter-lsp==2.2.5 87 | - jupyter-server==2.14.0 88 | - jupyter-server-terminals==0.5.3 89 | - jupyterlab==4.1.8 90 | - jupyterlab-pygments==0.3.0 91 | - jupyterlab-server==2.27.1 92 | - jupyterlab-widgets==3.0.10 93 | - kiwisolver==1.4.5 94 | - llvmlite==0.41.1 95 | - markupsafe==2.1.5 96 | - matplotlib==3.7.5 97 | - matplotlib-inline==0.1.7 98 | - mistune==3.0.2 99 | - mpmath==1.3.0 100 | - msgpack==1.0.8 101 | - nbclient==0.10.0 102 | - nbconvert==7.16.4 103 | - nbformat==5.10.4 104 | - nest-asyncio==1.6.0 105 | - networkx==3.1 106 | - nltk==3.8.1 107 | - notebook==7.1.3 108 | - notebook-shim==0.2.4 109 | - numba==0.58.1 110 | - numexpr==2.8.6 111 | - numpy==1.24.4 112 | - nvidia-cublas-cu12==12.1.3.1 113 | - nvidia-cuda-cupti-cu12==12.1.105 114 | - nvidia-cuda-nvrtc-cu12==12.1.105 115 | - nvidia-cuda-runtime-cu12==12.1.105 116 | - nvidia-cudnn-cu12==8.9.2.26 117 | - nvidia-cufft-cu12==11.0.2.54 118 | - nvidia-curand-cu12==10.3.2.106 119 | - nvidia-cusolver-cu12==11.4.5.107 120 | - nvidia-cusparse-cu12==12.1.0.106 121 | - nvidia-nccl-cu12==2.19.3 122 | - nvidia-nvjitlink-cu12==12.4.99 123 | - nvidia-nvtx-cu12==12.1.105 124 | - opencv-python==4.9.0.80 125 | - overrides==7.7.0 126 | - packaging==24.0 127 | - pandas==2.0.3 128 | - pandocfilters==1.5.1 129 | - parso==0.8.4 130 | - pexpect==4.9.0 131 | - pickleshare==0.7.5 132 | - pillow==10.2.0 133 | - pkgutil-resolve-name==1.3.10 134 | - platformdirs==4.2.1 135 | - pot==0.9.3 136 | - prometheus-client==0.20.0 137 | - prompt-toolkit==3.0.43 138 | - protobuf==4.25.3 139 | - psutil==5.9.8 140 | - ptyprocess==0.7.0 141 | - pure-eval==0.2.2 142 | - py-cpuinfo==9.0.0 143 | - py3nvml==0.2.7 144 | - pycparser==2.22 145 | - pygments==2.17.2 146 | - pynndescent==0.5.12 147 | - pyparsing==3.1.2 148 | - python-dateutil==2.9.0.post0 149 | - python-json-logger==2.0.7 150 | - pytz==2024.1 151 | - pyyaml==6.0.1 152 | - pyzmq==26.0.3 153 | - qtconsole==5.5.1 154 | - qtpy==2.4.1 155 | - referencing==0.35.1 156 | - regex==2023.12.25 157 | - requests==2.31.0 158 | - rfc3339-validator==0.1.4 159 | - rfc3986-validator==0.1.1 160 | - robustness==1.2.1.post2 161 | - rpds-py==0.18.0 162 | - scikit-learn==1.3.2 163 | - scipy==1.10.1 164 | - seaborn==0.13.2 165 | - send2trash==1.8.3 166 | - sentry-sdk==1.44.0 167 | - setproctitle==1.3.3 168 | - six==1.16.0 169 | - smmap==5.0.1 170 | - sniffio==1.3.1 171 | - soupsieve==2.5 172 | - stack-data==0.6.3 173 | - sympy==1.12 174 | - tables==3.8.0 175 | - tensorboardx==2.6.2.2 176 | - terminado==0.18.1 177 | - threadpoolctl==3.4.0 178 | - tinycss2==1.3.0 179 | - tomli==2.0.1 180 | - torch==2.2.2 181 | - torchaudio==2.2.2 182 | - torchvision==0.17.2 183 | - tornado==6.4 184 | - tqdm==4.66.2 185 | - traitlets==5.14.3 186 | - triton==2.2.0 187 | - types-python-dateutil==2.9.0.20240316 188 | - typing-extensions==4.10.0 189 | - tzdata==2024.1 190 | - umap-learn==0.5.6 191 | - uri-template==1.3.0 192 | - urllib3==2.2.1 193 | - wandb==0.16.5 194 | - wcwidth==0.2.13 195 | - webcolors==1.13 196 | - webencodings==0.5.1 197 | - websocket-client==1.8.0 198 | - widgetsnbextension==4.0.10 199 | - xmltodict==0.13.0 200 | - zipp==3.18.1 201 | -------------------------------------------------------------------------------- /geom/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/HypStructure/3c11ea961da48d7c570d973a198e40e6534268be/geom/__init__.py -------------------------------------------------------------------------------- /geom/euclidean.py: -------------------------------------------------------------------------------- 1 | """ Geometric utility functions, mostly for standard Euclidean operations.""" 2 | 3 | import torch 4 | 5 | MIN_NORM = 1e-15 6 | 7 | 8 | def orthonormal(Q): 9 | """Return orthonormal basis spanned by the vectors in Q. 10 | 11 | Q: (..., k, d) k vectors of dimension d to orthonormalize 12 | """ 13 | k = Q.size(-2) 14 | _, _, v = torch.svd(Q, some=False) # Q = USV^T 15 | Q_ = v[:, :k] 16 | return Q_.transpose(-1, -2) # (k, d) rows are orthonormal basis for rows of Q 17 | 18 | 19 | def euc_reflection(x, a): 20 | """ 21 | Euclidean reflection (also hyperbolic) of x 22 | Along the geodesic that goes through a and the origin 23 | (straight line) 24 | 25 | NOTE: this should be generalized by reflect() 26 | """ 27 | xTa = torch.sum(x * a, dim=-1, keepdim=True) 28 | norm_a_sq = torch.sum(a ** 2, dim=-1, keepdim=True) 29 | proj = xTa * a / norm_a_sq.clamp_min(MIN_NORM) 30 | return 2 * proj - x 31 | 32 | 33 | def reflect(x, Q): 34 | """Reflect points (euclidean) with respect to the space spanned by the rows of Q. 35 | 36 | Q: (k, d) set of k d-dimensional vectors (must be orthogonal) 37 | """ 38 | ref = 2 * Q.transpose(0, 1) @ Q - torch.eye(x.shape[-1], device=x.device) 39 | return x @ ref 40 | -------------------------------------------------------------------------------- /geom/horo.py: -------------------------------------------------------------------------------- 1 | """Horocycle projection utils (Poincare model).""" 2 | 3 | import torch 4 | 5 | MIN_NORM = 1e-15 6 | 7 | 8 | def busemann(x, p, keepdim=True): 9 | """ 10 | x: (..., d) 11 | p: (..., d) 12 | 13 | Returns: (..., 1) if keepdim==True else (...) 14 | """ 15 | 16 | xnorm = x.norm(dim=-1, p=2, keepdim=True) 17 | pnorm = p.norm(dim=-1, p=2, keepdim=True) 18 | p = p / pnorm.clamp_min(MIN_NORM) 19 | num = torch.norm(p - x, dim=-1, keepdim=True) ** 2 20 | den = (1 - xnorm ** 2).clamp_min(MIN_NORM) 21 | ans = torch.log((num / den).clamp_min(MIN_NORM)) 22 | if not keepdim: 23 | ans = ans.squeeze(-1) 24 | return ans 25 | 26 | 27 | def circle_intersection_(r, R): 28 | """ Computes the intersection of a circle of radius r and R with distance 1 between their centers. 29 | 30 | Returns: 31 | x - distance from center of first circle 32 | h - height off the line connecting the two centers of the two intersection pointers 33 | """ 34 | 35 | x = (1.0 - R ** 2 + r ** 2) / 2.0 36 | s = (r + R + 1) / 2.0 37 | sq_h = (s * (s - r) * (s - R) * (s - 1)).clamp_min(MIN_NORM) 38 | h = torch.sqrt(sq_h) * 2.0 39 | return x, h 40 | 41 | 42 | def circle_intersection(c1, c2, r1, r2): 43 | """ Computes the intersections of a circle centered at ci of radius ri. 44 | 45 | c1, c2: (..., d) 46 | r1, r2: (...) 47 | """ 48 | 49 | d = torch.norm(c1 - c2) # (...) 50 | x, h = circle_intersection_(r1 / d.clamp_min(MIN_NORM), r2 / d.clamp_min(MIN_NORM)) # (...) 51 | x = x.unsqueeze(-1) 52 | h = h.unsqueeze(-1) 53 | center = x * c2 + (1 - x) * c1 # (..., d) 54 | radius = h * d # (...) 55 | 56 | # The intersection is a hypersphere of one lower dimension, intersected with the plane 57 | # orthogonal to the direction c1->c2 58 | # In general, you can compute this with a sort of higher dimensional cross product? 59 | # For now, only 2 dimensions 60 | 61 | ortho = c2 - c1 # (..., d) 62 | assert ortho.size(-1) == 2 63 | direction = torch.stack((-ortho[..., 1], ortho[..., 0]), dim=-1) 64 | direction = direction / torch.norm(direction, keepdim=True).clamp_min(MIN_NORM) 65 | return center + radius.unsqueeze(-1) * direction # , center - radius*direction 66 | 67 | 68 | def busemann_to_horocycle(p, t): 69 | """ Find the horocycle corresponding to the level set of the Busemann function to ideal point p with value t. 70 | 71 | p: (..., d) 72 | t: (...) 73 | 74 | Returns: 75 | c: (..., d) 76 | r: (...) 77 | """ 78 | # Busemann_p(x) = d means dist(0, x) = -d 79 | q = -torch.tanh(t / 2).unsqueeze(-1) * p 80 | c = (p + q) / 2.0 81 | r = torch.norm(p - q, dim=-1) / 2.0 82 | return c, r 83 | 84 | 85 | def sphere_intersection(c1, r1, c2, r2): 86 | """ Computes the intersections of a circle centered at ci of radius ri. 87 | 88 | c1, c2: (..., d) 89 | r1, r2: (...) 90 | 91 | Returns: 92 | center, radius such that the intersection of the two spheres is given by 93 | the intersection of the sphere (c, r) with the hyperplane orthogonal to the direction c1->c2 94 | """ 95 | 96 | d = torch.norm(c1 - c2, dim=-1) # (...) 97 | x, h = circle_intersection_(r1 / d.clamp_min(MIN_NORM), r2 / d.clamp_min(MIN_NORM)) # (...) 98 | x = x.unsqueeze(-1) 99 | center = x * c2 + (1 - x) * c1 # (..., d) 100 | radius = h * d # (...) 101 | return center, radius 102 | 103 | 104 | def sphere_intersections(c, r): 105 | """ Computes the intersection of k spheres in dimension d. 106 | 107 | c: list of centers (..., k, d) 108 | r: list of radii (..., k) 109 | 110 | Returns: 111 | center: (..., d) 112 | radius: (...) 113 | ortho_directions: (..., d, k-1) 114 | """ 115 | 116 | k = c.size(-2) 117 | assert k == r.size(-1) 118 | 119 | ortho_directions = [] 120 | center = c[..., 0, :] # (..., d) 121 | radius = r[..., 0] # (...) 122 | for i in range(1, k): 123 | center, radius = sphere_intersection(center, radius, c[..., i, :], r[..., i]) 124 | ortho_directions.append(c[..., i, :] - center) 125 | ortho_directions.append(torch.zeros_like(center)) # trick to handle the case k=1 126 | ortho_directions = torch.stack(ortho_directions, dim=-1) # (..., d, k-1) [last element is 0] 127 | return center, radius, ortho_directions 128 | 129 | 130 | # 2D projections 131 | def project_kd(p, x, keep_ambient=True): 132 | """ Project n points in dimension d onto 'direction' spanned by k ideal points 133 | p: (..., k, d) ideal points 134 | x: (..., n, d) points to project 135 | 136 | Returns: 137 | projection_1: (..., n, s) where s = d if keep_ambient==True otherwise s = k 138 | projection_2: same as projection_1. this is guaranteed to be the ideal point in the case k = 1 139 | p: the ideal points 140 | """ 141 | 142 | if len(p.shape) < 2: 143 | p = p.unsqueeze(0) 144 | if len(x.shape) < 2: 145 | x = x.unsqueeze(0) 146 | k = p.size(-2) 147 | d = x.size(-1) 148 | assert d == p.size(-1) 149 | busemann_distances = busemann(x.unsqueeze(-2), p.unsqueeze(-3), keepdim=False) # (..., n, k) 150 | c, r = busemann_to_horocycle(p.unsqueeze(-3), busemann_distances) # (..., n, k, d) (..., n, k) 151 | c, r, ortho = sphere_intersections(c, r) # (..., n, d) (..., n) (..., n, d, k-1) 152 | # we are looking for a vector spanned by the k ideal points, orthogonal to k-1 given vectors 153 | # i.e. x @ p @ ortho = 0 154 | if ortho is None: 155 | direction = torch.ones_like(busemann_distances) # (..., n, k) 156 | else: 157 | a = torch.matmul(p.unsqueeze(-3), ortho) # (..., n, k, k-1) = (..., n, k, d) @ (..., n, d, k-1) 158 | u, s, v = torch.svd(a, some=False) # a = u s v^T 159 | direction = u[..., -1] # (..., n, k) 160 | direction = direction @ p # (..., n, d) 161 | direction = direction / torch.norm(direction, dim=-1, keepdim=True).clamp_min(MIN_NORM) 162 | 163 | projection_1 = c - r.unsqueeze(-1) * direction 164 | projection_2 = c + r.unsqueeze(-1) * direction 165 | if not keep_ambient: 166 | _, _, v = torch.svd(p, some=False) # P = USV^T => PV = US so last d-k columns of PV are 0 167 | projection_1 = (projection_1 @ v)[..., :k] 168 | projection_2 = (projection_2 @ v)[..., :k] 169 | p = (p @ v)[..., :k] 170 | 171 | return projection_1, projection_2, p 172 | 173 | 174 | def project2d(p, q, x): 175 | # reconstruct p and q in 2D 176 | p_ = torch.stack([p.new_ones(p.shape[:-1]), p.new_zeros(p.shape[:-1])], dim=-1) 177 | cos = torch.sum(p * q, dim=-1) 178 | sin = torch.sqrt(1 - cos ** 2) 179 | q_ = torch.stack([cos, sin], dim=-1) 180 | bp = busemann(x, p).squeeze(-1) 181 | bq = busemann(x, q).squeeze(-1) 182 | c0, r0 = busemann_to_horocycle(p_, bp) 183 | c1, r1 = busemann_to_horocycle(q_, bq) 184 | reconstruction = circle_intersection(c0, c1, r0, r1) 185 | return reconstruction 186 | 187 | 188 | def horo_project_using_one_ideal(submanifold_ideals, x, custom_ideal_direction=None, keep_ambient=True): 189 | """The first horospherical projection we discussed, currently section 5.3 of the overleaf doc 190 | Args: 191 | submanifold_ideals: torch.tensor of shape (sub_dim, dim) 192 | x: torch.tensor of shape (batch_size, dim) 193 | custom_ideal_direction (optional): torch.tensor of shape (dim, ) 194 | keep_ambient: boolean 195 | 196 | Returns: 197 | if keep_ambient == True: 198 | torch.tensor of shape (batch_size, dim) 199 | else: 200 | torch.tensor of shape (batch_size, sub_dim): the rotated projections 201 | 202 | Note: 203 | custom_ideal_direction, if provided, must be a unit vector in the row span of submanifold_ideals 204 | submanifold_ideals must have independent rows 205 | The submanifold we are projecting onto always passes through the origin. 206 | """ 207 | 208 | if custom_ideal_direction is None: 209 | p = submanifold_ideals[0, :] 210 | p = p / torch.sqrt(p.dot(p)) 211 | else: 212 | p = custom_ideal_direction 213 | 214 | eucl_proj_coefs, _ = torch.solve(submanifold_ideals @ x.transpose(0, 1), 215 | submanifold_ideals @ submanifold_ideals.transpose(0, 1)) # (sub_dim, batch_size) 216 | eucl_projs = eucl_proj_coefs.transpose(0, 1) @ submanifold_ideals # (batch_size, dim) 217 | 218 | t = torch.sum((p - x) * (p - x), dim=-1) / ( 219 | 2 * torch.sum((p - eucl_projs) * (p - x), dim=-1)) # shape (batch_size, ) 220 | t = t.unsqueeze(-1) 221 | 222 | output = 2 * t * eucl_projs + (1 - 2 * t) * p.unsqueeze(0) # shape (batch_size, dim) 223 | if keep_ambient: 224 | return output 225 | else: 226 | q, r = torch.qr(submanifold_ideals.transpose(0, 1)) # q.shape (dim, sub_dim) 227 | return output @ q 228 | 229 | 230 | def test_horo_project_one_ideal(): 231 | """ 232 | Sanity checks for horo_project_using_one_ideal: 233 | In this test, the first two input points are already in the submanifold, 234 | so they should project to themselves. The third input should not. 235 | """ 236 | 237 | print("Test Horo Projection with One Ideal:") 238 | print("------------------------------------") 239 | submanifold_ideals = torch.tensor([[1.0, 0, 0], [0.0, 1, 0]]) 240 | x = torch.tensor([[0.0, 0, 0], [0.7, 0.6, 0], [0.3, 0.4, 0.5]]) 241 | proj = horo_project_using_one_ideal(submanifold_ideals, x) 242 | print(x, proj) 243 | 244 | """ 245 | In the case sub_dim == 1, applying horo_project_using_one_ideal and then 246 | computing hyperbolic distance to the origin should give another implementation 247 | of the Busemann function, at least up to a sign convention 248 | """ 249 | import geom.poincare as poincare 250 | 251 | submanifold_ideals = torch.tensor([[1.0, 0, 0]]) 252 | x = torch.tensor([[0.0, 0, 0], [0.7, 0.6, 0], [0.3, 0.4, 0.5]]) 253 | proj = horo_project_using_one_ideal(submanifold_ideals, x, keep_ambient=False) 254 | print(poincare.distance0(proj)) 255 | print(busemann(x, submanifold_ideals)) 256 | 257 | 258 | if __name__ == "__main__": 259 | test_horo_project_one_ideal() 260 | -------------------------------------------------------------------------------- /geom/hyperboloid.py: -------------------------------------------------------------------------------- 1 | """Util functions for hyperboloid models 2 | Convention: The ambient Minkowski space has signature -1, 1, 1, ... 3 | i.e. the squared norm of (t,x,y,z) is -t^2 + x^2 + y^2 + z^2, 4 | And we are using the positive sheet, i.e. every point on the hyperboloid 5 | has positive first coordinate. 6 | """ 7 | import torch 8 | 9 | import geom.minkowski as minkowski 10 | import geom.poincare as poincare 11 | 12 | MIN_NORM = 1e-15 13 | 14 | 15 | def distance(x, y): 16 | """ 17 | Args: 18 | x, y: torch.tensor of the same shape (..., Minkowski_dim) 19 | 20 | Returns: 21 | torch.tensor of shape (..., ) 22 | """ 23 | # return torch.acosh(- minkowski.bilinear_pairing(x, y)) 24 | return torch.acosh(torch.clamp(- minkowski.bilinear_pairing(x, y), min=1.0)) 25 | 26 | 27 | def exp_unit_tangents(base_points, unit_tangents, distances): 28 | """Batched exponential map using the given base points, unit tangent directions, and distances 29 | 30 | Args: 31 | base_points, unit_tangents: torch.tensor of shape (..., Minkowski_dim) 32 | Each unit_tangents[j..., :] must have (Minkowski) squared norm 1 and is orthogonal to base_points[j..., :] 33 | distances: torch.tensor of shape (...) 34 | 35 | Returns: 36 | torch.tensor of shape (..., Minkowski_dim) 37 | """ 38 | distances = distances.unsqueeze(-1) 39 | return base_points * torch.cosh(distances) + unit_tangents * torch.sinh(distances) 40 | 41 | 42 | # def exp(base_points, tangents): 43 | # """Batched exponential map using the given base points and tangent vectors 44 | # 45 | # Args: 46 | # base_point, tangents: torch.tensor of shape (..., Minkowski_dim) 47 | # Each tangents[j..., :] must have squared norm > 0 and is orthogonal to base_points[j..., :] 48 | # 49 | # Returns: 50 | # torch.tensor of shape (..., Minkowski_dim) 51 | # """ 52 | # distances = torch.sqrt(minkowski.squared_norm(tangents)) # shape (...) 53 | # unit_tangets = tangents / distances.view(-1, 1) # shape (..., Minkowski_dim) 54 | # return exp_unit_tangents(base_point, unit_tangents, distances) 55 | 56 | 57 | def from_poincare(x, ideal=False): 58 | """Convert from Poincare ball model to hyperboloid model 59 | Args: 60 | x: torch.tensor of shape (..., dim) 61 | ideal: boolean. Should be True if the input vectors are ideal points, False otherwise 62 | 63 | Returns: 64 | torch.tensor of shape (..., dim+1) 65 | 66 | To do: 67 | Add some capping to make things numerically stable. This is only needed in the case ideal == False 68 | """ 69 | if ideal: 70 | t = torch.ones(x.shape[:-1], device=x.device).unsqueeze(-1) 71 | return torch.cat((t, x), dim=-1) 72 | else: 73 | eucl_squared_norm = (x * x).sum(dim=-1, keepdim=True) 74 | return torch.cat((1 + eucl_squared_norm, 2 * x), dim=-1) / (1 - eucl_squared_norm).clamp_min(MIN_NORM) 75 | 76 | 77 | def to_poincare(x, ideal=False): 78 | """Convert from hyperboloid model to Poincare ball model 79 | Args: 80 | x: torch.tensor of shape (..., Minkowski_dim), where Minkowski_dim >= 3 81 | ideal: boolean. Should be True if the input vectors are ideal points, False otherwise 82 | 83 | Returns: 84 | torch.tensor of shape (..., Minkowski_dim - 1) 85 | """ 86 | if ideal: 87 | return x[..., 1:] / (x[..., 0].unsqueeze(-1)).clamp_min(MIN_NORM) 88 | else: 89 | return x[..., 1:] / (1 + x[..., 0].unsqueeze(-1)).clamp_min(MIN_NORM) 90 | 91 | 92 | def decision_boundary_to_poincare(minkowski_normal_vec): 93 | """Convert the totally geodesic submanifold defined by the Minkowski normal vector to Poincare ball model 94 | (Here the Minkowski normal vector defines a linear subspace, which intersects the hyperboloid at our submanifold) 95 | 96 | Args: 97 | minkowski_normal_vec: torch.tensor of shape (Minkowski_dim, ) 98 | 99 | Returns: 100 | center: torch.tensor of shape (Minkowski_dim -1, ) 101 | radius: float 102 | 103 | Warning: 104 | minkowski_normal_vec must have positive squared norm 105 | minkowski_normal_vec[0] must be nonzero (otherwise the submanifold is a flat plane through the origin) 106 | """ 107 | x = minkowski_normal_vec 108 | # poincare_origin = [1,0,0,0,...], # shape (Minkowski_dim, ) 109 | poincare_origin = torch.zeros(minkowski_normal_vec.shape[0], device=minkowski_normal_vec.device) 110 | poincare_origin[0] = 1 111 | 112 | # shape (1, Minkowski_dim) 113 | poincare_origin_reflected = minkowski.reflection(minkowski_normal_vec, poincare_origin.unsqueeze(0)) 114 | 115 | # shape (Minkowski_dim-1, ) 116 | origin_reflected = to_poincare(poincare_origin_reflected).squeeze(0) 117 | center = poincare.reflection_center(origin_reflected) 118 | 119 | radius = torch.sqrt(torch.sum(center ** 2) - 1) 120 | 121 | return center, radius 122 | 123 | 124 | def orthogonal_projection(basis, x): 125 | """Compute the orthogonal projection of x onto the geodesic submanifold 126 | spanned by the given basis vectors (i.e. the intersection of the hyperboloid with 127 | the Euclidean linear subspace spanned by the basis vectors). 128 | 129 | Args: 130 | basis: torch.tensor of shape(num_basis, Minkowski_dim) 131 | x: torch.tensor of shape(batch_size, Minkowski_dim) 132 | 133 | Returns: 134 | torch.tensor of shape(batch_size, Minkowski_dim) 135 | 136 | Conditions: 137 | Each basis vector must have non-positive Minkowski squared norms. 138 | There must be at least 2 basis vectors. 139 | The basis vectors must be linearly independent. 140 | """ 141 | minkowski_proj = minkowski.orthogonal_projection(basis, x) # shape (batch_size, Minkowski_dim) 142 | squared_norms = minkowski.squared_norm(minkowski_proj) # shape (batch_size, ) 143 | return minkowski_proj / torch.sqrt(- squared_norms.unsqueeze(1)) 144 | 145 | 146 | def horo_projection(ideals, x): 147 | """Compute the projection based on horosphere intersections. 148 | The target submanifold has dimension num_ideals and is a geodesic submanifold passing through 149 | the ideal points and (1,0,0,0,...), i.e. the point corresponds to the origin in Poincare model. 150 | 151 | Args: 152 | ideals: torch.tensor of shape (num_ideals, Minkowski_dim) 153 | num_ideals must be STRICTLY between 1 and Minkowski_dim 154 | ideal vectors must be independent 155 | the geodesic submanifold spanned by ideals must not contain (1,0,0,...) 156 | 157 | x: torch.tensor of shape (batch_size, Minkowski_dim) 158 | 159 | 160 | Returns: 161 | torch.tensor of shape (batch_size, Minkowski_dim) 162 | """ 163 | 164 | # Compute orthogonal (geodesic) projection from x to the geodesic submanifold spanned by ideals 165 | # We call this submanifold the "spine" because of the "open book" intuition 166 | spine_ortho_proj = orthogonal_projection(ideals, x) # shape (batch_size, Minkowski_dim) 167 | spine_dist = distance(spine_ortho_proj, x) # shape (batch_size, ) 168 | 169 | # poincare_origin = [1,0,0,0,...], # shape (Minkowski_dim, ) 170 | poincare_origin = torch.zeros(x.shape[1], device=x.device) 171 | poincare_origin[0] = 1 172 | 173 | # Find a tangent vector of the hyperboloid at spine_ortho_proj that is tangent to the target submanifold 174 | # and orthogonal to the spine. 175 | # This is done in a Gram-Schmidt way: Take the Euclidean vector pointing from spine_ortho_proj to poincare_origin, 176 | # then subtract a projection part so that it is orthogonal to the spine and tangent to the hyperboloid 177 | # Everything below has shape (batch_size, Minkowski_dim) 178 | chords = poincare_origin - spine_ortho_proj 179 | tangents = chords - minkowski.orthogonal_projection(ideals, chords) 180 | unit_tangents = tangents / torch.sqrt(minkowski.squared_norm(tangents)).view(-1, 1) 181 | 182 | proj_1 = exp_unit_tangents(spine_ortho_proj, unit_tangents, spine_dist) 183 | proj_2 = exp_unit_tangents(spine_ortho_proj, unit_tangents, -spine_dist) 184 | 185 | return proj_1, proj_2 186 | 187 | 188 | def mds(D, d): 189 | """ 190 | Args: 191 | D - (..., n, n) distance matrix 192 | 193 | Returns: 194 | X - (..., n, d) hyperbolic embeddings 195 | """ 196 | Y = -torch.cosh(D) 197 | # print("Y:", Y) 198 | eigenvals, eigenvecs = torch.symeig(Y, eigenvectors=True) 199 | # print(Y.shape, eigenvals.shape, eigenvecs.shape) 200 | # print(eigenvals, eigenvecs) 201 | X = torch.sqrt(torch.clamp(eigenvals[-d:], min=0.)) * eigenvecs[..., -d:] 202 | # print("testing") 203 | # print(X) 204 | # print(Y @ X) 205 | u = torch.sqrt(1 + torch.sum(X * X, dim=-1, keepdim=True)) 206 | M = torch.cat((u, X), dim=-1) 207 | # print(minkowski.pairwise_bilinear_pairing(M, M)) 208 | return torch.cat((u, X), dim=-1) 209 | 210 | 211 | def test(): 212 | ideal = torch.tensor([[1.0, 0, 0, 0], [0.0, 1, 0, 0]]) 213 | x = torch.tensor([[0.2, 0.3, 0.4, 0.5], [0.0, 0, 0, 0], [0.0, 0, 0, 0.7]]) 214 | loid_ideal, loid_x = from_poincare(ideal, True), from_poincare(x) 215 | loid_p1, loid_p2 = horo_projection(loid_ideal, loid_x) 216 | pr1, pr2 = to_poincare(loid_p1), to_poincare(loid_p2) 217 | print(pr1) 218 | print(pr2) 219 | 220 | # ideals = torch.tensor([[3.0,3.0,0.0], [5.0,-5.0,0.0]]) 221 | # x = torch.tensor([[5.0,0.0,math.sqrt(24)],[2.0,-math.sqrt(3), 0]]) 222 | # print(orthogonal_projection(ideals, x)) 223 | 224 | ideals = torch.tensor([[1.0, 1.0, 0.0], [5.0, 3, 4]]) 225 | x = torch.tensor([[5.0, 0, 24 ** 0.5], [2.0, - 3 ** 0.5, 0]]) 226 | print(horo_projection(ideals, x)) 227 | 228 | 229 | def test_mds(n=100, d=10): 230 | X = torch.randn(n, d) 231 | X = X / torch.norm(X, dim=-1, keepdim=True) * 0.9 232 | X = from_poincare(X) 233 | # print(X.shape) 234 | D = distance(X.unsqueeze(-2), X.unsqueeze(-3)) 235 | # print(D.shape) 236 | # print(D-D.transpose(0,1)) 237 | 238 | X_ = mds(D, d) 239 | # print(X_.shape) 240 | D_ = distance(X_.unsqueeze(-2), X_.unsqueeze(-3)) 241 | print(D - D_) 242 | 243 | 244 | def test_projection(): 245 | """ Test that orthogonal projection agrees with the Poincare disk version. """ 246 | d = 5 247 | # x = torch.randn(1, d) * 0.01 248 | x = poincare.random_points((1, d)) 249 | # Q = torch.randn(2, d) 250 | # Q = Q / torch.norm(Q, dim=-1, keepdim=True) 251 | Q = poincare.random_ideals((2, d)) 252 | 253 | # poincare projection 254 | import geom.poincare as P 255 | from geom.euclidean import orthonormal 256 | Q = orthonormal(Q) 257 | x_r = P.reflect(x, Q) 258 | p = P.midpoint(x, x_r) 259 | print(p) 260 | 261 | # hyperboloid projection 262 | Q = torch.cat([Q, torch.zeros(1, d)], dim=0) 263 | p_ = orthogonal_projection(from_poincare(Q, ideal=True), from_poincare(x)) 264 | print(to_poincare(p_)) 265 | 266 | 267 | # Sanity checks 268 | if __name__ == "__main__": 269 | # test() 270 | # test_mds(n=100, d=10) 271 | 272 | poincare_origin = torch.zeros(3) 273 | poincare_origin[0] = 1 274 | print(from_poincare(poincare_origin, ideal=True)) 275 | print(to_poincare(from_poincare(poincare_origin, ideal=True), ideal=True)) 276 | 277 | test_projection() 278 | -------------------------------------------------------------------------------- /geom/minkowski.py: -------------------------------------------------------------------------------- 1 | """ Util functions for the Minkowski metric. 2 | 3 | Note that functions for the hyperboloid model itself are in geom.hyperboloid 4 | 5 | Most functions in this file has a bilinear_form argument that can generally be ignored. 6 | That argument is there just in case we need to use a non-standard norm/signature. 7 | """ 8 | import torch 9 | 10 | 11 | def product(x, y): 12 | eucl_pairing = torch.sum(x * y, dim=-1, keepdim=False) 13 | return 2 * x[..., 0] * y[..., 0] - eucl_pairing 14 | 15 | 16 | def bilinear_pairing(x, y, bilinear_form=None): 17 | """Compute the bilinear pairing (i.e. "dot product") of x and y using the given bilinear form. 18 | If bilinear_form is not provided, use the default Minkowski form, 19 | i.e. (x0, x1, x2) dot (y0, y1, y2) = -x0*y0 + x1*y1 + x2*y2 20 | 21 | Args: 22 | x, y: torch.tensor of the same shape (..., dim), where dim >= 2 23 | bilinear_form (optional): torch.tensor of shape (dim, dim) 24 | 25 | Returns: 26 | torch.tensor of shape (...) 27 | """ 28 | if bilinear_form is None: 29 | eucl_pairing = torch.sum(x * y, dim=-1, keepdim=False) 30 | return eucl_pairing - 2 * x[..., 0] * y[..., 0] 31 | else: 32 | pairing = torch.matmul(x.unsqueeze(-2), (y @ bilinear_form).unsqueeze(-1)) # shape (..., 1, 1) 33 | return pairing.reshape(x.shape[:-1]) 34 | 35 | 36 | def squared_norm(x, bilinear_form=None): 37 | return bilinear_pairing(x, x, bilinear_form) 38 | 39 | 40 | def pairwise_bilinear_pairing(x, y, bilinear_form=None): 41 | """Compute the pairwise bilinear pairings (i.e. "dot product") of two list of vectors 42 | with respect to the given bilinear form. 43 | If bilinear_form is not provided, use the default Minkowski form, 44 | i.e. (x0, x1, x2) dot (y0, y1, y2) = -x0*y0 + x1*y1 + x2*y2 45 | 46 | Args: 47 | x: torch.tensor of shape (..., M, dim), where dim >= 2 48 | y: torch.tensor of shape (..., N, dim), where dim >= 2 49 | bilinear_form (optional): torch.tensor of shape (dim, dim). 50 | 51 | Returns: 52 | torch.tensor of shape (..., M, N) 53 | """ 54 | if bilinear_form is None: 55 | return x @ y.transpose(-1, -2) - 2 * torch.ger(x[:, 0], y[:, 0]) 56 | else: 57 | return x @ bilinear_form @ y.transpose(-1, -2) 58 | 59 | 60 | def orthogonal_projection(basis, x, bilinear_form=None): 61 | """Compute the orthogonal projection of x onto the vector subspace spanned by basis. 62 | Here orthogonality is defined using the given bilinear_form 63 | If bilinear_form is not provided, use the default Minkowski form, 64 | i.e. (x0, x1, x2) dot (y0, y1, y2) = -x0*y0 + x1*y1 + x2*y2 65 | 66 | Args: 67 | basis: torch.tensor of shape (subspace_dim, dim), where dim >= 2 68 | x: torch.tensor of shape (batch_size, dim), where dim >= 2 69 | bilinear_form (optional): torch.tensor of shape (dim, dim). 70 | 71 | Returns: 72 | torch.tensor of shape (batch_size, dim) 73 | 74 | Warning: Will not work if the linear subspace spanned by basis is tangent to the light cone. 75 | (In that case, the orthogonal projection is not unique) 76 | """ 77 | coefs, _ = torch.solve(pairwise_bilinear_pairing(basis, x, bilinear_form), 78 | pairwise_bilinear_pairing(basis, basis, bilinear_form)) 79 | 80 | return coefs.transpose(-1, -2) @ basis 81 | 82 | 83 | def reflection(subspace, x, subspace_given_by_normal=True, bilinear_form=None): 84 | """Compute the reflection of x through a linear subspace (of dimension 1 less than the ambient space) 85 | Here reflection is defined using the notion of orthogonality coming from the given bilinear_form 86 | If bilinear_form is not provided, use the default Minkowski form, 87 | i.e. (x0, x1, x2) dot (y0, y1, y2) = -x0*y0 + x1*y1 + x2*y2 88 | 89 | Args: 90 | subspace: If subspace_given_by_normal: 91 | torch.tensor of shape (dim, ), representing a normal vector to the subspace 92 | Else: 93 | torch.tensor of shape (dim-1, dim), representing a basis of the subspace 94 | x: torch.tensor of shape (batch_size, dim) 95 | bilinear_form (optional): torch.tensor of shape (dim, dim). 96 | 97 | Returns: 98 | torch.tensor of shape (batch_size, dim) 99 | 100 | Warning: Will not work if the linear subspace is tangent to the light cone. 101 | (In that case, the reflection is not unique) 102 | """ 103 | if subspace_given_by_normal: 104 | return x - 2 * orthogonal_projection(subspace.unsqueeze(0), x, bilinear_form) 105 | else: 106 | return 2 * orthogonal_projection(subspace, x, bilinear_form) - x 107 | -------------------------------------------------------------------------------- /geom/nn.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | 7 | from . import pmath 8 | 9 | 10 | class HyperbolicMLR(nn.Module): 11 | r""" 12 | Module which performs softmax classification 13 | in Hyperbolic space. 14 | """ 15 | 16 | def __init__(self, ball_dim, n_classes, c): 17 | super(HyperbolicMLR, self).__init__() 18 | self.a_vals = nn.Parameter(torch.Tensor(n_classes, ball_dim)) 19 | self.p_vals = nn.Parameter(torch.Tensor(n_classes, ball_dim)) 20 | self.c = c 21 | self.n_classes = n_classes 22 | self.ball_dim = ball_dim 23 | self.reset_parameters() 24 | 25 | def forward(self, x, c=None): 26 | if c is None: 27 | c = torch.as_tensor(self.c).type_as(x) 28 | else: 29 | c = torch.as_tensor(c).type_as(x) 30 | p_vals_poincare = pmath.expmap0(self.p_vals, c=c) 31 | conformal_factor = 1 - c * p_vals_poincare.pow(2).sum(dim=1, keepdim=True) 32 | a_vals_poincare = self.a_vals * conformal_factor 33 | logits = pmath._hyperbolic_softmax(x, a_vals_poincare, p_vals_poincare, c) 34 | return logits 35 | 36 | def extra_repr(self): 37 | return "Poincare ball dim={}, n_classes={}, c={}".format( 38 | self.ball_dim, self.n_classes, self.c 39 | ) 40 | 41 | def reset_parameters(self): 42 | init.kaiming_uniform_(self.a_vals, a=math.sqrt(5)) 43 | init.kaiming_uniform_(self.p_vals, a=math.sqrt(5)) 44 | 45 | class HypLinear(nn.Module): 46 | def __init__(self, in_features, out_features, c, bias=True): 47 | super(HypLinear, self).__init__() 48 | self.in_features = in_features 49 | self.out_features = out_features 50 | self.c = c 51 | self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) 52 | if bias: 53 | self.bias = nn.Parameter(torch.Tensor(out_features)) 54 | else: 55 | self.register_parameter("bias", None) 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 60 | if self.bias is not None: 61 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) 62 | bound = 1 / math.sqrt(fan_in) 63 | init.uniform_(self.bias, -bound, bound) 64 | 65 | def forward(self, x, c=None): 66 | if c is None: 67 | c = self.c 68 | mv = pmath.mobius_matvec(self.weight, x, c=c) 69 | if self.bias is None: 70 | return pmath.project(mv, c=c) 71 | else: 72 | bias = pmath.expmap0(self.bias, c=c) 73 | return pmath.project(pmath.mobius_add(mv, bias), c=c) 74 | 75 | def extra_repr(self): 76 | return "in_features={}, out_features={}, bias={}, c={}".format( 77 | self.in_features, self.out_features, self.bias is not None, self.c 78 | ) 79 | 80 | class HypDistanceLinearLayer(nn.Module): 81 | def __init__(self, in_features, out_features, c, train_c=False, train_x=False, riemannian=False, clip_r=None): 82 | super(HypDistanceLinearLayer, self).__init__() 83 | self.in_features = in_features 84 | self.out_features = out_features 85 | self.weight = torch.nn.Parameter(torch.zeros(out_features, in_features)) 86 | self.to_poincare = ToPoincare(c, train_c, train_x, in_features, riemannian, clip_r) 87 | self.c = self.to_poincare.c 88 | self.w_scale = 1 89 | self.reset_parameters() 90 | 91 | def reset_parameters(self): 92 | init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 93 | 94 | def forward(self, x, c=None): 95 | # maxnorm = (1 - 1e-5) / (self.c ** 0.5) 96 | # w_norm = torch.norm(self.weight, dim=-1, keepdim=True) + 1e-5 97 | # weight = self.to_poincare(self.weight / w_norm * maxnorm * self.w_scale) 98 | 99 | # weight = F.normalize(self.weight, dim=-1) 100 | weight = self.to_poincare(self.weight) 101 | mv = pmath.dist_matrix(x, weight, c=self.c) 102 | return -mv 103 | 104 | def parent_order_penalty_cdist(parent, child, mrg, c=None): 105 | """Penalty for parents to have smaller norm than children.""" 106 | return torch.clip(pmath._dist0(parent, c=c, keepdim=True).t() - pmath._dist0(child, c=c, keepdim=True) + mrg, 0) + 1.0 107 | 108 | class HypHCLinearLayer(nn.Module): 109 | def __init__(self, in_features, out_features, c, train_c=False, train_x=False, riemannian=False, clip_r=None): 110 | super(HypHCLinearLayer, self).__init__() 111 | self.in_features = in_features 112 | self.out_features = out_features 113 | self.weight = torch.nn.Parameter(torch.zeros(out_features, in_features)) 114 | self.to_poincare = ToPoincare(c, train_c, train_x, in_features, riemannian, clip_r) 115 | self.c = self.to_poincare.c 116 | self.gamma = 0.25 117 | self.reset_parameters() 118 | 119 | def reset_parameters(self): 120 | nn.init.orthogonal_(self.weight) 121 | 122 | def p_par_to_broadcast(self, x, weight): 123 | weight = self.to_poincare(self.weight) 124 | dist = pmath.dist_matrix(x, weight, c=self.c) 125 | res = dist * parent_order_penalty_cdist(weight, x, self.gamma, c=self.c) 126 | return res 127 | 128 | def forward(self, x, c=None): 129 | dist = self.p_par_to_broadcast(x, self.weight) 130 | return dist 131 | 132 | 133 | class ConcatPoincareLayer(nn.Module): 134 | def __init__(self, d1, d2, d_out, c): 135 | super(ConcatPoincareLayer, self).__init__() 136 | self.d1 = d1 137 | self.d2 = d2 138 | self.d_out = d_out 139 | 140 | self.l1 = HypLinear(d1, d_out, bias=False, c=c) 141 | self.l2 = HypLinear(d2, d_out, bias=False, c=c) 142 | self.c = c 143 | 144 | def forward(self, x1, x2, c=None): 145 | if c is None: 146 | c = self.c 147 | return pmath.mobius_add(self.l1(x1), self.l2(x2), c=c) 148 | 149 | def extra_repr(self): 150 | return "dims {} and {} ---> dim {}".format(self.d1, self.d2, self.d_out) 151 | 152 | 153 | class HyperbolicDistanceLayer(nn.Module): 154 | def __init__(self, c): 155 | super(HyperbolicDistanceLayer, self).__init__() 156 | self.c = c 157 | 158 | def forward(self, x1, x2, c=None): 159 | if c is None: 160 | c = self.c 161 | return pmath.dist(x1, x2, c=c, keepdim=True) 162 | 163 | def extra_repr(self): 164 | return "c={}".format(self.c) 165 | 166 | 167 | class ToPoincare(nn.Module): 168 | r""" 169 | Module which maps points in n-dim Euclidean space 170 | to n-dim Poincare ball 171 | Also implements clipping from https://arxiv.org/pdf/2107.11472.pdf 172 | """ 173 | 174 | def __init__(self, c, train_c=False, train_x=False, ball_dim=None, riemannian=True, clip_r=None): 175 | super(ToPoincare, self).__init__() 176 | if train_x: 177 | if ball_dim is None: 178 | raise ValueError( 179 | "if train_x=True, ball_dim has to be integer, got {}".format( 180 | ball_dim 181 | ) 182 | ) 183 | self.xp = nn.Parameter(torch.zeros((ball_dim,))) 184 | else: 185 | self.register_parameter("xp", None) 186 | 187 | if train_c: 188 | self.c = nn.Parameter(torch.Tensor([c,])) 189 | else: 190 | self.c = c 191 | 192 | self.train_x = train_x 193 | 194 | self.riemannian = pmath.RiemannianGradient 195 | self.riemannian.c = c 196 | 197 | self.clip_r = clip_r 198 | 199 | if riemannian: 200 | self.grad_fix = lambda x: self.riemannian.apply(x) 201 | else: 202 | self.grad_fix = lambda x: x 203 | 204 | def forward(self, x): 205 | if self.clip_r is not None: 206 | x_norm = torch.norm(x, dim=-1, keepdim=True) + 1e-5 207 | fac = torch.minimum( 208 | torch.ones_like(x_norm), 209 | self.clip_r / x_norm 210 | ) 211 | x = x * fac 212 | 213 | if self.train_x: 214 | xp = pmath.project(pmath.expmap0(self.xp, c=self.c), c=self.c) 215 | return self.grad_fix(pmath.project(pmath.expmap(xp, x, c=self.c), c=self.c)) 216 | return self.grad_fix(pmath.project(pmath.expmap0(x, c=self.c), c=self.c)) 217 | 218 | def extra_repr(self): 219 | return "c={}, train_x={}".format(self.c, self.train_x) 220 | 221 | 222 | class FromPoincare(nn.Module): 223 | r""" 224 | Module which maps points in n-dim Poincare ball 225 | to n-dim Euclidean space 226 | """ 227 | 228 | def __init__(self, c, train_c=False, train_x=False, ball_dim=None): 229 | 230 | super(FromPoincare, self).__init__() 231 | 232 | if train_x: 233 | if ball_dim is None: 234 | raise ValueError( 235 | "if train_x=True, ball_dim has to be integer, got {}".format( 236 | ball_dim 237 | ) 238 | ) 239 | self.xp = nn.Parameter(torch.zeros((ball_dim,))) 240 | else: 241 | self.register_parameter("xp", None) 242 | 243 | if train_c: 244 | self.c = nn.Parameter(torch.Tensor([c,])) 245 | else: 246 | self.c = c 247 | 248 | self.train_c = train_c 249 | self.train_x = train_x 250 | 251 | def forward(self, x): 252 | if self.train_x: 253 | xp = pmath.project(pmath.expmap0(self.xp, c=self.c), c=self.c) 254 | return pmath.logmap(xp, x, c=self.c) 255 | return pmath.logmap0(x, c=self.c) 256 | 257 | def extra_repr(self): 258 | return "train_c={}, train_x={}".format(self.train_c, self.train_x) 259 | 260 | -------------------------------------------------------------------------------- /geom/pmath.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of various mathematical operations in the Poincare ball model of hyperbolic space. Some 3 | functions are based on the implementation in https://github.com/geoopt/geoopt (copyright by Maxim Kochurov). 4 | """ 5 | 6 | import numpy as np 7 | import torch 8 | from scipy.special import gamma 9 | 10 | 11 | def tanh(x, clamp=15): 12 | return x.clamp(-clamp, clamp).tanh() 13 | 14 | 15 | # + 16 | class Artanh(torch.autograd.Function): 17 | @staticmethod 18 | def forward(ctx, x): 19 | x = x.clamp(-1 + 1e-5, 1 - 1e-5) 20 | ctx.save_for_backward(x) 21 | res = (torch.log_(1 + x).sub_(torch.log_(1 - x))).mul_(0.5) 22 | return res 23 | 24 | @staticmethod 25 | def backward(ctx, grad_output): 26 | (input,) = ctx.saved_tensors 27 | return grad_output / (1 - input ** 2) 28 | 29 | 30 | class RiemannianGradient(torch.autograd.Function): 31 | 32 | c = 1 33 | 34 | @staticmethod 35 | def forward(ctx, x): 36 | ctx.save_for_backward(x) 37 | return x 38 | 39 | @staticmethod 40 | def backward(ctx, grad_output): 41 | (x,) = ctx.saved_tensors 42 | # x: B x d 43 | 44 | scale = (1 - RiemannianGradient.c * x.pow(2).sum(-1, keepdim=True)).pow(2) / 4 45 | return grad_output * scale 46 | 47 | 48 | class Arsinh(torch.autograd.Function): 49 | @staticmethod 50 | def forward(ctx, x): 51 | ctx.save_for_backward(x) 52 | return (x + torch.sqrt_(1 + x.pow(2))).clamp_min_(1e-5).log_() 53 | 54 | @staticmethod 55 | def backward(ctx, grad_output): 56 | (input,) = ctx.saved_tensors 57 | return grad_output / (1 + input ** 2) ** 0.5 58 | 59 | 60 | def artanh(x): 61 | return Artanh.apply(x) 62 | 63 | 64 | def arsinh(x): 65 | return Arsinh.apply(x) 66 | 67 | 68 | def arcosh(x, eps=1e-5): # pragma: no cover 69 | x = x.clamp(-1 + eps, 1 - eps) 70 | return torch.log(x + torch.sqrt(1 + x) * torch.sqrt(x - 1)) 71 | 72 | 73 | def project(x, *, c=1.0): 74 | r""" 75 | Safe projection on the manifold for numerical stability. This was mentioned in [1]_ 76 | Parameters 77 | ---------- 78 | x : tensor 79 | point on the Poincare ball 80 | c : float|tensor 81 | ball negative curvature 82 | Returns 83 | ------- 84 | tensor 85 | projected vector on the manifold 86 | References 87 | ---------- 88 | .. [1] Hyperbolic Neural Networks, NIPS2018 89 | https://arxiv.org/abs/1805.09112 90 | """ 91 | c = torch.as_tensor(c).type_as(x) 92 | return _project(x, c) 93 | 94 | 95 | def _project(x, c): 96 | norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5) 97 | maxnorm = (1 - 1e-3) / (c ** 0.5) 98 | cond = norm > maxnorm 99 | projected = x / norm * maxnorm 100 | return torch.where(cond, projected, x) 101 | 102 | 103 | def lambda_x(x, *, c=1.0, keepdim=False): 104 | r""" 105 | Compute the conformal factor :math:`\lambda^c_x` for a point on the ball 106 | .. math:: 107 | \lambda^c_x = \frac{1}{1 - c \|x\|_2^2} 108 | Parameters 109 | ---------- 110 | x : tensor 111 | point on the Poincare ball 112 | c : float|tensor 113 | ball negative curvature 114 | keepdim : bool 115 | retain the last dim? (default: false) 116 | Returns 117 | ------- 118 | tensor 119 | conformal factor 120 | """ 121 | c = torch.as_tensor(c).type_as(x) 122 | return _lambda_x(x, c, keepdim=keepdim) 123 | 124 | 125 | def _lambda_x(x, c, keepdim: bool = False): 126 | return 2 / (1 - c * x.pow(2).sum(-1, keepdim=keepdim)) 127 | 128 | 129 | def mobius_add(x, y, *, c=1.0): 130 | r""" 131 | Mobius addition is a special operation in a hyperbolic space. 132 | .. math:: 133 | x \oplus_c y = \frac{ 134 | (1 + 2 c \langle x, y\rangle + c \|y\|^2_2) x + (1 - c \|x\|_2^2) y 135 | }{ 136 | 1 + 2 c \langle x, y\rangle + c^2 \|x\|^2_2 \|y\|^2_2 137 | } 138 | In general this operation is not commutative: 139 | .. math:: 140 | x \oplus_c y \ne y \oplus_c x 141 | But in some cases this property holds: 142 | * zero vector case 143 | .. math:: 144 | \mathbf{0} \oplus_c x = x \oplus_c \mathbf{0} 145 | * zero negative curvature case that is same as Euclidean addition 146 | .. math:: 147 | x \oplus_0 y = y \oplus_0 x 148 | Another usefull property is so called left-cancellation law: 149 | .. math:: 150 | (-x) \oplus_c (x \oplus_c y) = y 151 | Parameters 152 | ---------- 153 | x : tensor 154 | point on the Poincare ball 155 | y : tensor 156 | point on the Poincare ball 157 | c : float|tensor 158 | ball negative curvature 159 | Returns 160 | ------- 161 | tensor 162 | the result of mobius addition 163 | """ 164 | c = torch.as_tensor(c).type_as(x) 165 | return _mobius_add(x, y, c) 166 | 167 | 168 | def _mobius_add(x, y, c): 169 | x2 = x.pow(2).sum(dim=-1, keepdim=True) 170 | y2 = y.pow(2).sum(dim=-1, keepdim=True) 171 | xy = (x * y).sum(dim=-1, keepdim=True) 172 | num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y 173 | denom = 1 + 2 * c * xy + c ** 2 * x2 * y2 174 | return num / (denom + 1e-5) 175 | 176 | 177 | def dist(x, y, *, c=1.0, keepdim=False): 178 | r""" 179 | Distance on the Poincare ball 180 | .. math:: 181 | d_c(x, y) = \frac{2}{\sqrt{c}}\tanh^{-1}(\sqrt{c}\|(-x)\oplus_c y\|_2) 182 | .. plot:: plots/extended/poincare/distance.py 183 | Parameters 184 | ---------- 185 | x : tensor 186 | point on poincare ball 187 | y : tensor 188 | point on poincare ball 189 | c : float|tensor 190 | ball negative curvature 191 | keepdim : bool 192 | retain the last dim? (default: false) 193 | Returns 194 | ------- 195 | tensor 196 | geodesic distance between :math:`x` and :math:`y` 197 | """ 198 | c = torch.as_tensor(c).type_as(x) 199 | return _dist(x, y, c, keepdim=keepdim) 200 | 201 | 202 | def _dist(x, y, c, keepdim: bool = False): 203 | sqrt_c = c ** 0.5 204 | dist_c = artanh(sqrt_c * _mobius_add(-x, y, c).norm(dim=-1, p=2, keepdim=keepdim)) 205 | return dist_c * 2 / sqrt_c 206 | 207 | 208 | def dist0(x, *, c=1.0, keepdim=False): 209 | r""" 210 | Distance on the Poincare ball to zero 211 | Parameters 212 | ---------- 213 | x : tensor 214 | point on poincare ball 215 | c : float|tensor 216 | ball negative curvature 217 | keepdim : bool 218 | retain the last dim? (default: false) 219 | Returns 220 | ------- 221 | tensor 222 | geodesic distance between :math:`x` and :math:`0` 223 | """ 224 | c = torch.as_tensor(c).type_as(x) 225 | return _dist0(x, c, keepdim=keepdim) 226 | 227 | 228 | def _dist0(x, c, keepdim: bool = False): 229 | sqrt_c = c ** 0.5 230 | dist_c = artanh(sqrt_c * x.norm(dim=-1, p=2, keepdim=keepdim)) 231 | return dist_c * 2 / sqrt_c 232 | 233 | 234 | def expmap(x, u, *, c=1.0): 235 | r""" 236 | Exponential map for Poincare ball model. This is tightly related with :func:`geodesic`. 237 | Intuitively Exponential map is a smooth constant travelling from starting point :math:`x` with speed :math:`u`. 238 | A bit more formally this is travelling along curve :math:`\gamma_{x, u}(t)` such that 239 | .. math:: 240 | \gamma_{x, u}(0) = x\\ 241 | \dot\gamma_{x, u}(0) = u\\ 242 | \|\dot\gamma_{x, u}(t)\|_{\gamma_{x, u}(t)} = \|u\|_x 243 | The existence of this curve relies on uniqueness of differential equation solution, that is local. 244 | For the Poincare ball model the solution is well defined globally and we have. 245 | .. math:: 246 | \operatorname{Exp}^c_x(u) = \gamma_{x, u}(1) = \\ 247 | x\oplus_c \tanh(\sqrt{c}/2 \|u\|_x) \frac{u}{\sqrt{c}\|u\|_2} 248 | Parameters 249 | ---------- 250 | x : tensor 251 | starting point on poincare ball 252 | u : tensor 253 | speed vector on poincare ball 254 | c : float|tensor 255 | ball negative curvature 256 | Returns 257 | ------- 258 | tensor 259 | :math:`\gamma_{x, u}(1)` end point 260 | """ 261 | c = torch.as_tensor(c).type_as(x) 262 | return _expmap(x, u, c) 263 | 264 | 265 | def _expmap(x, u, c): # pragma: no cover 266 | sqrt_c = c ** 0.5 267 | u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5) 268 | second_term = ( 269 | tanh(sqrt_c / 2 * _lambda_x(x, c, keepdim=True) * u_norm) 270 | * u 271 | / (sqrt_c * u_norm) 272 | ) 273 | gamma_1 = _mobius_add(x, second_term, c) 274 | return gamma_1 275 | 276 | 277 | def expmap0(u, *, c=1.0): 278 | r""" 279 | Exponential map for Poincare ball model from :math:`0`. 280 | .. math:: 281 | \operatorname{Exp}^c_0(u) = \tanh(\sqrt{c}/2 \|u\|_2) \frac{u}{\sqrt{c}\|u\|_2} 282 | Parameters 283 | ---------- 284 | u : tensor 285 | speed vector on poincare ball 286 | c : float|tensor 287 | ball negative curvature 288 | Returns 289 | ------- 290 | tensor 291 | :math:`\gamma_{0, u}(1)` end point 292 | """ 293 | c = torch.as_tensor(c).type_as(u) 294 | return _expmap0(u, c) 295 | 296 | 297 | def _expmap0(u, c): 298 | sqrt_c = c ** 0.5 299 | u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5) 300 | gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm) 301 | return gamma_1 302 | 303 | 304 | def logmap(x, y, *, c=1.0): 305 | r""" 306 | Logarithmic map for two points :math:`x` and :math:`y` on the manifold. 307 | .. math:: 308 | \operatorname{Log}^c_x(y) = \frac{2}{\sqrt{c}\lambda_x^c} \tanh^{-1}( 309 | \sqrt{c} \|(-x)\oplus_c y\|_2 310 | ) * \frac{(-x)\oplus_c y}{\|(-x)\oplus_c y\|_2} 311 | The result of Logarithmic map is a vector such that 312 | .. math:: 313 | y = \operatorname{Exp}^c_x(\operatorname{Log}^c_x(y)) 314 | Parameters 315 | ---------- 316 | x : tensor 317 | starting point on poincare ball 318 | y : tensor 319 | target point on poincare ball 320 | c : float|tensor 321 | ball negative curvature 322 | Returns 323 | ------- 324 | tensor 325 | tangent vector that transports :math:`x` to :math:`y` 326 | """ 327 | c = torch.as_tensor(c).type_as(x) 328 | return _logmap(x, y, c) 329 | 330 | 331 | def _logmap(x, y, c): # pragma: no cover 332 | sub = _mobius_add(-x, y, c) 333 | sub_norm = sub.norm(dim=-1, p=2, keepdim=True) 334 | lam = _lambda_x(x, c, keepdim=True) 335 | sqrt_c = c ** 0.5 336 | return 2 / sqrt_c / lam * artanh(sqrt_c * sub_norm) * sub / sub_norm 337 | 338 | 339 | def logmap0(y, *, c=1.0): 340 | r""" 341 | Logarithmic map for :math:`y` from :math:`0` on the manifold. 342 | .. math:: 343 | \operatorname{Log}^c_0(y) = \tanh^{-1}(\sqrt{c}\|y\|_2) \frac{y}{\|y\|_2} 344 | The result is such that 345 | .. math:: 346 | y = \operatorname{Exp}^c_0(\operatorname{Log}^c_0(y)) 347 | Parameters 348 | ---------- 349 | y : tensor 350 | target point on poincare ball 351 | c : float|tensor 352 | ball negative curvature 353 | Returns 354 | ------- 355 | tensor 356 | tangent vector that transports :math:`0` to :math:`y` 357 | """ 358 | c = torch.as_tensor(c).type_as(y) 359 | return _logmap0(y, c) 360 | 361 | 362 | def _logmap0(y, c): 363 | sqrt_c = c ** 0.5 364 | y_norm = torch.clamp_min(y.norm(dim=-1, p=2, keepdim=True), 1e-5) 365 | return y / y_norm / sqrt_c * artanh(sqrt_c * y_norm) 366 | 367 | 368 | def mobius_matvec(m, x, *, c=1.0): 369 | r""" 370 | Generalization for matrix-vector multiplication to hyperbolic space defined as 371 | .. math:: 372 | M \otimes_c x = (1/\sqrt{c}) \tanh\left( 373 | \frac{\|Mx\|_2}{\|x\|_2}\tanh^{-1}(\sqrt{c}\|x\|_2) 374 | \right)\frac{Mx}{\|Mx\|_2} 375 | Parameters 376 | ---------- 377 | m : tensor 378 | matrix for multiplication 379 | x : tensor 380 | point on poincare ball 381 | c : float|tensor 382 | negative ball curvature 383 | Returns 384 | ------- 385 | tensor 386 | Mobius matvec result 387 | """ 388 | c = torch.as_tensor(c).type_as(x) 389 | return _mobius_matvec(m, x, c) 390 | 391 | 392 | def _mobius_matvec(m, x, c): 393 | x_norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5) 394 | sqrt_c = c ** 0.5 395 | mx = x @ m.transpose(-1, -2) 396 | mx_norm = mx.norm(dim=-1, keepdim=True, p=2) 397 | res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c) 398 | cond = (mx == 0).prod(-1, keepdim=True, dtype=torch.uint8) 399 | res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device) 400 | res = torch.where(cond, res_0, res_c) 401 | return _project(res, c) 402 | 403 | 404 | def _tensor_dot(x, y): 405 | res = torch.einsum("ij,kj->ik", (x, y)) 406 | return res 407 | 408 | 409 | def _mobius_addition_batch(x, y, c): 410 | xy = _tensor_dot(x, y) # B x C 411 | x2 = x.pow(2).sum(-1, keepdim=True) # B x 1 412 | y2 = y.pow(2).sum(-1, keepdim=True) # C x 1 413 | num = 1 + 2 * c * xy + c * y2.permute(1, 0) # B x C 414 | num = num.unsqueeze(2) * x.unsqueeze(1) 415 | num = num + (1 - c * x2).unsqueeze(2) * y # B x C x D 416 | denom_part1 = 1 + 2 * c * xy # B x C 417 | denom_part2 = c ** 2 * x2 * y2.permute(1, 0) 418 | denom = denom_part1 + denom_part2 419 | res = num / (denom.unsqueeze(2) + 1e-5) 420 | return res 421 | 422 | 423 | def _hyperbolic_softmax(X, A, P, c): 424 | lambda_pkc = 2 / (1 - c * P.pow(2).sum(dim=1)) 425 | k = lambda_pkc * torch.norm(A, dim=1) / torch.sqrt(c) 426 | mob_add = _mobius_addition_batch(-P, X, c) 427 | num = 2 * torch.sqrt(c) * torch.sum(mob_add * A.unsqueeze(1), dim=-1) 428 | denom = torch.norm(A, dim=1, keepdim=True) * (1 - c * mob_add.pow(2).sum(dim=2)) 429 | logit = k.unsqueeze(1) * arsinh(num / denom) 430 | return logit.permute(1, 0) 431 | 432 | def p2k(x, c): 433 | denom = 1 + c * x.pow(2).sum(-1, keepdim=True) 434 | return 2 * x / denom 435 | 436 | 437 | def k2p(x, c): 438 | denom = 1 + torch.sqrt(1 - c * x.pow(2).sum(-1, keepdim=True)) 439 | return x / denom 440 | 441 | 442 | def lorenz_factor(x, *, c=1.0, dim=-1, keepdim=False): 443 | """ 444 | 445 | Parameters 446 | ---------- 447 | x : tensor 448 | point on Klein disk 449 | c : float 450 | negative curvature 451 | dim : int 452 | dimension to calculate Lorenz factor 453 | keepdim : bool 454 | retain the last dim? (default: false) 455 | 456 | Returns 457 | ------- 458 | tensor 459 | Lorenz factor 460 | """ 461 | return 1 / torch.sqrt(1 - c * x.pow(2).sum(dim=dim, keepdim=keepdim)) 462 | 463 | 464 | def poincare_mean(x, dim=0, c=1.0): 465 | x = p2k(x, c) 466 | lamb = lorenz_factor(x, c=c, keepdim=True) 467 | mean = torch.sum(lamb * x, dim=dim, keepdim=True) / torch.sum( 468 | lamb, dim=dim, keepdim=True 469 | ) 470 | mean = k2p(mean, c) 471 | return mean.squeeze(dim) 472 | 473 | 474 | def _dist_matrix(x, y, c): 475 | sqrt_c = c ** 0.5 476 | return ( 477 | 2 478 | / sqrt_c 479 | * artanh(sqrt_c * torch.norm(_mobius_addition_batch(-x, y, c=c), dim=-1)) 480 | ) 481 | 482 | 483 | def dist_matrix(x, y, c=1.0): 484 | c = torch.as_tensor(c).type_as(x) 485 | return _dist_matrix(x, y, c) 486 | 487 | 488 | def auto_select_c(d): 489 | """ 490 | calculates the radius of the Poincare ball, 491 | such that the d-dimensional ball has constant volume equal to pi 492 | """ 493 | dim2 = d / 2.0 494 | R = gamma(dim2 + 1) / (np.pi ** (dim2 - 1)) 495 | R = R ** (1 / float(d)) 496 | c = 1 / (R ** 2) 497 | return c 498 | -------------------------------------------------------------------------------- /geom/poincare.py: -------------------------------------------------------------------------------- 1 | """Poincare utils functions.""" 2 | 3 | import torch 4 | 5 | import geom.euclidean as euclidean 6 | 7 | MIN_NORM = 1e-15 8 | BALL_EPS = {torch.float32: 4e-3, torch.float64: 1e-5} 9 | 10 | 11 | def expmap0(u): 12 | """Exponential map taken at the origin of the Poincare ball with curvature c. 13 | Args: 14 | u: torch.Tensor of size B x d with hyperbolic points 15 | c: torch.Tensor of size 1 or B x 1 with absolute hyperbolic curvatures 16 | Returns: 17 | torch.Tensor with tangent points shape (B, d) 18 | """ 19 | u_norm = u.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM) 20 | gamma_1 = torch.tanh(u_norm) * u / u_norm 21 | return project(gamma_1) 22 | 23 | 24 | def logmap0(y): 25 | """Logarithmic map taken at the origin of the Poincare ball with curvature c. 26 | Args: 27 | y: torch.Tensor of size B x d with tangent points 28 | c: torch.Tensor of size 1 or B x 1 with absolute hyperbolic curvatures 29 | Returns: 30 | torch.Tensor with hyperbolic points. 31 | """ 32 | y_norm = y.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM) 33 | return y / y_norm / 1. * torch.atanh(y_norm.clamp(-1 + 1e-15, 1 - 1e-15)) 34 | 35 | 36 | def expmap(x, u): 37 | u_norm = u.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM) 38 | second_term = torch.tanh(lambda_(x) * u_norm / 2) * u / u_norm 39 | gamma_1 = mobius_add(x, second_term) 40 | return gamma_1 41 | 42 | 43 | def logmap(x, y): 44 | sub = mobius_add(-x, y) 45 | sub_norm = sub.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM).clamp_max(1 - 1e-15) 46 | return 2 / lambda_(x) * torch.atanh(sub_norm) * sub / sub_norm 47 | 48 | 49 | def lambda_(x): 50 | """Computes the conformal factor.""" 51 | x_sqnorm = torch.sum(x.data.pow(2), dim=-1, keepdim=True) 52 | return 2 / (1. - x_sqnorm).clamp_min(MIN_NORM) 53 | 54 | 55 | def project(x): 56 | """Project points to Poincare ball with curvature c. 57 | Args: 58 | x: torch.Tensor of size B x d with hyperbolic points 59 | Returns: 60 | torch.Tensor with projected hyperbolic points. 61 | """ 62 | norm = x.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM) 63 | eps = BALL_EPS[x.dtype] 64 | maxnorm = (1 - eps) 65 | cond = norm > maxnorm 66 | projected = x / norm * maxnorm 67 | return torch.where(cond, projected, x) 68 | 69 | 70 | def distance(x, y, keepdim=True): 71 | """Hyperbolic distance on the Poincare ball with curvature c. 72 | Args: 73 | x: torch.Tensor of size B x d with hyperbolic points 74 | y: torch.Tensor of size B x d with hyperbolic points 75 | Returns: torch,Tensor with hyperbolic distances, size B x 1 76 | """ 77 | pairwise_norm = mobius_add(-x, y).norm(dim=-1, p=2, keepdim=True) 78 | dist = 2.0 * torch.atanh(pairwise_norm.clamp(-1 + MIN_NORM, 1 - MIN_NORM)) 79 | if not keepdim: 80 | dist = dist.squeeze(-1) 81 | return dist 82 | 83 | 84 | def pairwise_distance(x, keepdim=False): 85 | """All pairs of hyperbolic distances (NxN matrix).""" 86 | return distance(x.unsqueeze(-2), x.unsqueeze(-3), keepdim=keepdim) 87 | 88 | 89 | def distance0(x, keepdim=True): 90 | """Computes hyperbolic distance between x and the origin.""" 91 | x_norm = x.norm(dim=-1, p=2, keepdim=True) 92 | d = 2 * torch.atanh(x_norm.clamp(-1 + 1e-15, 1 - 1e-15)) 93 | if not keepdim: 94 | d = d.squeeze(-1) 95 | return d 96 | 97 | 98 | def mobius_add(x, y): 99 | """Mobius addition.""" 100 | x2 = torch.sum(x * x, dim=-1, keepdim=True) 101 | y2 = torch.sum(y * y, dim=-1, keepdim=True) 102 | xy = torch.sum(x * y, dim=-1, keepdim=True) 103 | num = (1 + 2 * xy + y2) * x + (1 - x2) * y 104 | denom = 1 + 2 * xy + x2 * y2 105 | return num / denom.clamp_min(MIN_NORM) 106 | 107 | 108 | def mobius_mul(x, t): 109 | """Mobius multiplication.""" 110 | normx = x.norm(dim=-1, p=2, keepdim=True).clamp(min=MIN_NORM, max=1. - 1e-5) 111 | return torch.tanh(t * torch.atanh(normx)) * x / normx 112 | 113 | 114 | def midpoint(x, y): 115 | """Computes hyperbolic midpoint beween x and y.""" 116 | t1 = mobius_add(-x, y) 117 | t2 = mobius_mul(t1, 0.5) 118 | return mobius_add(x, t2) 119 | 120 | 121 | # Reflection (circle inversion of x through orthogonal circle centered at a) 122 | def isometric_transform(x, a): 123 | r2 = torch.sum(a ** 2, dim=-1, keepdim=True) - 1. 124 | u = x - a 125 | return r2 / torch.sum(u ** 2, dim=-1, keepdim=True) * u + a 126 | 127 | 128 | # center of inversion circle 129 | def reflection_center(mu): 130 | return mu / torch.sum(mu ** 2, dim=-1, keepdim=True) 131 | 132 | 133 | # Map x under the isometry (inversion) taking mu to origin 134 | def reflect_at_zero(x, mu): 135 | a = reflection_center(mu) 136 | return isometric_transform(x, a) 137 | 138 | 139 | def orthogonal_projection(x, Q, normalized=False): 140 | """ Orthogonally project x onto linear subspace (through the origin) spanned by rows of Q. """ 141 | if not normalized: 142 | Q = euclidean.orthonormal(Q) 143 | x_ = euclidean.reflect(x, Q) 144 | return midpoint(x, x_) 145 | 146 | 147 | def geodesic_between_ideals(ideals): 148 | """Return the center and radius of the Euclidean circle representing 149 | the geodesic joining two ideal points p = ideals[0] and q = ideals[1] 150 | 151 | Args: 152 | ideals: torch.tensor of shape (...,2,dim) 153 | Return: 154 | center: torch.tensor of shape (..., dim) 155 | radius: torch.tensor of shape (...) 156 | 157 | Note: raise an error if p = -q, i.e. if the geodesic between them is an Euclidean line 158 | """ 159 | p = ideals[..., 0, :] 160 | q = ideals[..., 1, :] 161 | norm_sum = (p + q).norm(dim=-1, p=2) # shape (...) 162 | assert torch.all(norm_sum != 0) 163 | center = (p + q) / (1 + (p * q).sum(dim=-1, keepdim=True)) 164 | radius = (p - q).norm(dim=-1, p=2) / norm_sum 165 | return center, radius 166 | 167 | 168 | def random_points(size, std=1.0): 169 | tangents = torch.randn(*size) * std 170 | x = expmap0(tangents) 171 | return x 172 | 173 | 174 | def random_ideals(size): 175 | Q = torch.randn(*size) 176 | Q = Q / torch.norm(Q, dim=-1, keepdim=True) 177 | return Q 178 | 179 | 180 | def p2k(x, c): 181 | denom = 1 + c * x.pow(2).sum(-1, keepdim=True) 182 | return 2 * x / denom 183 | 184 | 185 | def k2p(x, c): 186 | denom = 1 + torch.sqrt(1 - c * x.pow(2).sum(-1, keepdim=True)) 187 | return x / denom 188 | 189 | 190 | def lorenz_factor(x, *, c=1.0, dim=-1, keepdim=False): 191 | """ 192 | 193 | Parameters 194 | ---------- 195 | x : tensor 196 | point on Klein disk 197 | c : float 198 | negative curvature 199 | dim : int 200 | dimension to calculate Lorenz factor 201 | keepdim : bool 202 | retain the last dim? (default: false) 203 | 204 | Returns 205 | ------- 206 | tensor 207 | Lorenz factor 208 | """ 209 | return 1 / torch.sqrt(1 - c * x.pow(2).sum(dim=dim, keepdim=keepdim)) 210 | 211 | 212 | def poincare_mean(x, dim=0, c=1.0): 213 | x = p2k(x, c) 214 | lamb = lorenz_factor(x, c=c, keepdim=True) 215 | mean = torch.sum(lamb * x, dim=dim, keepdim=True) / torch.sum( 216 | lamb, dim=dim, keepdim=True 217 | ) 218 | mean = k2p(mean, c) 219 | return mean.squeeze(dim) 220 | -------------------------------------------------------------------------------- /hierarchy/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/HypStructure/3c11ea961da48d7c570d973a198e40e6534268be/hierarchy/.DS_Store -------------------------------------------------------------------------------- /hierarchy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/HypStructure/3c11ea961da48d7c570d973a198e40e6534268be/hierarchy/__init__.py -------------------------------------------------------------------------------- /hierarchy/data.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import numpy as np 3 | 4 | import networkx as nx 5 | import os 6 | import json 7 | 8 | class Hierarchy(): 9 | ''' 10 | Class representing a general ImageNet-style hierarchy, modified from 11 | https://github.com/MadryLab/robustness/blob/master/robustness/tools/breeds_helpers.py 12 | ''' 13 | def __init__(self, dataset_name, info_dir='/path/to/repository/', *args, **kw): 14 | """ 15 | Args: 16 | info_dir (str) : Base path to datasets and hierarchy information files. Contains a 17 | "class_hierarchy.txt" file with one edge per line, a 18 | "node_names.txt" mapping nodes to names, and "dataset_class_info.json". 19 | """ 20 | super(Hierarchy, self).__init__(*args, **kw) 21 | 22 | self.dataset_name = dataset_name 23 | self.info_dir = info_dir 24 | 25 | REQUIRED_FILES = [f'{self.dataset_name}/dataset_class_info.json', 26 | f'{self.dataset_name}/class_hierarchy.txt', 27 | f'{self.dataset_name}/node_names.txt'] 28 | 29 | for f in REQUIRED_FILES: 30 | if not os.path.exists(os.path.join(self.info_dir, f)): 31 | self.generate_files() 32 | 33 | # Details about dataset class names (leaves), IDS 34 | with open(os.path.join(self.info_dir, f"{self.dataset_name}/dataset_class_info.json")) as f: 35 | class_info = json.load(f) 36 | 37 | # Hierarchy represented as edges between parent & child nodes. 38 | with open(os.path.join(self.info_dir, f'{self.dataset_name}/class_hierarchy.txt')) as f: 39 | edges = [l.strip().split() for l in f.readlines()] 40 | 41 | # Information (names, IDs) for intermediate nodes in hierarchy. 42 | with open(os.path.join(self.info_dir, f'{self.dataset_name}/node_names.txt')) as f: 43 | mapping = [l.strip().split('\t') for l in f.readlines()] 44 | 45 | 46 | # Original dataset classes 47 | self.LEAF_IDS = [c[1] for c in class_info] # wnid 48 | self.LEAF_ID_TO_NAME = {c[1]: c[2] for c in class_info} # wnid : name 49 | self.LEAF_ID_TO_NUM = {c[1]: c[0] for c in class_info} # wnid : labelid 50 | self.LEAF_NUM_TO_NAME = {c[0]: c[2] for c in class_info} # labelid : name 51 | 52 | # Full hierarchy 53 | self.HIER_NODE_NAME = {w[0]: w[1] for w in mapping} # wnid : name 54 | self.NAME_TO_NODE_ID = {w[1]: w[0] for w in mapping} # name : wnid 55 | self.graph = self._make_parent_graph(self.LEAF_IDS, edges) 56 | 57 | # make label mapping 58 | self.label_map = self.get_label_mapping() 59 | # generate tree distance 60 | self.tree_dist = self.generate_tree_dist() 61 | # leave node names 62 | self.leaf_names = [c[2] for c in class_info] 63 | 64 | 65 | 66 | @staticmethod 67 | def _make_parent_graph(nodes, edges): 68 | """ 69 | Obtain networkx representation of class hierarchy. 70 | 71 | Args: 72 | nodes [str] : List of node names to traverse upwards. 73 | edges [(str, str)] : Tuples of parent-child pairs. 74 | 75 | Return: 76 | networkx representation of the graph. 77 | """ 78 | 79 | # create full graph 80 | full_graph_dir = {} 81 | for p, c in edges: 82 | if p not in full_graph_dir: 83 | full_graph_dir[p] = {c: 1} 84 | else: 85 | full_graph_dir[p].update({c: 1}) 86 | 87 | FG = nx.DiGraph(full_graph_dir) 88 | 89 | # perform backward BFS to get the relevant graph 90 | graph_dir = {} 91 | todo = [n for n in nodes if n in FG.nodes()] # skip nodes not in graph 92 | while todo: 93 | curr = todo 94 | todo = [] 95 | for w in curr: 96 | for p in FG.predecessors(w): 97 | if p not in graph_dir: 98 | graph_dir[p] = {w: 1} 99 | else: 100 | graph_dir[p].update({w: 1}) 101 | todo.append(p) 102 | todo = set(todo) 103 | 104 | return nx.DiGraph(graph_dir) 105 | 106 | def get_root_node(self): 107 | for node in self.graph.nodes(): 108 | if self.graph.in_degree(node) == 0: 109 | return node 110 | 111 | def get_ancestors(self, node): # leaf to root path 112 | return nx.shortest_path(self.graph, source=self.get_root_node(), target=node)[::-1] 113 | 114 | def find_leaf_nodes(self): 115 | return [node for node in self.graph.nodes() if self.graph.out_degree(node) == 0] 116 | 117 | def get_label_mapping(self): 118 | leaf_nodes = self.find_leaf_nodes() 119 | leaf_ancestors = {} 120 | max_path = -1 121 | for leaf in leaf_nodes: 122 | ancestors = self.get_ancestors(leaf) 123 | # here leaf is wnid 124 | # convert to label id 125 | if len(ancestors) > max_path: 126 | max_path = len(ancestors) 127 | leaf_ancestors[self.LEAF_ID_TO_NUM[leaf]] = ancestors 128 | 129 | label_map = np.empty((len(leaf_nodes), max_path),dtype=object) 130 | for leaf in leaf_nodes: 131 | true_path = leaf_ancestors[self.LEAF_ID_TO_NUM[leaf]] 132 | padded_path = [''] * (max_path - len(true_path)) + true_path 133 | label_map[self.LEAF_ID_TO_NUM[leaf]] = padded_path 134 | 135 | return label_map 136 | 137 | 138 | def generate_tree_dist(self): 139 | distances = {} 140 | all_distances = dict(nx.all_pairs_shortest_path_length(self.graph.to_undirected())) 141 | for node1, paths in all_distances.items(): 142 | for node2, dist in paths.items(): 143 | if node1 != node2: 144 | node_pair = (node1, node2) 145 | rev_pair = (node2, node1) 146 | if node_pair not in distances: 147 | distances[node_pair] = dist 148 | distances[rev_pair] = dist 149 | return distances 150 | 151 | 152 | def generate_files(self): 153 | # specify how to generate files for each dataset 154 | return None 155 | 156 | -------------------------------------------------------------------------------- /imagenet100/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/HypStructure/3c11ea961da48d7c570d973a198e40e6534268be/imagenet100/__init__.py -------------------------------------------------------------------------------- /imagenet100/class_list.txt: -------------------------------------------------------------------------------- 1 | n03877845 2 | n03000684 3 | n03110669 4 | n03710721 5 | n02825657 6 | n02113186 7 | n01817953 8 | n04239074 9 | n02002556 10 | n04356056 11 | n03187595 12 | n03355925 13 | n03125729 14 | n02058221 15 | n01580077 16 | n03016953 17 | n02843684 18 | n04371430 19 | n01944390 20 | n03887697 21 | n04037443 22 | n02493793 23 | n01518878 24 | n03840681 25 | n04179913 26 | n01871265 27 | n03866082 28 | n03180011 29 | n01910747 30 | n03388549 31 | n03908714 32 | n01855032 33 | n02134084 34 | n03400231 35 | n04483307 36 | n03721384 37 | n02033041 38 | n01775062 39 | n02808304 40 | n13052670 41 | n01601694 42 | n04136333 43 | n03272562 44 | n03895866 45 | n03995372 46 | n06785654 47 | n02111889 48 | n03447721 49 | n03666591 50 | n04376876 51 | n03929855 52 | n02128757 53 | n02326432 54 | n07614500 55 | n01695060 56 | n02484975 57 | n02105412 58 | n04090263 59 | n03127925 60 | n04550184 61 | n04606251 62 | n02488702 63 | n03404251 64 | n03633091 65 | n02091635 66 | n03457902 67 | n02233338 68 | n02483362 69 | n04461696 70 | n02871525 71 | n01689811 72 | n01498041 73 | n02107312 74 | n01632458 75 | n03394916 76 | n04147183 77 | n04418357 78 | n03218198 79 | n01917289 80 | n02102318 81 | n02088364 82 | n09835506 83 | n02095570 84 | n03982430 85 | n04041544 86 | n04562935 87 | n03933933 88 | n01843065 89 | n02128925 90 | n02480495 91 | n03425413 92 | n03935335 93 | n02971356 94 | n02124075 95 | n07714571 96 | n03133878 97 | n02097130 98 | n02113799 99 | n09399592 100 | n03594945 -------------------------------------------------------------------------------- /imagenet100/coarse_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": 0, 3 | "1": 1, 4 | "2": 2, 5 | "3": 3, 6 | "4": 4, 7 | "5": 5, 8 | "6": 6, 9 | "7": 7, 10 | "8": 8, 11 | "9": 9, 12 | "10": 10, 13 | "11": 11, 14 | "12": 12, 15 | "13": 13, 16 | "14": 14, 17 | "15": 15, 18 | "16": 16, 19 | "17": 17, 20 | "18": 18, 21 | "19": 18, 22 | "20": 19, 23 | "21": 20, 24 | "22": 21, 25 | "23": 22, 26 | "24": 23, 27 | "25": 24, 28 | "26": 25, 29 | "27": 26, 30 | "28": 27, 31 | "29": 28, 32 | "30": 28, 33 | "31": 29, 34 | "32": 30, 35 | "33": 31, 36 | "34": 32, 37 | "35": 33, 38 | "36": 34, 39 | "37": 34, 40 | "38": 35, 41 | "39": 36, 42 | "40": 37, 43 | "41": 37, 44 | "42": 38, 45 | "43": 39, 46 | "44": 40, 47 | "45": 42, 48 | "46": 44, 49 | "47": 46, 50 | "48": 39, 51 | "49": 47, 52 | "50": 48, 53 | "51": 49, 54 | "52": 50, 55 | "53": 51, 56 | "54": 52, 57 | "55": 53, 58 | "56": 44, 59 | "57": 54, 60 | "58": 55, 61 | "59": 56, 62 | "60": 57, 63 | "61": 58, 64 | "62": 59, 65 | "63": 60, 66 | "64": 41, 67 | "65": 61, 68 | "66": 57, 69 | "67": 45, 70 | "68": 62, 71 | "69": 63, 72 | "70": 36, 73 | "71": 64, 74 | "72": 65, 75 | "73": 66, 76 | "74": 67, 77 | "75": 68, 78 | "76": 69, 79 | "77": 70, 80 | "78": 59, 81 | "79": 71, 82 | "80": 72, 83 | "81": 62, 84 | "82": 73, 85 | "83": 74, 86 | "84": 75, 87 | "85": 76, 88 | "86": 61, 89 | "87": 77, 90 | "88": 78, 91 | "89": 79, 92 | "90": 80, 93 | "91": 43, 94 | "92": 81, 95 | "93": 82, 96 | "94": 83, 97 | "95": 84, 98 | "96": 85, 99 | "97": 86, 100 | "98": 87, 101 | "99": 88 102 | } -------------------------------------------------------------------------------- /imagenet100/data.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | from torchvision.datasets import ImageFolder 3 | from hierarchy.data import Hierarchy 4 | 5 | import json, os 6 | 7 | class ImageNet100(ImageFolder): 8 | def __init__(self, root, train = True, *args, **kw): 9 | if train: 10 | super(ImageNet100, self).__init__(root = os.path.join(root, 'train'), *args, **kw) 11 | else: 12 | super(ImageNet100, self).__init__(root = os.path.join(root, 'val'), *args, **kw) 13 | 14 | class HierarchyImageNet100(Hierarchy, ImageNet100): 15 | def __init__(self, *args, **kw): 16 | super(HierarchyImageNet100, self).__init__(dataset_name='imagenet100', *args, **kw) 17 | 18 | def generate_files(self): 19 | # Use class_hierarchy.txt, node_names.txt for original wordnet hierarchy in BREEDS. 20 | # For 100 leaf nodes, use https://github.com/deeplearning-wisc/MCM/blob/main/data/ImageNet100/class_list.txt 21 | with open(os.path.join(self.info_dir, f'{self.dataset_name}/class_list.txt')) as f: 22 | leaf_names = [l.strip() for l in f.readlines()] 23 | 24 | with open(os.path.join(self.info_dir, f'{self.dataset_name}/node_names.txt')) as f: 25 | mapping = [l.strip().split('\t') for l in f.readlines()] 26 | HIER_NODE_NAME = {w[0]: w[1] for w in mapping} # wnid : name 27 | 28 | # dataset_class_info.json 29 | # a list, each entry is [int(label_id), wordnet id, label name] 30 | # all leaves 31 | leaf2labelid = {class_name : self.class_to_idx[class_name] for class_name in self.classes} 32 | data_class_info = [[leaf2labelid[leaf],leaf,HIER_NODE_NAME[leaf]] for leaf in leaf_names] 33 | with open(os.path.join(self.info_dir, f"{self.dataset_name}/dataset_class_info.json"), 'w') as file: 34 | json.dump(data_class_info, file) 35 | 36 | 37 | -------------------------------------------------------------------------------- /imagenet100/dataset_class_info.json: -------------------------------------------------------------------------------- 1 | [[69, "n03877845", "palace"], [44, "n03000684", "chain saw, chainsaw"], [46, "n03110669", "cornet, horn, trumpet, trump"], [65, "n03710721", "maillot, tank suit"], [40, "n02825657", "bell cote, bell cot"], [26, "n02113186", "Cardigan, Cardigan Welsh corgi"], [8, "n01817953", "African grey, African gray, Psittacus erithacus"], [84, "n04239074", "sliding door"], [15, "n02002556", "white stork, Ciconia ciconia"], [85, "n04356056", "sunglasses, dark glasses, shades"], [51, "n03187595", "dial telephone, dial phone"], [54, "n03355925", "flagpole, flagstaff"], [47, "n03125729", "cradle"], [17, "n02058221", "albatross, mollymawk"], [2, "n01580077", "jay"], [45, "n03016953", "chiffonier, commode"], [41, "n02843684", "birdhouse"], [86, "n04371430", "swimming trunks, bathing trunks"], [14, "n01944390", "snail"], [70, "n03887697", "paper towel"], [78, "n04037443", "racer, race car, racing car"], [38, "n02493793", "spider monkey, Ateles geoffroyi"], [1, "n01518878", "ostrich, Struthio camelus"], [67, "n03840681", "ocarina, sweet potato"], [83, "n04179913", "sewing machine"], [11, "n01871265", "tusker"], [68, "n03866082", "overskirt"], [50, "n03180011", "desktop computer"], [12, "n01910747", "jellyfish"], [55, "n03388549", "four-poster"], [72, "n03908714", "pencil sharpener"], [10, "n01855032", "red-breasted merganser, Mergus serrator"], [31, "n02134084", "ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus"], [57, "n03400231", "frying pan, frypan, skillet"], [90, "n04483307", "trimaran"], [66, "n03721384", "marimba, xylophone"], [16, "n02033041", "dowitcher"], [7, "n01775062", "wolf spider, hunting spider"], [39, "n02808304", "bath towel"], [99, "n13052670", "hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa"], [3, "n01601694", "water ouzel, dipper"], [81, "n04136333", "sarong"], [53, "n03272562", "electric locomotive"], [71, "n03895866", "passenger car, coach, carriage"], [77, "n03995372", "power drill"], [94, "n06785654", "crossword puzzle, crossword"], [25, "n02111889", "Samoyed, Samoyede"], [60, "n03447721", "gong, tam-tam"], [64, "n03666591", "lighter, light, igniter, ignitor"], [87, "n04376876", "syringe"], [73, "n03929855", "pickelhaube"], [29, "n02128757", "snow leopard, ounce, Panthera uncia"], [33, "n02326432", "hare"], [95, "n07614500", "ice cream, icecream"], [6, "n01695060", "Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis"], [36, "n02484975", "guenon, guenon monkey"], [23, "n02105412", "kelpie"], [80, "n04090263", "rifle"], [48, "n03127925", "crate"], [91, "n04550184", "wardrobe, closet, press"], [93, "n04606251", "wreck"], [37, "n02488702", "colobus, colobus monkey"], [58, "n03404251", "fur coat"], [63, "n03633091", "ladle"], [19, "n02091635", "otterhound, otter hound"], [61, "n03457902", "greenhouse, nursery, glasshouse"], [32, "n02233338", "cockroach, roach"], [35, "n02483362", "gibbon, Hylobates lar"], [89, "n04461696", "tow truck, tow car, wrecker"], [42, "n02871525", "bookshop, bookstore, bookstall"], [5, "n01689811", "alligator lizard"], [0, "n01498041", "stingray"], [24, "n02107312", "miniature pinscher"], [4, "n01632458", "spotted salamander, Ambystoma maculatum"], [56, "n03394916", "French horn, horn"], [82, "n04147183", "schooner"], [88, "n04418357", "theater curtain, theatre curtain"], [52, "n03218198", "dogsled, dog sled, dog sleigh"], [13, "n01917289", "brain coral"], [22, "n02102318", "cocker spaniel, English cocker spaniel, cocker"], [18, "n02088364", "beagle"], [98, "n09835506", "ballplayer, baseball player"], [20, "n02095570", "Lakeland terrier"], [76, "n03982430", "pool table, billiard table, snooker table"], [79, "n04041544", "radio, wireless"], [92, "n04562935", "water tower"], [74, "n03933933", "pier"], [9, "n01843065", "jacamar"], [30, "n02128925", "jaguar, panther, Panthera onca, Felis onca"], [34, "n02480495", "orangutan, orang, orangutang, Pongo pygmaeus"], [59, "n03425413", "gas pump, gasoline pump, petrol pump, island dispenser"], [75, "n03935335", "piggy bank, penny bank"], [43, "n02971356", "carton"], [28, "n02124075", "Egyptian cat"], [96, "n07714571", "head cabbage"], [49, "n03133878", "Crock Pot"], [21, "n02097130", "giant schnauzer"], [27, "n02113799", "standard poodle"], [97, "n09399592", "promontory, headland, head, foreland"], [62, "n03594945", "jeep, landrover"]] -------------------------------------------------------------------------------- /imagenet100/hyperparameters.json: -------------------------------------------------------------------------------- 1 | { 2 | "SupCon":{ 3 | "epochs": 20, 4 | "optimizer": { 5 | "lr": 0.01, 6 | "momentum": 0.9, 7 | "weight_decay": 0.0001 8 | }, 9 | "lr_decay_rate" : 0.1 10 | }, 11 | "PoincareSupCon":{ 12 | "epochs": 20, 13 | "optimizer": { 14 | "lr": 0.01, 15 | "momentum": 0.9, 16 | "weight_decay": 0.0001 17 | }, 18 | "lr_decay_rate" : 0.1 19 | }, 20 | "ERM":{ 21 | "epochs": 200, 22 | "optimizer": { 23 | "lr": 0.1, 24 | "momentum": 0.9, 25 | "weight_decay": 0.0005 26 | }, 27 | "scheduler": { 28 | "step_size": 60, 29 | "gamma": 0.2 30 | } 31 | } 32 | } -------------------------------------------------------------------------------- /images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/HypStructure/3c11ea961da48d7c570d973a198e40e6534268be/images/.DS_Store -------------------------------------------------------------------------------- /images/algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/HypStructure/3c11ea961da48d7c570d973a198e40e6534268be/images/algorithm.png -------------------------------------------------------------------------------- /images/cifar10_viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/HypStructure/3c11ea961da48d7c570d973a198e40e6534268be/images/cifar10_viz.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from itertools import combinations 7 | from typing import * 8 | 9 | from geom import poincare, hyperboloid 10 | import geom.pmath as pmath 11 | import geom.nn as hypnn 12 | from geom.pmath import dist_matrix 13 | 14 | class CPCCLoss(nn.Module): 15 | ''' 16 | CPCC as a mini-batch regularizer. 17 | ''' 18 | def __init__(self, dataset, args): 19 | super(CPCCLoss, self).__init__() 20 | self.tree_dist = dataset.tree_dist 21 | self.label_map_str = dataset.label_map # fine label id => wnid from leaf to root 22 | if args.leaf_only: 23 | self.label_map_str = dataset.label_map[:,0].reshape(-1,1) 24 | self.label_map_int, self.str2int, self.int2str = self.map_strings_to_integers(self.label_map_str) 25 | self.empty_int = self.str2int.get('',-1) # integer id for empty string used for padding in label map 26 | self.tree_depth = len(self.label_map_str[0]) 27 | self.distance_type = args.cpcc_metric 28 | if self.distance_type == 'poincare_exp_c': 29 | self.to_hyperbolic = hypnn.ToPoincare(c=args.hyp_c, ball_dim=args.poincare_head_dim, riemannian=True, clip_r=args.clip_r, train_c=args.train_c) 30 | self.dist_f = lambda x, y: dist_matrix(x, y, c=args.hyp_c) 31 | 32 | def map_strings_to_integers(self, string_array): 33 | string_to_int = {} 34 | int_to_string = {} 35 | current_id = 0 # nonnegative 36 | 37 | int_array = np.zeros(string_array.shape, dtype=int) 38 | 39 | for i in range(string_array.shape[0]): 40 | for j in range(string_array.shape[1]): 41 | string = string_array[i, j] 42 | if string not in string_to_int: 43 | string_to_int[string] = current_id 44 | int_to_string[current_id] = string 45 | current_id += 1 46 | int_array[i, j] = string_to_int[string] 47 | 48 | return int_array, string_to_int, int_to_string 49 | 50 | def forward(self, representations, targets_fine): 51 | label_map = torch.tensor(self.label_map_int,device=targets_fine.device) 52 | # assume the tree has d levels 53 | # targets: B * d, d = 0 => leaf node, d = d-1 => root 54 | targets = label_map[targets_fine] 55 | all_unique_int = [torch.unique(targets[:, col][targets[:, col] != self.empty_int]) for col in range(targets.shape[1])] # unique node for each level 56 | all_unique_str = [] # flattened all unique string nodes in this batch 57 | target_mean_list = [] # d components from fine to coarse 58 | if self.distance_type == 'poincare_mean': 59 | representations_poincare = poincare.expmap0(representations) 60 | for col, unique_values in enumerate(all_unique_int): 61 | for val in unique_values: 62 | if self.distance_type == 'poincare_mean': 63 | column_mean = pmath.poincare_mean(torch.index_select(representations_poincare, 0, (targets[:, col] == val).nonzero(as_tuple=True)[0]), dim=0, c=1.0) 64 | else: 65 | column_mean = torch.mean(torch.index_select(representations, 0, (targets[:, col] == val).nonzero(as_tuple=True)[0]), 0) 66 | target_mean_list.append(column_mean) 67 | all_unique_str.append(self.int2str[val.item()]) 68 | sorted_sums = torch.stack(target_mean_list, 0) 69 | 70 | if self.distance_type == 'l2': 71 | pairwise_dist = F.pdist(sorted_sums, p=2.0) # get pairwise distance 72 | elif self.distance_type == 'nl2': 73 | # normalized 74 | all_norms = torch.norm(sorted_sums, dim=1, p=2).unsqueeze(-1) 75 | pairwise_dist = F.pdist(all_norms, p=2.0) 76 | elif self.distance_type == 'l1': 77 | pairwise_dist = F.pdist(sorted_sums, p=1.0) 78 | elif self.distance_type == 'poincare': 79 | # Project into the poincare ball with norm <= 1 - epsilon 80 | # https://www.tensorflow.org/addons/api_docs/python/tfa/layers/PoincareNormalize 81 | epsilon = 1e-5 82 | all_norms = torch.norm(sorted_sums, dim=1, p=2).unsqueeze(-1) 83 | normalized_sorted_sums = sorted_sums * (1 - epsilon) / all_norms 84 | all_normalized_norms = torch.norm(normalized_sorted_sums, dim=1, p=2) 85 | # |u-v|^2 86 | condensed_idx = torch.triu_indices(len(all_unique_str), len(all_unique_str), offset=1, device = sorted_sums.device) 87 | numerator_square = torch.sum((normalized_sorted_sums[None, :] - normalized_sorted_sums[:, None])**2, -1) 88 | numerator = numerator_square[condensed_idx[0],condensed_idx[1]] 89 | # (1 - |u|^2) * (1 - |v|^2) 90 | denominator_square = ((1 - all_normalized_norms**2).reshape(-1,1)) @ (1 - all_normalized_norms**2).reshape(1,-1) 91 | denominator = denominator_square[condensed_idx[0],condensed_idx[1]] 92 | pairwise_dist = torch.acosh(1 + 2 * (numerator/denominator)) 93 | elif self.distance_type == 'poincare_exp': 94 | sorted_sums_exp = poincare.expmap0(sorted_sums) 95 | condensed_idx = torch.triu_indices(len(all_unique_str), len(all_unique_str), offset=1, device = sorted_sums.device) 96 | pairwise_dists_poincare_matrix = poincare.pairwise_distance(sorted_sums_exp) 97 | pairwise_dist = pairwise_dists_poincare_matrix[condensed_idx[0],condensed_idx[1]] 98 | elif self.distance_type == 'poincare_exp_c': 99 | sorted_sums_exp = self.to_hyperbolic(sorted_sums) 100 | condensed_idx = torch.triu_indices(len(all_unique_str), len(all_unique_str), offset=1, device = sorted_sums.device) 101 | pairwise_dists_poincare_matrix = self.dist_f(sorted_sums_exp, sorted_sums_exp) 102 | pairwise_dist = pairwise_dists_poincare_matrix[condensed_idx[0],condensed_idx[1]] 103 | elif self.distance_type == 'poincare_mean': 104 | condensed_idx = torch.triu_indices(len(all_unique_str), len(all_unique_str), offset=1, device = sorted_sums.device) 105 | pairwise_dists_poincare_matrix = poincare.pairwise_distance(sorted_sums) 106 | pairwise_dist = pairwise_dists_poincare_matrix[condensed_idx[0],condensed_idx[1]] 107 | 108 | tree_pairwise_dist = self.dT(all_unique_str, pairwise_dist.device) 109 | 110 | res = 1 - torch.corrcoef(torch.stack([pairwise_dist, tree_pairwise_dist], 0))[0,1] # maximize cpcc 111 | if torch.isnan(res): 112 | return torch.tensor(1,device=pairwise_dist.device) 113 | else: 114 | return res 115 | 116 | def dT(self, all_node, device): 117 | tree_pairwise_dist = [] 118 | for i in range(len(all_node)): 119 | for j in range(i+1, len(all_node)): 120 | tree_pairwise_dist.append(self.tree_dist[(all_node[i], all_node[j])]) 121 | return torch.tensor(tree_pairwise_dist, device=device) 122 | 123 | """ 124 | Adapted from SupCon: https://github.com/HobbitLong/SupContrast/ 125 | """ 126 | 127 | class SupConLoss(nn.Module): 128 | """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf. 129 | It also supports the unsupervised contrastive loss in SimCLR""" 130 | def __init__(self, temperature, contrast_mode='all', 131 | base_temperature=0.07): 132 | super(SupConLoss, self).__init__() 133 | self.temperature = temperature 134 | self.contrast_mode = contrast_mode 135 | self.base_temperature = base_temperature 136 | 137 | def forward(self, features, labels=None, mask=None): 138 | """Compute loss for model. If both `labels` and `mask` are None, 139 | it degenerates to SimCLR unsupervised loss: 140 | https://arxiv.org/pdf/2002.05709.pdf 141 | Args: 142 | features: hidden vector of shape [bsz, n_views, ...]. 143 | labels: ground truth of shape [bsz]. 144 | mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j 145 | has the same class as sample i. Can be asymmetric. 146 | Returns: 147 | A loss scalar. 148 | """ 149 | device = (torch.device('cuda') 150 | if features.is_cuda 151 | else torch.device('cpu')) 152 | 153 | if len(features.shape) < 3: 154 | raise ValueError('`features` needs to be [bsz, n_views, ...],' 155 | 'at least 3 dimensions are required') 156 | if len(features.shape) > 3: 157 | features = features.view(features.shape[0], features.shape[1], -1) 158 | 159 | batch_size = features.shape[0] 160 | if labels is not None and mask is not None: 161 | raise ValueError('Cannot define both `labels` and `mask`') 162 | elif labels is None and mask is None: 163 | mask = torch.eye(batch_size, dtype=torch.float32).to(device) 164 | elif labels is not None: 165 | labels = labels.contiguous().view(-1, 1) 166 | if labels.shape[0] != batch_size: 167 | raise ValueError('Num of labels does not match num of features') 168 | mask = torch.eq(labels, labels.T).float().to(device) 169 | else: 170 | mask = mask.float().to(device) 171 | 172 | contrast_count = features.shape[1] 173 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 174 | if self.contrast_mode == 'one': 175 | anchor_feature = features[:, 0] 176 | anchor_count = 1 177 | elif self.contrast_mode == 'all': 178 | anchor_feature = contrast_feature 179 | anchor_count = contrast_count 180 | else: 181 | raise ValueError('Unknown mode: {}'.format(self.contrast_mode)) 182 | 183 | # compute logits 184 | anchor_dot_contrast = torch.div( 185 | torch.matmul(anchor_feature, contrast_feature.T), 186 | self.temperature) 187 | # for numerical stability 188 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 189 | logits = anchor_dot_contrast - logits_max.detach() 190 | 191 | # tile mask 192 | mask = mask.repeat(anchor_count, contrast_count) 193 | # mask-out self-contrast cases 194 | logits_mask = torch.scatter( 195 | torch.ones_like(mask), 196 | 1, 197 | torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 198 | 0 199 | ) 200 | mask = mask * logits_mask 201 | 202 | # compute log_prob 203 | exp_logits = torch.exp(logits) * logits_mask 204 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 205 | 206 | # compute mean of log-likelihood over positive 207 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 208 | 209 | # loss 210 | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos 211 | loss = loss.view(anchor_count, batch_size).mean() 212 | 213 | return loss -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data.distributed import DistributedSampler 7 | import torch.distributed as dist 8 | 9 | import numpy as np 10 | import math 11 | 12 | import faiss 13 | 14 | from sklearn.neighbors import KNeighborsClassifier 15 | from sklearn.model_selection import cross_val_score 16 | 17 | from datetime import datetime 18 | import argparse 19 | import os 20 | import json 21 | from typing import * 22 | import pickle 23 | 24 | from model import init_model 25 | from data_loader import make_dataloader, TwoCropTransform 26 | from loss import CPCCLoss, SupConLoss 27 | from param import init_optim_schedule, adjust_learning_rate, warmup_learning_rate, load_params 28 | from utils import * 29 | 30 | import geom.pmath as pmath 31 | from geom import poincare 32 | 33 | 34 | def get_different_loss(exp_name : str, model : nn.Module, data : Tensor, 35 | criterion : nn.Module, target : Tensor, 36 | args) -> Tuple[Tensor, Tensor, Tensor]: 37 | ''' 38 | Helper to calculate non CPCC loss, also return (default unnormalized) representation and model loss 39 | ''' 40 | if exp_name == 'SupCon': 41 | bsz = target.shape[0] 42 | input_combined = torch.cat([data[0], data[1]], dim=0).cuda() 43 | target_combined = target.repeat(2).cuda() 44 | 45 | if isinstance(model, nn.DataParallel) or (hasattr(args, 'world_size') and args.world_size > 1): 46 | model = model.module 47 | 48 | penultimate = model.encoder(input_combined).squeeze() 49 | representation = penultimate[:bsz] 50 | 51 | if args.normalize: # default: False or 0 52 | penultimate = F.normalize(penultimate, dim=1) 53 | 54 | features = F.normalize(model.head(penultimate), dim=1) # result of proj head 55 | f1, f2 = torch.split(features, [bsz, bsz], dim=0) #f1 shape: [bz, feat_dim] 56 | features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) #features shape: [bz, 2, feat_dim] 57 | loss = criterion(features, target_combined[:bsz]) 58 | 59 | elif exp_name == 'ERM': 60 | representation, logits = model(data) 61 | loss = criterion(logits, target) 62 | return representation, loss 63 | 64 | def pretrain_objective(train_loader : DataLoader, val_loader : DataLoader, device : torch.device, 65 | save_dir : str, seed : int, CPCC : bool, exp_name : str, epochs : int, 66 | dataset_name : str, hyper) -> None: 67 | ''' 68 | Pretrain session. 69 | ''' 70 | 71 | def trainer(model, hyper, total_epochs, optimizer, scheduler, train_loader, val_loader): 72 | train_losses_base_hist = {} 73 | train_losses_cpcc_hist = {} 74 | for epoch in range(1,total_epochs+1): 75 | if isinstance(train_loader.sampler, DistributedSampler): 76 | train_loader.sampler.set_epoch(epoch) 77 | adjust_learning_rate(hyper, optimizer, epoch) 78 | t_start = datetime.now() # record the time for each epoch 79 | model.train() 80 | 81 | train_losses_base = [] 82 | train_losses_cpcc = [] 83 | 84 | for idx, (data, target) in enumerate(train_loader): 85 | warmup_learning_rate(hyper, epoch, idx, len(train_loader), optimizer) 86 | target = target.to(device) 87 | 88 | optimizer.zero_grad() 89 | 90 | representation, loss_base = get_different_loss(exp_name, model, data, criterion, target, args) 91 | 92 | loss_cpcc = criterion_cpcc(representation, target) 93 | train_losses_cpcc.append(loss_cpcc) 94 | train_losses_base.append(loss_base) 95 | if CPCC: 96 | loss = loss_base + lamb * loss_cpcc 97 | if args.center: 98 | if args.cpcc_metric == 'poincare_mean': 99 | if args.feature_dim < 512: # 64 100 | loss = loss_base + lamb * 0.5 * loss_cpcc + 0.005 * torch.norm(pmath.poincare_mean(poincare.expmap0(representation),dim=0, c=1.0)) 101 | else: 102 | loss = loss_base + lamb * loss_cpcc + 0.01 * torch.norm(pmath.poincare_mean(poincare.expmap0(representation),dim=0, c=1.0)) 103 | else: 104 | loss = loss_base + lamb * loss_cpcc + 0.01 * torch.norm(torch.mean(representation,0)) 105 | else: 106 | loss = loss_base 107 | 108 | loss.backward() 109 | optimizer.step() 110 | 111 | if scheduler is not None: 112 | scheduler.step() 113 | 114 | if is_rank_zero() and (idx % 10 == 1): 115 | print(f"Train Loss: {sum(train_losses_base)/len(train_losses_base):.4f}, " 116 | f"Train CPCC: {sum(train_losses_cpcc)/len(train_losses_cpcc):.4f}") 117 | 118 | 119 | t_end = datetime.now() 120 | t_delta = (t_end-t_start).total_seconds() 121 | 122 | if is_rank_zero(): 123 | print(f"Epoch {epoch} takes {t_delta} sec.") 124 | 125 | train_losses_base_hist[epoch] = (sum(train_losses_base)/len(train_losses_base)).item() 126 | train_losses_cpcc_hist[epoch] = (sum(train_losses_cpcc)/len(train_losses_cpcc)).item() 127 | 128 | pickle.dump(train_losses_base_hist, open(save_dir + f'/train_losses_base_hist.pkl', 'wb')) 129 | pickle.dump(train_losses_cpcc_hist, open(save_dir + f'/train_losses_cpcc_hist.pkl', 'wb')) 130 | 131 | 132 | log_dict = {f"train_losses_{exp_name}":sum(train_losses_base)/len(train_losses_base), 133 | f"train_losses_cpcc":sum(train_losses_cpcc)/len(train_losses_cpcc),} 134 | 135 | 136 | if epoch % args.save_freq == 0 and args.save_freq > 0: 137 | checkpoint = {'epoch': epoch, 138 | 'model': model.state_dict(), 139 | 'optimizer': optimizer.state_dict()} 140 | if scheduler is not None: 141 | checkpoint['lr_sched'] = scheduler 142 | torch.save(checkpoint, save_dir+f"/checkpoints/e{epoch}_seed{seed}.pth") 143 | 144 | return 145 | 146 | def base_eval(model, val_loader): 147 | model.eval() 148 | test_accs = [] 149 | test_losses_base = [] 150 | test_losses_cpcc = [] 151 | 152 | with torch.no_grad(): 153 | for item in val_loader: 154 | data = item[0].to(device) 155 | target = item[-1].to(device) 156 | 157 | representation, logits = model(data) 158 | loss_base = criterion(logits, target) 159 | loss_cpcc = criterion_cpcc(representation, target) 160 | 161 | prob = F.softmax(logits,dim=1) 162 | pred = prob.argmax(dim=1) 163 | acc = pred.eq(target).flatten().tolist() 164 | test_accs.extend(acc) 165 | 166 | test_losses_base.append(loss_base) 167 | test_losses_cpcc.append(loss_cpcc) 168 | 169 | return sum(test_accs)/len(test_accs), sum(test_losses_base)/len(test_losses_base), sum(test_losses_cpcc)/len(test_losses_cpcc) 170 | 171 | def knn(model, val_loader): 172 | model.eval() 173 | 174 | features = [] 175 | labels = [] 176 | 177 | test_losses_base = [] 178 | test_losses_cpcc = [] 179 | 180 | with torch.no_grad(): 181 | for item in val_loader: 182 | data = item[0] 183 | target = item[-1].to(device) 184 | bsz = target.shape[0] 185 | target_combined = target.repeat(2) 186 | 187 | # compute output 188 | input_combined = torch.cat([data[0], data[1]], dim=0).cuda() 189 | penultimate = model.module.encoder(input_combined).squeeze() 190 | 191 | representation = penultimate[:bsz] 192 | output = F.normalize(representation, dim=1).data.cpu() 193 | features.append(output) 194 | labels.append(target) 195 | 196 | proj_features = F.normalize(model.module.head(penultimate), dim=1) 197 | f1, f2 = torch.split(proj_features, [bsz, bsz], dim=0) 198 | proj_features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1) # 199 | supcon_loss = criterion(proj_features, target_combined[:bsz]) 200 | test_losses_base.append(supcon_loss) 201 | 202 | representation, _ = model(data[0].to(device)) 203 | test_loss_cpcc = criterion_cpcc(representation, target) 204 | test_losses_cpcc.append(test_loss_cpcc) 205 | 206 | features = torch.cat(features).numpy() 207 | labels = torch.cat(labels).cpu().numpy() 208 | 209 | cls = KNeighborsClassifier(20, metric="cosine").fit(features, labels) 210 | acc = np.mean(cross_val_score(cls, features, labels)) 211 | 212 | 213 | return acc, sum(test_losses_base)/len(test_losses_base), sum(test_losses_cpcc)/len(test_losses_cpcc) 214 | 215 | if 'WORLD_SIZE' in os.environ and world_size > 0: 216 | dist.init_process_group() 217 | torch.set_float32_matmul_precision('high') 218 | 219 | dataset = train_loader.dataset 220 | num_train_batches = len(train_loader.dataset) // train_loader.batch_size + 1 221 | 222 | optim_config, scheduler_config = hyper['optimizer'], hyper['scheduler'] 223 | init_config = {"dataset":dataset_name, 224 | "exp_name":exp_name, 225 | "cpcc":CPCC, 226 | "_batch_size":train_loader.batch_size, 227 | "epochs":epochs, 228 | "_num_workers":train_loader.num_workers, 229 | "cpcc_lamb":lamb, 230 | "center":args.center, 231 | } 232 | if CPCC: 233 | init_config['cpcc_metric'] = args.cpcc_metric 234 | init_config['leaf_only'] = args.leaf_only 235 | if args.cpcc_metric == 'poincare_exp_c': 236 | init_config['hyp_c'] = args.hyp_c 237 | init_config['clip_r'] = args.clip_r 238 | init_config['poincare_head_dim'] = args.poincare_head_dim 239 | init_config['train_c'] = args.train_c 240 | 241 | if scheduler_config is None: 242 | config = {**init_config, **optim_config} 243 | else: 244 | config = {**init_config, **optim_config, **scheduler_config} 245 | 246 | if exp_name.startswith('SupCon'): 247 | criterion = SupConLoss(temperature = temperature) 248 | elif exp_name.startswith('ERM'): 249 | criterion = nn.CrossEntropyLoss() 250 | 251 | criterion_cpcc = CPCCLoss(dataset, args) 252 | 253 | if is_rank_zero(): 254 | with open(save_dir+'/config.json', 'w') as fp: 255 | json.dump(config, fp, sort_keys=True, indent=4) 256 | 257 | out_dir = save_dir+f"/seed{seed}.pth" 258 | if os.path.exists(out_dir): 259 | print("Skipped.") 260 | return 261 | 262 | model = init_model(device, args) 263 | 264 | if exp_name.startswith('SupCon'): 265 | # apply two crop to validation dataset to get the correct loss 266 | val_loader.dataset.transform = TwoCropTransform(val_loader.dataset.transform) 267 | 268 | if train_loader.batch_size > 256 and (exp_name.startswith('SupCon')) : 269 | hyper["warm"] = True 270 | hyper["warmup_from"] = 0.01 271 | hyper["warm_epochs"] = 10 272 | optim_param = hyper['optimizer'] 273 | lr_decay_rate = hyper["lr_decay_rate"] 274 | eta_min = optim_param['lr'] * (lr_decay_rate ** 3) 275 | hyper["warmup_to"] = eta_min + (optim_param['lr'] - eta_min) * ( 276 | 1 + math.cos(math.pi * hyper["warm_epochs"] / epochs)) / 2 277 | 278 | optimizer, scheduler = init_optim_schedule(model, hyper, train_loader, exp_name, init_optimizer=None) 279 | trainer(model, hyper, epochs, optimizer, scheduler, train_loader, val_loader) 280 | 281 | if is_rank_zero(): 282 | torch.save(model.state_dict(), out_dir) # save the last checkpoint by convention 283 | 284 | # wandb.finish() 285 | if 'WORLD_SIZE' in os.environ and world_size > 0: 286 | ddp_cleanup() 287 | return 288 | 289 | def feature_extractor(dataloader : DataLoader, seed : int, epoch : int = -1): 290 | model = init_model(device, args) 291 | 292 | model_dict = model.state_dict() 293 | if epoch == -1: 294 | ckpt_dict = {k: v for k, v in torch.load(save_dir+f"/seed{seed}.pth").items()} 295 | else: 296 | ckpt_dict = {k: v for k, v in torch.load(save_dir+f"/e{epoch}_seed{seed}.pth").items()} 297 | model_dict.update(ckpt_dict) 298 | model.load_state_dict(model_dict) 299 | 300 | features = [] 301 | targets = [] 302 | model.eval() 303 | 304 | if isinstance(model, nn.DataParallel) or (hasattr(args, 'world_size') and args.world_size > 1): 305 | model = model.module 306 | 307 | with torch.no_grad(): 308 | for item in dataloader: 309 | data = item[0] 310 | target = item[-1] 311 | 312 | if isinstance(data, List): 313 | input_combined = data[0].cuda() 314 | else: # test dataset, type(data) = Tensor 315 | input_combined = data.cuda() 316 | 317 | penultimate = model.encoder(input_combined).squeeze() 318 | features.append(penultimate.cpu().detach().numpy()) 319 | 320 | target = target.to(device) 321 | targets.append(target.cpu().detach().numpy()) 322 | 323 | features = np.concatenate(features,axis=0) 324 | targets = np.concatenate(targets,axis=0) 325 | return features, targets 326 | 327 | def ood_detection(seeds : int, in_dataset_name : str, ood_dataset_name : str, 328 | exp_name : str, num_workers : int, batch_size : int): 329 | 330 | def load_ood_scores(ood_dataset_name, method_name, out): 331 | print(f'{method_name} Skipped.') 332 | out[method_name]['unnormalized']['fpr95'].append(result[ood_dataset_name][method_name]['unnormalized']['fpr95']) 333 | out[method_name]['unnormalized']['auroc'].append(result[ood_dataset_name][method_name]['unnormalized']['auroc']) 334 | out[method_name]['unnormalized']['aupr'].append(result[ood_dataset_name][method_name]['unnormalized']['aupr']) 335 | out[method_name]['normalized']['fpr95'].append(result[ood_dataset_name][method_name]['normalized']['fpr95']) 336 | out[method_name]['normalized']['auroc'].append(result[ood_dataset_name][method_name]['normalized']['auroc']) 337 | out[method_name]['normalized']['aupr'].append(result[ood_dataset_name][method_name]['normalized']['aupr']) 338 | return out 339 | 340 | in_train_loader = make_dataloader(exp_name, num_workers, batch_size, 'train', in_dataset_name) 341 | in_test_loader = make_dataloader(exp_name, num_workers, batch_size, 'test', in_dataset_name) 342 | out_test_loader = make_dataloader(exp_name, num_workers, batch_size, 'ood', in_dataset_name, ood_dataset_name = ood_dataset_name) 343 | 344 | print("OOD Dataset:", ood_dataset_name) 345 | out = {'Mahalanobis':{'normalized' : {'fpr95':[],'auroc':[],'aupr':[]}, 'unnormalized' : {'fpr95':[],'auroc':[],'aupr':[]}}, 346 | 'SSD':{'normalized' : {'fpr95':[],'auroc':[],'aupr':[]}, 'unnormalized' : {'fpr95':[],'auroc':[],'aupr':[]}}, 347 | 'KNN':{'normalized' : {'fpr95':[],'auroc':[],'aupr':[]}, 'unnormalized' : {'fpr95':[],'auroc':[],'aupr':[]}}, 348 | } 349 | 350 | if os.path.exists(save_dir + '/OOD.json'): 351 | with open(save_dir+'/OOD.json', 'r') as fp: 352 | result = json.load(fp) 353 | else: 354 | result = dict() 355 | 356 | for seed in range(seeds): 357 | # compute features 358 | save_location = save_dir + '/' + ood_dataset_name 359 | if not os.path.exists(save_location): 360 | os.makedirs(save_location) 361 | 362 | if not os.path.exists(save_location + f'/in_train_features_seed{seed}.pkl') or \ 363 | not os.path.exists(save_location + f'/in_train_labels_seed{seed}.pkl'): 364 | in_train_features, in_train_labels = feature_extractor(in_train_loader, seed) 365 | pickle.dump(in_train_features, open(save_location + f'/in_train_features_seed{seed}.pkl', 'wb')) 366 | pickle.dump(in_train_labels, open(save_location + f'/in_train_labels_seed{seed}.pkl', 'wb')) 367 | else: 368 | with open(save_location + f'/in_train_features_seed{seed}.pkl', 'rb') as fp: 369 | in_train_features = pickle.load(fp) 370 | with open(save_location + f'/in_train_labels_seed{seed}.pkl', 'rb') as fp: 371 | in_train_labels = pickle.load(fp) 372 | if not os.path.exists(save_location + f'/in_test_features_seed{seed}.pkl') or \ 373 | not os.path.exists(save_location + f'/in_test_labels_seed{seed}.pkl'): 374 | in_test_features, in_test_labels = feature_extractor(in_test_loader, seed) 375 | pickle.dump(in_test_features, open(save_location + f'/in_test_features_seed{seed}.pkl', 'wb')) 376 | pickle.dump(in_test_labels, open(save_location + f'/in_test_labels_seed{seed}.pkl', 'wb')) 377 | else: 378 | with open(save_location + f'/in_test_features_seed{seed}.pkl', 'rb') as fp: 379 | in_test_features = pickle.load(fp) 380 | with open(save_location + f'/in_test_labels_seed{seed}.pkl', 'rb') as fp: 381 | in_test_labels = pickle.load(fp) 382 | if not os.path.exists(save_location + f'/out_test_features_seed{seed}.pkl') or \ 383 | not os.path.exists(save_location + f'/out_test_labels_seed{seed}.pkl'): 384 | out_test_features, out_test_labels = feature_extractor(out_test_loader, seed) 385 | pickle.dump(out_test_features, open(save_location + f'/out_test_features_seed{seed}.pkl', 'wb')) 386 | pickle.dump(out_test_labels, open(save_location + f'/out_test_labels_seed{seed}.pkl', 'wb')) 387 | else: 388 | with open(save_location + f'/out_test_features_seed{seed}.pkl', 'rb') as fp: 389 | out_test_features = pickle.load(fp) 390 | with open(save_location + f'/out_test_labels_seed{seed}.pkl', 'rb') as fp: 391 | out_test_labels = pickle.load(fp) 392 | print("Features successfully loaded.") 393 | 394 | 395 | ftrain = np.copy(in_train_features) 396 | ftest = np.copy(in_test_features) 397 | food = np.copy(out_test_features) 398 | ftrain /= np.linalg.norm(ftrain, axis=-1, keepdims=True) + 1e-10 399 | ftest /= np.linalg.norm(ftest, axis=-1, keepdims=True) + 1e-10 400 | food /= np.linalg.norm(food, axis=-1, keepdims=True) + 1e-10 401 | 402 | m, s = np.mean(ftrain, axis=0, keepdims=True), np.std(ftrain, axis=0, keepdims=True) 403 | 404 | ftrain = (ftrain - m) / (s + 1e-10) 405 | ftest = (ftest - m) / (s + 1e-10) 406 | food = (food - m) / (s + 1e-10) 407 | 408 | clusters = 1 409 | fpr95_s, auroc_s, aupr_s = get_eval_results( 410 | np.copy(in_train_features), 411 | np.copy(in_test_features), 412 | np.copy(out_test_features), 413 | np.copy(in_train_labels), 414 | clusters, 415 | ) 416 | print("SSD: FPR95", fpr95_s, "AUROC:", auroc_s, "AUPR:", aupr_s) 417 | out['SSD']['unnormalized']['fpr95'].append(fpr95_s) 418 | out['SSD']['unnormalized']['auroc'].append(auroc_s) 419 | out['SSD']['unnormalized']['aupr'].append(aupr_s) 420 | 421 | clusters = 1 422 | fpr95_sn, auroc_sn, aupr_sn = get_eval_results( 423 | np.copy(ftrain), 424 | np.copy(ftest), 425 | np.copy(food), 426 | np.copy(in_train_labels), 427 | clusters, 428 | ) 429 | print("Normalized SSD: FPR95", fpr95_sn, "AUROC:", auroc_sn, "AUPR:", aupr_sn) 430 | out['SSD']['normalized']['fpr95'].append(fpr95_sn) 431 | out['SSD']['normalized']['auroc'].append(auroc_sn) 432 | out['SSD']['normalized']['aupr'].append(aupr_sn) 433 | 434 | if in_dataset_name == 'CIFAR10': 435 | K = 50 436 | elif in_dataset_name == 'CIFAR100': 437 | K = 200 438 | elif in_dataset_name == 'IMAGENET100': 439 | K = 200 440 | index = faiss.IndexFlatL2(in_train_features.shape[1]) 441 | index.add(in_train_features) 442 | D, _ = index.search(in_test_features, K) 443 | in_score = D[:,-1] 444 | 445 | D, _ = index.search(out_test_features,K) 446 | out_score = D[:,-1] 447 | 448 | auroc_k, aupr_k, fpr_k = get_measures(-in_score, -out_score) 449 | print("KNN: FPR95", fpr_k, "AUROC:", auroc_k, "AUPR:", aupr_k) 450 | out['KNN']['unnormalized']['fpr95'].append(fpr_k) 451 | out['KNN']['unnormalized']['auroc'].append(auroc_k) 452 | out['KNN']['unnormalized']['aupr'].append(aupr_k) 453 | 454 | index = faiss.IndexFlatL2(ftrain.shape[1]) 455 | index.add(ftrain) 456 | D, _ = index.search(ftest, K) 457 | in_score = D[:,-1] 458 | 459 | D, _ = index.search(food,K) 460 | out_score = D[:,-1] 461 | 462 | auroc_kn, aupr_kn, fpr_kn = get_measures(-in_score, -out_score) 463 | print("Normalized KNN: FPR95", fpr_kn, "AUROC:", auroc_kn, "AUPR:", aupr_kn) 464 | out['KNN']['normalized']['fpr95'].append(fpr_kn) 465 | out['KNN']['normalized']['auroc'].append(auroc_kn) 466 | out['KNN']['normalized']['aupr'].append(aupr_kn) 467 | 468 | n_cls = args.n_cls 469 | classwise_mean, precision = get_mean_prec(torch.tensor(in_train_features), torch.tensor(in_train_labels), n_cls) 470 | in_score_maha = get_Mahalanobis_score(torch.tensor(in_test_features).to(device), n_cls, classwise_mean, precision, in_dist = True) 471 | out_score_maha = get_Mahalanobis_score(torch.tensor(out_test_features).to(device), n_cls, classwise_mean, precision, in_dist = False) 472 | auroc_m, aupr_m, fpr_m = get_measures(-in_score_maha, -out_score_maha) 473 | print("Mahalanobis: FPR95", fpr_m, "AUROC:", auroc_m, "AUPR:", aupr_m) 474 | out['Mahalanobis']['unnormalized']['fpr95'].append(fpr_m) 475 | out['Mahalanobis']['unnormalized']['auroc'].append(auroc_m) 476 | out['Mahalanobis']['unnormalized']['aupr'].append(aupr_m) 477 | 478 | n_cls = args.n_cls 479 | classwise_mean, precision = get_mean_prec(torch.tensor(ftrain), torch.tensor(in_train_labels), n_cls) 480 | in_score_maha = get_Mahalanobis_score(torch.tensor(ftest).to(device), n_cls, classwise_mean, precision, in_dist = True) 481 | out_score_maha = get_Mahalanobis_score(torch.tensor(food).to(device), n_cls, classwise_mean, precision, in_dist = False) 482 | auroc_mn, aupr_mn, fpr_mn = get_measures(-in_score_maha, -out_score_maha) 483 | print("Normalized Mahalanobis: FPR95", fpr_mn, "AUROC:", auroc_mn, "AUPR:", aupr_mn) 484 | out['Mahalanobis']['normalized']['fpr95'].append(fpr_mn) 485 | out['Mahalanobis']['normalized']['auroc'].append(auroc_mn) 486 | out['Mahalanobis']['normalized']['aupr'].append(aupr_mn) 487 | 488 | for ood_scores in ['Mahalanobis','SSD','KNN']: 489 | for n in ['normalized','unnormalized']: 490 | for metric in ['fpr95','auroc','aupr']: 491 | out[ood_scores][n][metric] = np.mean(out[ood_scores][n][metric]) 492 | 493 | result[ood_dataset_name] = out 494 | with open(save_dir+'/OOD.json', 'w') as fp: 495 | json.dump(result, fp, indent=4) 496 | 497 | return result 498 | 499 | def main(): 500 | 501 | # Train 502 | for seed in range(seeds): 503 | seed_everything(seed) 504 | hyper = load_params(dataset_name, exp_name) 505 | epochs = hyper['epochs'] 506 | train_loader = make_dataloader(exp_name, num_workers, batch_size, 'train', dataset_name) 507 | val_loader = make_dataloader(exp_name, num_workers, batch_size, 'test', dataset_name) 508 | args.n_cls = len(train_loader.dataset.leaf_names) 509 | pretrain_objective(train_loader, val_loader, device, save_dir, seed, cpcc, exp_name, epochs, dataset_name, hyper) 510 | 511 | # Eval: ood 512 | if dataset_name in ['CIFAR100','CIFAR10']: 513 | ood_dataset_names = ['SVHN', 'Textures', 'Places365', 'LSUN','iSUN'] 514 | elif dataset_name in ['IMAGENET100']: 515 | ood_dataset_names = ['iNaturalist','SUN', 'Places365', 'dtd'] 516 | for ood_dataset_name in ood_dataset_names: 517 | ood_detection(seeds, dataset_name, ood_dataset_name, exp_name, num_workers, batch_size) 518 | 519 | return 520 | 521 | def is_ddp_initialized(): 522 | return torch.distributed.is_initialized() 523 | 524 | def is_rank_zero(): 525 | # Check if this is the rank 0 process or if DDP is not initialized (i.e., single GPU/CPU mode) 526 | if is_ddp_initialized(): 527 | return torch.distributed.get_rank() == 0 528 | return True 529 | 530 | 531 | def ddp_setup(): 532 | rank = int(os.environ['RANK']) 533 | world_size = int(os.environ['WORLD_SIZE']) 534 | 535 | dist.init_process_group() 536 | 537 | device_id = rank 538 | device = torch.device(f'cuda:{device_id}') 539 | torch.cuda.set_device(device_id) 540 | 541 | return device, device_id, rank, world_size 542 | 543 | def ddp_cleanup(): 544 | dist.destroy_process_group() 545 | 546 | 547 | 548 | if __name__ == '__main__': 549 | parser = argparse.ArgumentParser() 550 | parser.add_argument("--root", default="/path/to/data", type=str, help='directory that you want to save your experiment results') 551 | parser.add_argument("--timestamp", required=True, help=r'your unique experiment id, hint: datetime.now().strftime("%m%d%Y%H%M%S")') 552 | parser.add_argument("--exp_name", required=True, help='ERM | SupCon') 553 | parser.add_argument("--dataset", required=True, help='CIFAR100 | CIFAR10 | IMAGENET100') 554 | parser.add_argument("--model_name", type=str, help='resnet18 | resnet34 | resnet50') 555 | parser.add_argument("--save_freq", type=int, default=-1) 556 | 557 | parser.add_argument("--cpcc", required=True, type=int, help='0/1') 558 | parser.add_argument("--cpcc_metric", default='poincare', type=str, help='distance metric in CPCC, l2/l1/poincare') 559 | parser.add_argument("--leaf_only", default=0, type=int, help='0 use all nodes, 1 use leaf nodes') 560 | parser.add_argument("--lamb",type=float,default=1,help='strength of CPCC regularization') 561 | parser.add_argument("--center", default=0, type=int, help='normalize batch representation') 562 | parser.add_argument("--poincare_low_dim_project", type=int, help='0/1') 563 | parser.add_argument("--poincare_head_dim", type=int, default=512) 564 | parser.add_argument("--hyp_c",type=float,default=1.0,help='curvature of the poincare ball') 565 | parser.add_argument("--clip_r",type=float,default=2.3,help='Clipping parameter of the poincare ball') 566 | parser.add_argument("--train_c",type=bool,default=False,help='Whether or not to train the curvature') 567 | parser.add_argument("--feature_dim", type=int, help='change feature dimension of encoder') 568 | 569 | parser.add_argument("--num_workers", type=int, default=12) 570 | parser.add_argument("--batch_size", type=int, default=512) 571 | parser.add_argument("--seeds", type=int,default=5) 572 | 573 | parser.add_argument('--warm', action='store_true', 574 | help='warm-up for large batch training') 575 | 576 | parser.add_argument('--temp', type=float, default=0.1, 577 | help='temperature for loss function') 578 | 579 | parser.add_argument('--normalize', type=int, default=0, 580 | help='normalize feat embeddings setting, see comments in loss.py') 581 | 582 | parser.set_defaults(bottleneck=True) 583 | parser.set_defaults(augment=True) 584 | 585 | parser.add_argument("--local-rank", type=int, default=0) # DDP 586 | parser.add_argument('--encoder_manifold', type=str, choices=['euclidean','poincare']) 587 | 588 | 589 | args = parser.parse_args() 590 | timestamp = args.timestamp 591 | exp_name = args.exp_name 592 | dataset_name = args.dataset 593 | cpcc = args.cpcc 594 | 595 | num_workers = args.num_workers 596 | batch_size = args.batch_size 597 | seeds = args.seeds 598 | lamb = args.lamb 599 | 600 | temperature = args.temp 601 | 602 | root = f'{args.root}/hypstructure/{dataset_name}' 603 | save_dir = root + '/' + timestamp 604 | if not os.path.exists(save_dir): 605 | os.makedirs(save_dir) 606 | os.makedirs(save_dir + '/checkpoints') 607 | 608 | 609 | # Check if the script is running in a DDP environment 610 | if 'WORLD_SIZE' in os.environ: 611 | # Assuming DDP is initialized outside this function if required 612 | world_size = int(os.environ.get('WORLD_SIZE',0)) 613 | args.world_size = world_size 614 | if world_size > 1: # More than one process implies DDP 615 | local_rank = int(os.getenv('LOCAL_RANK', 0)) 616 | torch.cuda.set_device(local_rank) 617 | device = torch.device(f"cuda:{local_rank}") 618 | rank = int(os.environ['RANK']) 619 | args.rank = rank 620 | print(f"Running in DDP mode on device: {device}, rank: {rank}/{world_size}") 621 | else: 622 | # Default to single GPU/CPU training 623 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 624 | print(f"Running in single GPU/CPU mode on device: {device}") 625 | 626 | main() 627 | -------------------------------------------------------------------------------- /main_evaluate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import json 4 | import pickle 5 | import argparse 6 | import numpy as np 7 | from sklearn.metrics import accuracy_score 8 | from sklearn.neighbors import KNeighborsClassifier 9 | 10 | def load_coarse_map(dataset): 11 | if dataset == 'CIFAR10': 12 | return np.loadtxt('./cifar10/coarse_map.txt', dtype=int) 13 | elif dataset == 'CIFAR100': 14 | return np.loadtxt('./cifar100/coarse_map.txt', dtype=int) 15 | elif dataset == 'IMAGENET100': 16 | with open('./imagenet100/coarse_map.json', 'r') as f: 17 | coarse_map = json.load(f) 18 | return {int(k): v for k, v in coarse_map.items()} 19 | else: 20 | raise ValueError(f"Unsupported dataset: {dataset}") 21 | 22 | def main(dataset, model_save_location): 23 | if dataset == 'CIFAR10': 24 | ood_dataset = 'SVHN' 25 | elif dataset == 'CIFAR100': 26 | ood_dataset = 'SVHN' 27 | elif dataset == 'IMAGENET100': 28 | ood_dataset = 'iNaturalist' 29 | else: 30 | raise ValueError(f"Unsupported dataset: {dataset}") 31 | 32 | coarse_map = load_coarse_map(dataset) 33 | features_path = os.path.join(model_save_location, ood_dataset) 34 | 35 | # Load data 36 | in_train_features = pickle.load(open(os.path.join(features_path, 'in_train_features_seed0.pkl'), 'rb')) 37 | in_train_fine_labels = pickle.load(open(os.path.join(features_path, 'in_train_labels_seed0.pkl'), 'rb')) 38 | in_train_coarse_labels = [coarse_map[i] for i in in_train_fine_labels] 39 | 40 | in_test_features = pickle.load(open(os.path.join(features_path, 'in_test_features_seed0.pkl'), 'rb')) 41 | in_test_fine_labels = pickle.load(open(os.path.join(features_path, 'in_test_labels_seed0.pkl'), 'rb')) 42 | in_test_coarse_labels = [coarse_map[i] for i in in_test_fine_labels] 43 | 44 | # Coarse accuracy 45 | cls = KNeighborsClassifier(50, metric="cosine").fit(in_train_features, in_train_coarse_labels) 46 | pred = cls.predict(in_test_features) 47 | acc = accuracy_score(in_test_coarse_labels, pred) 48 | print("Coarse Accuracy:", acc) 49 | 50 | # Fine accuracy 51 | cls = KNeighborsClassifier(50, metric="cosine").fit(in_train_features, in_train_fine_labels) 52 | pred = cls.predict(in_test_features) 53 | acc = accuracy_score(in_test_fine_labels, pred) 54 | print("Fine Accuracy:", acc) 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser(description="Run evaluation of Fine and Coarse Classification accuracies.") 58 | parser.add_argument('--dataset', type=str, required=True, help="Dataset to use (CIFAR10, CIFAR100, IMAGENET100)") 59 | parser.add_argument('--model_save_location', type=str, required=True, help="Path to the model save location") 60 | 61 | args = parser.parse_args() 62 | main(args.dataset, args.model_save_location) 63 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.models as models 5 | from torch.nn.parallel import DistributedDataParallel as DDP 6 | 7 | from typing import * 8 | 9 | def init_model(device : torch.device, args): 10 | ''' 11 | Load the correct model for each dataset. 12 | ''' 13 | if args.exp_name == 'SupCon': 14 | if hasattr(args, 'world_size') and args.world_size > 1: 15 | model = DDP(SupConResNet(args).to(device)) 16 | else: 17 | model = SupConResNet(args).to(device) 18 | elif args.exp_name == 'ERM': 19 | if hasattr(args, 'world_size') and args.world_size > 1: 20 | model = DDP(SupCEResNet(args)).to(device) 21 | else: 22 | model = SupCEResNet(args).to(device) 23 | model = torch.compile(model) 24 | return model 25 | 26 | class BasicBlock(nn.Module): 27 | expansion = 1 28 | 29 | def __init__(self, in_planes, planes, stride=1, is_last=False): 30 | super(BasicBlock, self).__init__() 31 | self.is_last = is_last 32 | self.conv1 = nn.Conv2d( 33 | in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 34 | ) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.conv2 = nn.Conv2d( 37 | planes, planes, kernel_size=3, stride=1, padding=1, bias=False 38 | ) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | 41 | self.shortcut = nn.Sequential() 42 | if stride != 1 or in_planes != self.expansion * planes: 43 | self.shortcut = nn.Sequential( 44 | nn.Conv2d( 45 | in_planes, 46 | self.expansion * planes, 47 | kernel_size=1, 48 | stride=stride, 49 | bias=False, 50 | ), 51 | nn.BatchNorm2d(self.expansion * planes), 52 | ) 53 | 54 | def forward(self, x): 55 | out = F.relu(self.bn1(self.conv1(x))) 56 | out = self.bn2(self.conv2(out)) 57 | out += self.shortcut(x) 58 | preact = out 59 | out = F.relu(out) 60 | if self.is_last: 61 | return out, preact 62 | else: 63 | return out 64 | 65 | 66 | class Bottleneck(nn.Module): 67 | expansion = 4 68 | 69 | def __init__(self, in_planes, planes, stride=1, is_last=False): 70 | super(Bottleneck, self).__init__() 71 | self.is_last = is_last 72 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 73 | self.bn1 = nn.BatchNorm2d(planes) 74 | self.conv2 = nn.Conv2d( 75 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 76 | ) 77 | self.bn2 = nn.BatchNorm2d(planes) 78 | self.conv3 = nn.Conv2d( 79 | planes, self.expansion * planes, kernel_size=1, bias=False 80 | ) 81 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 82 | 83 | self.shortcut = nn.Sequential() 84 | if stride != 1 or in_planes != self.expansion * planes: 85 | self.shortcut = nn.Sequential( 86 | nn.Conv2d( 87 | in_planes, 88 | self.expansion * planes, 89 | kernel_size=1, 90 | stride=stride, 91 | bias=False, 92 | ), 93 | nn.BatchNorm2d(self.expansion * planes), 94 | ) 95 | 96 | def forward(self, x): 97 | out = F.relu(self.bn1(self.conv1(x))) 98 | out = F.relu(self.bn2(self.conv2(out))) 99 | out = self.bn3(self.conv3(out)) 100 | out += self.shortcut(x) 101 | preact = out 102 | out = F.relu(out) 103 | if self.is_last: 104 | return out, preact 105 | else: 106 | return out 107 | 108 | 109 | class ResNet(nn.Module): 110 | def __init__(self, block, num_blocks, in_channel=3, zero_init_residual=False): 111 | super(ResNet, self).__init__() 112 | self.in_planes = 64 113 | 114 | self.conv1 = nn.Conv2d( 115 | in_channel, 64, kernel_size=3, stride=1, padding=1, bias=False 116 | ) 117 | self.bn1 = nn.BatchNorm2d(64) 118 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 119 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 120 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 121 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 122 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 123 | 124 | for m in self.modules(): 125 | if isinstance(m, nn.Conv2d): 126 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 127 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 128 | nn.init.constant_(m.weight, 1) 129 | nn.init.constant_(m.bias, 0) 130 | 131 | # Zero-initialize the last BN in each residual branch, 132 | # so that the residual branch starts with zeros, and each residual block behaves 133 | # like an identity. This improves the model by 0.2~0.3% according to: 134 | # https://arxiv.org/abs/1706.02677 135 | if zero_init_residual: 136 | for m in self.modules(): 137 | if isinstance(m, Bottleneck): 138 | nn.init.constant_(m.bn3.weight, 0) 139 | elif isinstance(m, BasicBlock): 140 | nn.init.constant_(m.bn2.weight, 0) 141 | 142 | def _make_layer(self, block, planes, num_blocks, stride): 143 | strides = [stride] + [1] * (num_blocks - 1) 144 | layers = [] 145 | for i in range(num_blocks): 146 | stride = strides[i] 147 | layers.append(block(self.in_planes, planes, stride)) 148 | self.in_planes = planes * block.expansion 149 | return nn.Sequential(*layers) 150 | 151 | def forward(self, x, layer=100): 152 | out = F.relu(self.bn1(self.conv1(x))) 153 | out = self.layer1(out) 154 | out = self.layer2(out) 155 | out = self.layer3(out) 156 | out = self.layer4(out) 157 | out = self.avgpool(out) 158 | out = torch.flatten(out, 1) 159 | return out 160 | 161 | 162 | class SupCEResNet(nn.Module): 163 | """encoder + classifier""" 164 | def __init__(self, args): 165 | super(SupCEResNet, self).__init__() 166 | model_fun, dim_in = model_dict[args.model_name] 167 | self.args = args 168 | self.encoder = model_fun() 169 | self.fc = nn.Linear(dim_in, args.n_cls) 170 | self.normalize = args.normalize 171 | 172 | def forward(self, x): 173 | features = self.encoder(x) 174 | if self.normalize: 175 | features = F.normalize(features, dim=1) 176 | return features, self.fc(features) 177 | 178 | class SupConResNet(nn.Module): 179 | """encoder + head""" 180 | def __init__(self, args, multiplier = 1): 181 | super(SupConResNet, self).__init__() 182 | model_fun, dim_in = model_dict[args.model_name] 183 | self.args = args 184 | if args.dataset == 'IMAGENET100': 185 | # fine-tune 186 | if args.model_name == 'resnet18': 187 | model = models.resnet18(pretrained=True) 188 | elif args.model_name == 'resnet34': 189 | model = models.resnet34(pretrained=True) 190 | elif args.model_name == 'resnet50': 191 | model = models.resnet50(pretrained=True) 192 | elif args.model_name == 'resnet101': 193 | model = models.resnet101(pretrained=True) 194 | for name, p in model.named_parameters(): 195 | if not name.startswith('layer4'): 196 | p.requires_grad = False 197 | modules = list(model.children())[:-1] # remove last linear layer 198 | self.encoder = nn.Sequential(*modules) 199 | else: 200 | self.encoder = model_fun() 201 | self.fc = nn.Linear(dim_in, args.n_cls) 202 | self.multiplier = multiplier 203 | self.head = nn.Sequential( 204 | nn.Linear(dim_in, dim_in), 205 | nn.ReLU(inplace=True), 206 | nn.Linear(dim_in, 128) 207 | ) 208 | 209 | 210 | def forward(self, x): 211 | feat = self.encoder(x).squeeze() 212 | if self.args.normalize == 0: # official SupCon codebase 213 | # proj-L2 214 | unnorm_features = self.head(feat) 215 | features = F.normalize(unnorm_features, dim=1) 216 | elif self.args.normalize == 1: 217 | # feat-L2-proj-L2 218 | feat = F.normalize(feat, dim=1) # Following paper setting, normalize twice 219 | unnorm_features = self.head(feat) 220 | features = F.normalize(unnorm_features, dim=1) 221 | elif self.args.normalize == 2: 222 | # proj-max 223 | unnorm_features = self.head(feat) 224 | norms = torch.norm(unnorm_features, dim=1) 225 | features = unnorm_features/torch.max(norms) # learn on unit disk for projection head 226 | elif self.args.normalize == 3: 227 | # none 228 | unnorm_features = self.head(feat) 229 | features = unnorm_features 230 | elif self.args.normalize == 4: 231 | # feat-L2 232 | feat = F.normalize(feat, dim=1) 233 | unnorm_features = self.head(feat) 234 | features = unnorm_features 235 | elif self.args.normalize == 5: 236 | # feat-L2-proj-max 237 | feat = F.normalize(feat, dim=1) 238 | unnorm_features = self.head(feat) 239 | norms = torch.norm(unnorm_features, dim=1) 240 | features = unnorm_features/torch.max(norms) 241 | elif self.args.normalize == 6: 242 | # feat-max 243 | feat_norms = torch.norm(feat, dim=1) 244 | feat = feat/torch.max(feat_norms) 245 | unnorm_features = self.head(feat) 246 | features = unnorm_features 247 | elif self.args.normalize == 7: 248 | # feat-max-proj-L2 249 | feat_norms = torch.norm(feat, dim=1) 250 | feat = feat/torch.max(feat_norms) 251 | unnorm_features = self.head(feat) 252 | features = F.normalize(unnorm_features, dim=1) 253 | elif self.args.normalize == 8: 254 | # feat-max-proj-max 255 | feat_norms = torch.norm(feat, dim=1) 256 | feat = feat/torch.max(feat_norms) 257 | unnorm_features = self.head(feat) 258 | norms = torch.norm(unnorm_features, dim=1) 259 | features = unnorm_features/torch.max(norms) 260 | 261 | return feat, features # first is unnormalized feat, second is projected+normalize result 262 | 263 | class LinearClassifier(nn.Module): 264 | """Linear classifier""" 265 | def __init__(self, name='resnet50', num_classes=10): 266 | super(LinearClassifier, self).__init__() 267 | _, feat_dim = model_dict[name] 268 | self.fc = nn.Linear(feat_dim, num_classes) 269 | 270 | def forward(self, features): 271 | return self.fc(features) 272 | 273 | class MLPClassifier(nn.Module): 274 | """MLP classifier""" 275 | def __init__(self, name='resnet50', num_classes=10): 276 | super(MLPClassifier, self).__init__() 277 | _, feat_dim = model_dict[name] 278 | hidden_dim = 128 279 | self.mlp = nn.Sequential(nn.Linear(feat_dim, hidden_dim), 280 | nn.ReLU(), 281 | nn.Linear(hidden_dim, num_classes)) 282 | 283 | def forward(self, features): 284 | return self.mlp(features) 285 | 286 | 287 | 288 | 289 | def resnet18(**kwargs): 290 | return ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 291 | 292 | 293 | def resnet34(**kwargs): 294 | return ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 295 | 296 | 297 | def resnet50(**kwargs): 298 | return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 299 | 300 | 301 | def resnet101(**kwargs): 302 | return ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 303 | 304 | 305 | model_dict = { 306 | 'resnet18': [resnet18, 512], 307 | 'resnet34': [resnet34, 512], 308 | 'resnet50': [resnet50, 2048], 309 | 'resnet101': [resnet101, 2048], 310 | } 311 | -------------------------------------------------------------------------------- /param.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.optim import SGD 3 | from torch.optim.lr_scheduler import StepLR 4 | from torch.utils.data import DataLoader 5 | import math 6 | import json 7 | 8 | def adjust_learning_rate(hyper, optimizer, epoch): 9 | optim_param = hyper['optimizer'] 10 | lr = optim_param['lr'] 11 | eta_min = lr * (hyper["lr_decay_rate"] ** 3) 12 | lr = eta_min + (lr - eta_min) * ( 13 | 1 + math.cos(math.pi * epoch / hyper["epochs"])) / 2 14 | 15 | for param_group in optimizer.param_groups: 16 | param_group['lr'] = lr 17 | 18 | 19 | def warmup_learning_rate(hyper, epoch, batch_id, total_batches, optimizer): 20 | if hyper["warm"] and epoch <= hyper["warm_epochs"]: 21 | p = (batch_id + (epoch - 1) * total_batches) / \ 22 | (hyper["warm_epochs"] * total_batches) 23 | lr = hyper["warmup_from"] + p * (hyper["warmup_to"] - hyper["warmup_from"]) 24 | 25 | for param_group in optimizer.param_groups: 26 | param_group['lr'] = lr 27 | 28 | def init_optim_schedule(model : nn.Module, params : dict, train_loader : DataLoader, exp_name : str, init_optimizer = None): 29 | optim_param = params['optimizer'] 30 | if init_optimizer is None: 31 | optimizer = SGD(model.parameters(), **optim_param) 32 | else: 33 | optimizer = init_optimizer 34 | for p in optimizer.param_groups: 35 | p["lr"] = optim_param['lr'] 36 | p["initial_lr"] = optim_param['lr'] 37 | if exp_name.startswith('SupCon'): 38 | scheduler = None 39 | elif exp_name.startswith('ERM'): 40 | schedule_param = params['scheduler'] 41 | schedule_param['step_size'] *= len(train_loader) 42 | scheduler = StepLR(optimizer, **schedule_param) 43 | return optimizer, scheduler 44 | 45 | def load_params(dataset_name : str, exp_name : str) -> dict: 46 | reset = {} 47 | with open(f'./{dataset_name.lower()}/hyperparameters.json', 'r') as fp: 48 | params = json.load(fp) 49 | reset['epochs'] = params[exp_name]['epochs'] 50 | reset['optimizer'] = params[exp_name]['optimizer'] 51 | reset['lr_decay_rate'] = params[exp_name]['lr_decay_rate'] 52 | if 'scheduler' in params[exp_name]: 53 | reset['scheduler'] = params[exp_name]['scheduler'] 54 | else: 55 | reset['scheduler'] = None 56 | return reset -------------------------------------------------------------------------------- /svhn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uiuctml/HypStructure/3c11ea961da48d7c570d973a198e40e6534268be/svhn/__init__.py -------------------------------------------------------------------------------- /svhn/data.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | import numpy as np 3 | import os 4 | 5 | import torch.utils.data as data 6 | from PIL import Image 7 | import os 8 | import os.path 9 | import numpy as np 10 | 11 | 12 | class SVHN(data.Dataset): 13 | url = "" 14 | filename = "" 15 | file_md5 = "" 16 | 17 | split_list = { 18 | 'train': ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 19 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 20 | 'test': ["http://ufldl.stanford.edu/housenumbers/test_32x32.mat", 21 | "selected_test_32x32.mat", "eb5a983be6a315427106f1b164d9cef3"], # use select_svhn_data.py to pre-process svhn, different from official torchvision code 22 | 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 23 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"], 24 | 'train_and_extra': [ 25 | ["http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 26 | "train_32x32.mat", "e26dedcc434d2e4c54c9b2d4a06d8373"], 27 | ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 28 | "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]]} 29 | 30 | def __init__(self, root, split='train', 31 | transform=None, target_transform=None, download=False): 32 | self.root = root 33 | self.transform = transform 34 | self.target_transform = target_transform 35 | self.split = split # training set or test set or extra set 36 | 37 | if self.split not in self.split_list: 38 | raise ValueError('Wrong split entered! Please use split="train" ' 39 | 'or split="extra" or split="test" ' 40 | 'or split="train_and_extra" ') 41 | 42 | if self.split == "train_and_extra": 43 | self.url = self.split_list[split][0][0] 44 | self.filename = self.split_list[split][0][1] 45 | self.file_md5 = self.split_list[split][0][2] 46 | else: 47 | self.url = self.split_list[split][0] 48 | self.filename = self.split_list[split][1] 49 | self.file_md5 = self.split_list[split][2] 50 | 51 | # import here rather than at top of file because this is 52 | # an optional dependency for torchvision 53 | import scipy.io as sio 54 | 55 | # reading(loading) mat file as array 56 | loaded_mat = sio.loadmat(os.path.join(root, self.filename)) 57 | 58 | if self.split == "test": 59 | self.data = loaded_mat['X'] 60 | self.targets = loaded_mat['y'] 61 | # Note label 10 == 0 so modulo operator required 62 | self.targets = (self.targets % 10).squeeze() # convert to zero-based indexing 63 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 64 | else: 65 | self.data = loaded_mat['X'] 66 | self.targets = loaded_mat['y'] 67 | 68 | if self.split == "train_and_extra": 69 | extra_filename = self.split_list[split][1][1] 70 | loaded_mat = sio.loadmat(os.path.join(root, extra_filename)) 71 | self.data = np.concatenate([self.data, 72 | loaded_mat['X']], axis=3) 73 | self.targets = np.vstack((self.targets, 74 | loaded_mat['y'])) 75 | # Note label 10 == 0 so modulo operator required 76 | self.targets = (self.targets % 10).squeeze() # convert to zero-based indexing 77 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 78 | 79 | def __getitem__(self, index): 80 | if self.split == "test": 81 | img, target = self.data[index], self.targets[index] 82 | else: 83 | img, target = self.data[index], self.targets[index] 84 | 85 | # doing this so that it is consistent with all other datasets 86 | # to return a PIL Image 87 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 88 | 89 | if self.transform is not None: 90 | img = self.transform(img) 91 | 92 | if self.target_transform is not None: 93 | target = self.target_transform(target) 94 | 95 | return img, int(target) 96 | 97 | def __len__(self): 98 | if self.split == "test": 99 | return len(self.data) 100 | else: 101 | return len(self.data) 102 | -------------------------------------------------------------------------------- /svhn/select_svhn_data.py: -------------------------------------------------------------------------------- 1 | import scipy.io as sio 2 | import os 3 | import numpy as np 4 | 5 | root = '/path/to/svhn/dataset/' 6 | filename = 'test_32x32.mat' 7 | 8 | loaded_mat = sio.loadmat(os.path.join(root, filename)) 9 | 10 | data = loaded_mat['X'] 11 | targets = loaded_mat['y'] 12 | 13 | data = np.transpose(data, (3, 0, 1, 2)) 14 | 15 | selected_data = [] 16 | selected_targets = [] 17 | count = np.zeros(11) 18 | 19 | for i, y in enumerate(targets): 20 | if count[y[0]] < 1000: 21 | selected_data.append(data[i]) 22 | selected_targets.append(y) 23 | count[y[0]] += 1 24 | 25 | selected_data = np.array(selected_data) 26 | selected_targets = np.array(selected_targets) 27 | 28 | selected_data = np.transpose(selected_data, (1, 2, 3, 0)) 29 | 30 | save_mat = {'X': selected_data, 'y': selected_targets} 31 | 32 | save_filename = 'selected_test_32x32.mat' 33 | sio.savemat(os.path.join(root, save_filename), save_mat) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import torch 5 | import numpy as np 6 | import faiss 7 | import sklearn.metrics as skm 8 | from sklearn.covariance import ledoit_wolf 9 | from sklearn.neighbors import KNeighborsClassifier 10 | 11 | from geom.frechet import Frechet 12 | from geom.poincare import expmap0, logmap 13 | 14 | def seed_everything(seed : int) -> None: 15 | ''' 16 | Seed everything for reproducibility. 17 | Args: 18 | seed : any integer 19 | ''' 20 | random.seed(seed) 21 | os.environ['PYTHONHASHSEED'] = str(seed) 22 | np.random.seed(seed) 23 | torch.manual_seed(seed) 24 | torch.cuda.manual_seed(seed) 25 | 26 | 27 | def get_scores_one_cluster(ftrain, ftest, food, shrunkcov=False): 28 | if shrunkcov: 29 | print("Using ledoit-wolf covariance estimator.") 30 | cov = lambda x: ledoit_wolf(x)[0] 31 | else: 32 | cov = lambda x: np.cov(x.T, bias=True) 33 | 34 | # ToDO: Simplify these equations 35 | dtest = np.sum( 36 | (ftest - np.mean(ftrain, axis=0, keepdims=True)) 37 | * ( 38 | np.linalg.pinv(cov(ftrain)).dot( 39 | (ftest - np.mean(ftrain, axis=0, keepdims=True)).T 40 | ) 41 | ).T, 42 | axis=-1, 43 | ) 44 | 45 | dood = np.sum( 46 | (food - np.mean(ftrain, axis=0, keepdims=True)) 47 | * ( 48 | np.linalg.pinv(cov(ftrain)).dot( 49 | (food - np.mean(ftrain, axis=0, keepdims=True)).T 50 | ) 51 | ).T, 52 | axis=-1, 53 | ) 54 | 55 | return dtest, dood 56 | 57 | def get_scores_multi_cluster(ftrain, ftest, food, ypred): 58 | xc = [ftrain[ypred == i] for i in np.unique(ypred)] 59 | 60 | din = [ 61 | np.sum( 62 | (ftest - np.mean(x, axis=0, keepdims=True)) 63 | * ( 64 | np.linalg.pinv(np.cov(x.T, bias=True)).dot( 65 | (ftest - np.mean(x, axis=0, keepdims=True)).T 66 | ) 67 | ).T, 68 | axis=-1, 69 | ) 70 | for x in xc 71 | ] 72 | dood = [ 73 | np.sum( 74 | (food - np.mean(x, axis=0, keepdims=True)) 75 | * ( 76 | np.linalg.pinv(np.cov(x.T, bias=True)).dot( 77 | (food - np.mean(x, axis=0, keepdims=True)).T 78 | ) 79 | ).T, 80 | axis=-1, 81 | ) 82 | for x in xc 83 | ] 84 | 85 | din = np.min(din, axis=0) 86 | dood = np.min(dood, axis=0) 87 | 88 | return din, dood 89 | 90 | 91 | def get_scores(ftrain, ftest, food, labelstrain, clusters): 92 | if clusters == 1: 93 | return get_scores_one_cluster(ftrain, ftest, food) 94 | else: 95 | ypred = labelstrain 96 | return get_scores_multi_cluster(ftrain, ftest, food, ypred) 97 | 98 | def get_clusters(ftrain, nclusters): 99 | kmeans = faiss.Kmeans( 100 | ftrain.shape[1], nclusters, niter=100, verbose=False, gpu=False 101 | ) 102 | kmeans.train(np.random.permutation(ftrain)) 103 | _, ypred = kmeans.assign(ftrain) 104 | return ypred 105 | 106 | 107 | def get_eval_results(ftrain, ftest, food, labelstrain, clusters): 108 | """ 109 | None. 110 | """ 111 | dtest, dood = get_scores(ftrain, ftest, food, labelstrain, clusters) 112 | 113 | fpr95 = get_fpr(dtest, dood) 114 | auroc, aupr = get_roc_sklearn(dtest, dood), get_pr_sklearn(dtest, dood) 115 | return fpr95, auroc, aupr 116 | 117 | 118 | def get_roc_sklearn(xin, xood): 119 | labels = [0] * len(xin) + [1] * len(xood) 120 | data = np.concatenate((xin, xood)) 121 | auroc = skm.roc_auc_score(labels, data) 122 | return auroc 123 | 124 | 125 | def get_pr_sklearn(xin, xood): 126 | labels = [0] * len(xin) + [1] * len(xood) 127 | data = np.concatenate((xin, xood)) 128 | aupr = skm.average_precision_score(labels, data) 129 | return aupr 130 | 131 | 132 | def get_fpr(xin, xood): 133 | return np.sum(xood < np.percentile(xin, 95)) / len(xood) 134 | 135 | 136 | #### Utils from KNN-OOD 137 | def fpr_and_fdr_at_recall(y_true, y_score, recall_level=0.95, pos_label=None): 138 | classes = np.unique(y_true) 139 | if (pos_label is None and 140 | not (np.array_equal(classes, [0, 1]) or 141 | np.array_equal(classes, [-1, 1]) or 142 | np.array_equal(classes, [0]) or 143 | np.array_equal(classes, [-1]) or 144 | np.array_equal(classes, [1]))): 145 | raise ValueError("Data is not binary and pos_label is not specified") 146 | elif pos_label is None: 147 | pos_label = 1. 148 | 149 | # make y_true a boolean vector 150 | y_true = (y_true == pos_label) 151 | 152 | # sort scores and corresponding truth values 153 | desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] 154 | y_score = y_score[desc_score_indices] 155 | y_true = y_true[desc_score_indices] 156 | 157 | # y_score typically has many tied values. Here we extract 158 | # the indices associated with the distinct values. We also 159 | # concatenate a value for the end of the curve. 160 | distinct_value_indices = np.where(np.diff(y_score))[0] 161 | threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1] 162 | 163 | # accumulate the true positives with decreasing threshold 164 | tps = stable_cumsum(y_true)[threshold_idxs] 165 | fps = 1 + threshold_idxs - tps # add one because of zero-based indexing 166 | 167 | thresholds = y_score[threshold_idxs] 168 | 169 | recall = tps / tps[-1] 170 | 171 | last_ind = tps.searchsorted(tps[-1]) 172 | sl = slice(last_ind, None, -1) # [last_ind::-1] 173 | recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl] 174 | 175 | cutoff = np.argmin(np.abs(recall - recall_level)) 176 | 177 | return fps[cutoff] / (np.sum(np.logical_not(y_true))) 178 | 179 | 180 | def stable_cumsum(arr, rtol=1e-05, atol=1e-08): 181 | """Use high precision for cumsum and check that final value matches sum 182 | Parameters 183 | ---------- 184 | arr : array-like 185 | To be cumulatively summed as flat 186 | rtol : float 187 | Relative tolerance, see ``np.allclose`` 188 | atol : float 189 | Absolute tolerance, see ``np.allclose`` 190 | """ 191 | out = np.cumsum(arr, dtype=np.float64) 192 | expected = np.sum(arr, dtype=np.float64) 193 | if not np.allclose(out[-1], expected, rtol=rtol, atol=atol): 194 | raise RuntimeError('cumsum was found to be unstable: ' 195 | 'its last element does not correspond to sum') 196 | return out 197 | 198 | def get_measures(_pos, _neg, recall_level=0.95): 199 | pos = np.array(_pos[:]).reshape((-1, 1)) 200 | neg = np.array(_neg[:]).reshape((-1, 1)) 201 | examples = np.squeeze(np.vstack((pos, neg))) 202 | labels = np.zeros(len(examples), dtype=np.int32) 203 | labels[:len(pos)] += 1 204 | 205 | auroc = skm.roc_auc_score(labels, examples) 206 | aupr = skm.average_precision_score(labels, examples) 207 | fpr = fpr_and_fdr_at_recall(labels, examples, recall_level) 208 | 209 | return auroc, aupr, fpr 210 | 211 | 212 | def get_mean_prec(train_features, train_labels, n_cls): 213 | ''' 214 | used for Mahalanobis score. Calculate class-wise mean and inverse covariance matrix 215 | ''' 216 | classwise_mean = torch.empty(n_cls, train_features.shape[1], device = 'cuda') 217 | classwise_cov = torch.empty(n_cls, train_features.shape[1], train_features.shape[1], device = 'cuda') 218 | all_features = torch.zeros((0, train_features.shape[1]), device = 'cuda') 219 | classwise_idx = {} 220 | all_features = train_features 221 | 222 | targets = train_labels 223 | for class_id in range(n_cls): 224 | classwise_idx[class_id] = np.where(targets == class_id)[0] 225 | 226 | for cls in range(n_cls): 227 | classwise_mean[cls] = torch.mean(all_features[classwise_idx[cls]].float(), dim = 0) 228 | classwise_cov[cls] = torch.cov(all_features[classwise_idx[cls]].float().T) 229 | 230 | tied_cov = torch.sum(classwise_cov,dim=0) / train_features.shape[1] 231 | precision = torch.linalg.pinv(tied_cov).float() 232 | precision = precision.to(classwise_mean.device) 233 | return classwise_mean, precision 234 | 235 | def get_Mahalanobis_score(test_features, n_cls, classwise_mean, precision, in_dist = True): 236 | ''' 237 | Compute the proposed Mahalanobis confidence score on input dataset 238 | ''' 239 | Mahalanobis_score_all = [] 240 | with torch.no_grad(): 241 | for i in range(n_cls): 242 | class_mean = classwise_mean[i] 243 | zero_f = test_features - class_mean 244 | Mahalanobis_dist = -torch.mm(torch.mm(zero_f, precision), zero_f.t()).diag() 245 | if i == 0: 246 | Mahalanobis_score = Mahalanobis_dist.view(-1,1) 247 | else: 248 | Mahalanobis_score = torch.cat((Mahalanobis_score, Mahalanobis_dist.view(-1,1)), 1) 249 | Mahalanobis_score, _ = torch.max(Mahalanobis_score, dim=1) 250 | Mahalanobis_score_all.extend(-Mahalanobis_score.cpu().numpy()) 251 | 252 | return np.asarray(Mahalanobis_score_all, dtype=np.float32) 253 | 254 | 255 | def exponential_map(x, c=1.0): 256 | """ 257 | Exponential map for the Poincare ball model. 258 | 259 | Parameters: 260 | x (np.ndarray): Input vector in Euclidean space. 261 | c (float): Curvature parameter of hyperbolic space. 262 | 263 | Returns: 264 | np.ndarray: Mapped vector in hyperbolic space. 265 | """ 266 | norm_x = np.linalg.norm(x, axis=-1, keepdims=True) 267 | return np.tanh(np.sqrt(c) * norm_x) * x / (np.sqrt(c) * norm_x) 268 | 269 | def embed_in_poincare_ball(X, c=1.0): 270 | """ 271 | Embeds points in the Poincare ball. 272 | 273 | Parameters: 274 | X (np.ndarray): [N x d] dimensional matrix of Euclidean vectors. 275 | c (float): Curvature parameter of hyperbolic space. 276 | 277 | Returns: 278 | np.ndarray: [N x d] dimensional matrix of vectors in hyperbolic space. 279 | """ 280 | return np.array([exponential_map(x, c=c) for x in X]) 281 | 282 | def distance_from_origin(Y): 283 | """ 284 | Computes the distance from the origin for each point in hyperbolic space. 285 | 286 | Parameters: 287 | Y (np.ndarray): [N x d] dimensional matrix of vectors in hyperbolic space. 288 | 289 | Returns: 290 | np.ndarray: Distance of each point from the origin. 291 | """ 292 | return np.arccosh(1 + 2 * np.sum(np.square(Y), axis=-1) / (1 - np.sum(np.square(Y), axis=-1))) 293 | 294 | def distance_from_origin_stable(Y, epsilon=1e-7): 295 | """ 296 | Computes the distance from the origin for each point in hyperbolic space, with numerical stability. 297 | 298 | Parameters: 299 | Y (np.ndarray): [N x d] dimensional matrix of vectors in hyperbolic space. 300 | epsilon (float): Small value to ensure numerical stability. 301 | 302 | Returns: 303 | np.ndarray: Distance of each point from the origin. 304 | """ 305 | # Ensure the norm of Y is less than 1 to avoid numerical issues 306 | norm_Y = np.linalg.norm(Y, axis=-1) 307 | Y = Y / (1 + epsilon) * np.where(norm_Y >= 1, (1 - epsilon) / norm_Y, 1).reshape(-1, 1) 308 | 309 | # Compute the distance with an added epsilon for numerical stability 310 | return np.arccosh(1 + 2 * np.sum(np.square(Y), axis=-1) / ((1 - np.sum(np.square(Y), axis=-1)) + epsilon)) 311 | 312 | 313 | def clamp(value, eps=1e-10): 314 | if value >= 1: 315 | return 1 - eps 316 | else: 317 | return value 318 | 319 | def mobius_addition(z, y, c): 320 | norm_y = np.linalg.norm(y) 321 | norm_z = np.linalg.norm(z) 322 | numerator = (1 + 2*c*np.dot(z, y) + c*norm_y**2)*z + (1 - c*norm_z**2)*y 323 | denominator = 1+2*c*np.dot(z,y) + c**2*norm_z**2*norm_y**2 324 | return numerator / denominator 325 | 326 | def poincare_logarithm_map(z, y, c=1.0): 327 | zy = mobius_addition(-z, y, c) 328 | norm_zy = clamp(np.linalg.norm(zy)) 329 | norm_z = clamp(np.linalg.norm(z)) 330 | lambda_c = 2/(1 - c*norm_z**2) 331 | return 2/(np.sqrt(c)*lambda_c)*np.arctanh(np.sqrt(c)*norm_zy)*(zy/norm_zy) --------------------------------------------------------------------------------