├── images ├── toy_example.png └── curation_pipeline.png ├── requirements.txt ├── vis ├── __init__.py └── generalized_kmeans_1d.py ├── __init__.py ├── src ├── __init__.py ├── utils.py ├── clusters.py ├── hierarchical_sampling.py ├── hierarchical_kmeans_gpu.py ├── dist_comm.py ├── kmeans_gpu.py └── distributed_kmeans_gpu.py ├── scripts ├── __init__.py ├── run_hierarchical_sampling.py ├── split_clusters.py ├── hierarchical_kmeans_launcher.py └── run_distributed_kmeans.py ├── setup.py ├── configs ├── 2levels_random_embeddings.yaml └── 4levels_web_based_images.yaml ├── CONTRIBUTING.md ├── CODE_OF_CONDUCT.md ├── README.md └── LICENSE /images/toy_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ssl-data-curation/HEAD/images/toy_example.png -------------------------------------------------------------------------------- /images/curation_pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/ssl-data-curation/HEAD/images/curation_pipeline.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu118 2 | torch==2.2 3 | matplotlib==3.8.2 4 | scipy==1.11.4 5 | numpy==1.24.4 6 | omegaconf 7 | scikit-learn>=1.5.0 8 | tqdm 9 | ipykernel 10 | -------------------------------------------------------------------------------- /vis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | 9 | 10 | setup( 11 | name="ssl_data_curation", 12 | packages=find_packages(), 13 | ) 14 | -------------------------------------------------------------------------------- /configs/2levels_random_embeddings.yaml: -------------------------------------------------------------------------------- 1 | # Number of levels in hierarchical k-means. 2 | n_levels: 2 3 | # Number of updates of centroids in the main k-means loop. 4 | n_iters: 50 5 | # Number of clusters in each level of hierarchical k-means. 6 | n_clusters: 7 | - 5000 8 | - 1000 9 | # If > 1, run the level in two steps. First, k-means is executed once. 10 | # Then, each obtained cluster is splitted into "n_split" smaller clusters, 11 | # which are considered final and used in the subsequent level. 12 | n_splits: 13 | - 1 14 | - 1 15 | # Number of resampling steps in each level. 16 | n_resampling_steps: 17 | - 10 18 | - 10 19 | # Number of data points sampled from each cluster in the resampling steps. 20 | # It is roughly half the average cluster size in each level. 21 | sample_size: 22 | - 10 23 | - 3 24 | # Specified if running only on a subset of the data pool. 25 | # For example, we extract embeddings for all images in the data pool, 26 | # but run the curation pipeline only on a deduplicated subset. 27 | subset_indices_path: null 28 | checkpoint_period: 1000 29 | dtype: float64 30 | high_precision: float64 31 | ngpus_per_node: 32 | - 2 33 | - 2 34 | nnodes: 35 | - 1 36 | - 1 37 | ncpus_per_gpu: 10 38 | sampling_strategy: c 39 | slurm_partition: null 40 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to this repository 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://bugbounty.meta.com/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to this repository, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /configs/4levels_web_based_images.yaml: -------------------------------------------------------------------------------- 1 | # Number of levels in hierarchical k-means. 2 | n_levels: 4 3 | # Number of updates of centroids in the main k-means loop. 4 | n_iters: 50 5 | # Number of clusters in each level of hierarchical k-means. 6 | # For efficiency in the first level, we run first a k-means 7 | # with 100k clusters, then split each cluster into 100 8 | # smaller ones to have 10M clusters. 9 | n_clusters: 10 | - 100_000 11 | - 500_000 12 | - 50_000 13 | - 10_000 14 | # If > 1, run the level in two steps. First, k-means is executed once. 15 | # Then, each obtained cluster is splitted into "n_split" smaller clusters, 16 | # which are considered final and used in the subsequent level. 17 | n_splits: 18 | - 100 19 | - 1 20 | - 1 21 | - 1 22 | # Number of resampling steps in each level. 23 | # For efficiency, we do not use resampling in the first level. 24 | n_resampling_steps: 25 | - 1 26 | - 10 27 | - 10 28 | - 10 29 | # Number of data points sampled from each cluster in the resampling steps. 30 | # It is roughly half the average cluster size in each level. 31 | sample_size: 32 | - 1 33 | - 10 34 | - 5 35 | - 3 36 | # Specified if running only on a subset of the data pool. 37 | # For example, we extract embeddings for all images in the data pool, 38 | # but run the curation pipeline only on a deduplicated subset. 39 | subset_indices_path: null 40 | checkpoint_period: 10_000 41 | dtype: float64 42 | high_precision: float64 43 | ngpus_per_node: 44 | - 8 45 | - 8 46 | - 8 47 | - 8 48 | nnodes: 49 | - 16 50 | - 2 51 | - 1 52 | - 1 53 | ncpus_per_gpu: 10 54 | sampling_strategy: c 55 | slurm_partition: null 56 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import logging 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | import torch 13 | 14 | 15 | def create_clusters_from_cluster_assignment( 16 | cluster_assignment: np.array, 17 | num_clusters: int, 18 | return_object_array: bool = True, 19 | ): 20 | """ 21 | Build clusters from cluster assignment. 22 | """ 23 | ID = np.argsort(cluster_assignment) 24 | sorted_cluster_assigment = cluster_assignment[ID] 25 | index_split = np.searchsorted(sorted_cluster_assigment, list(range(num_clusters))) 26 | clusters = np.split(ID, index_split[1:]) 27 | if return_object_array: 28 | return np.array(clusters, dtype=object) 29 | else: 30 | return clusters 31 | 32 | 33 | def find_all_checkpoints(save_dir, pattern): 34 | """ 35 | Parameters: 36 | pattern: str 37 | checkpoint name format _%d., 38 | e.g., kmpp_checkpoint_%d.pth 39 | """ 40 | save_dir = Path(save_dir) 41 | ckpt_list = [str(el.stem) for el in save_dir.glob(pattern.replace("%d", "*"))] 42 | ckpt_list = [int(el.split("_")[-1]) for el in ckpt_list] 43 | ckpt_list = sorted(ckpt_list) 44 | return [Path(save_dir, pattern % el) for el in ckpt_list] 45 | 46 | 47 | def get_last_valid_checkpoint(save_dir, pattern): 48 | """ 49 | Find path to the last checkpoint. 50 | """ 51 | ckpt_list = find_all_checkpoints(save_dir, pattern) 52 | for ckpt_path in ckpt_list[::-1]: 53 | try: 54 | if ".pth" in pattern: 55 | _ = torch.load(ckpt_path, map_location="cpu") 56 | elif ".npy" in pattern: 57 | _ = np.load(ckpt_path) 58 | else: 59 | raise ValueError("Pattern not recognized!") 60 | return ckpt_path 61 | except Exception: 62 | continue 63 | return None 64 | 65 | 66 | def _delete_old_checkpoint( 67 | save_dir, current_iter, checkpoint_period, max_num_checkpoints, pattern 68 | ): 69 | Path( 70 | save_dir, pattern % (current_iter - checkpoint_period * max_num_checkpoints) 71 | ).unlink(missing_ok=True) 72 | 73 | 74 | def setup_logging( 75 | *, 76 | name: str = None, 77 | level: int = logging.INFO, 78 | capture_warnings: bool = True, 79 | ) -> None: 80 | """ 81 | Basic setting for logger. 82 | """ 83 | logging.captureWarnings(capture_warnings) 84 | 85 | logger = logging.getLogger(name) 86 | logger.setLevel(level) 87 | 88 | if logger.hasHandlers(): 89 | return 90 | 91 | fmt_prefix = ( 92 | "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] " 93 | ) 94 | fmt_message = "%(message)s" 95 | fmt = fmt_prefix + fmt_message 96 | datefmt = "%Y%m%d %H:%M:%S" 97 | formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) 98 | 99 | handler = logging.StreamHandler(sys.stdout) 100 | handler.setLevel(level) 101 | handler.setFormatter(formatter) 102 | 103 | logger.propagate = False 104 | logger.addHandler(handler) 105 | return 106 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /scripts/run_hierarchical_sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from argparse import ArgumentParser 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | 13 | from src.clusters import HierarchicalCluster 14 | from src.utils import setup_logging 15 | from src.hierarchical_sampling import hierarchical_sampling 16 | 17 | logger = logging.getLogger("hkmeans") 18 | 19 | if __name__ == "__main__": 20 | parser = ArgumentParser() 21 | parser.add_argument("--save", action="store_true") 22 | parser.add_argument("--clustering_path", "-clus", type=str, required=True) 23 | parser.add_argument( 24 | "--target_size", 25 | type=int, 26 | required=True, 27 | help="Target size of the sampled set" 28 | ) 29 | parser.add_argument( 30 | "--multiplier", 31 | "-m", 32 | type=int, 33 | default=1, 34 | help="Maximum number of times an image is selected" 35 | ) 36 | parser.add_argument( 37 | "--sampling_strategy", 38 | "-ss", 39 | type=str, 40 | default="r", 41 | help='"r" for random, "c" for closest', 42 | ) 43 | parser.add_argument( 44 | "--sort_indices", 45 | action="store_true", 46 | help="If true, sort indices in increasing order", 47 | ) 48 | parser.add_argument( 49 | "--name_suffix", 50 | type=str, 51 | default="", 52 | help="Suffix to add to the indice file name", 53 | ) 54 | parser.add_argument( 55 | "--valid_indices_path", 56 | type=str, 57 | default=None, 58 | help=( 59 | "Path to .npy file containing valid indices of the base dataset. " 60 | "The clustering is computed only on these valid images." 61 | ), 62 | ) 63 | parser.add_argument( 64 | "--cluster_fname", 65 | type=str, 66 | default="sorted_clusters.npy", 67 | help="name of files containing clusters", 68 | ) 69 | parser.add_argument("--save_dir_name", type=str, default="curated_datasets") 70 | 71 | args = parser.parse_args() 72 | args.clustering_path = Path(args.clustering_path).resolve() 73 | setup_logging() 74 | logger.info(f"args: {args}") 75 | 76 | cl = HierarchicalCluster.from_file( 77 | cluster_path=args.clustering_path, 78 | cluster_fname=args.cluster_fname 79 | ) 80 | 81 | sampled_indices = hierarchical_sampling( 82 | cl, 83 | args.target_size, 84 | args.multiplier, 85 | args.sampling_strategy, 86 | ) 87 | if args.valid_indices_path is not None: 88 | valid_indices = np.load(args.valid_indices_path) 89 | assert len(valid_indices) == np.sum( 90 | [len(el) for el in cl.clusters[1]] 91 | ), "Number of images is not equal to valid_indices size" 92 | sampled_indices = valid_indices[sampled_indices] 93 | 94 | if args.sort_indices: 95 | sampled_indices = np.sort(sampled_indices) 96 | 97 | num_images = len(sampled_indices) 98 | logger.info(f"Number of selected data points: {num_images}") 99 | 100 | save_indices_path = Path( 101 | args.clustering_path, 102 | args.save_dir_name, 103 | f'{cl.n_levels}{args.sampling_strategy}_mul{args.multiplier}_' 104 | f'{args.target_size}_balanced_selection.npy' 105 | ) 106 | if len(args.name_suffix) > 0: 107 | save_indices_path = Path( 108 | str(save_indices_path).replace(".npy", f"_{args.name_suffix}.npy") 109 | ) 110 | logger.info(f"Indices will be saved to {str(save_indices_path.resolve())}") 111 | if args.save: 112 | Path(args.clustering_path, args.save_dir_name).mkdir(exist_ok=True) 113 | np.save(save_indices_path, sampled_indices) 114 | logger.info("Indices are saved!") 115 | -------------------------------------------------------------------------------- /src/clusters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from pathlib import Path 9 | import pickle 10 | from typing import Dict, List 11 | 12 | import numpy as np 13 | 14 | 15 | logger = logging.getLogger("hkmeans") 16 | 17 | 18 | def load_clusters_from_file(fpath): 19 | """ 20 | Utility to load clusters fromj different file formats. 21 | """ 22 | if Path(fpath).suffix == ".pkl": 23 | with open(fpath, "rb") as f: 24 | return np.array(pickle.load(f), dtype=object) 25 | else: 26 | return np.load(Path(fpath), allow_pickle=True) 27 | 28 | class HierarchicalCluster: 29 | """ 30 | Class representing a hierarchy of clusters returned by hierarchical k-means. 31 | """ 32 | def __init__(self): 33 | self.cluster_path = None 34 | self.n_levels = None 35 | self.cluster_fname = None 36 | self.is_loaded = False 37 | self.is_processed = False 38 | self.n_clusters = {} 39 | self.clusters = {} 40 | self.flat_clusters = {} 41 | self.clusters_size = {} 42 | self.flat_clusters_size = {} 43 | self.size_order = {} 44 | self.flat_size_order = {} 45 | 46 | def load_clusters_from_file(self): 47 | for level in range(1, 1 + self.n_levels): 48 | self.clusters[level] = load_clusters_from_file( 49 | Path( 50 | self.cluster_path, 51 | f"level{level}", 52 | self.cluster_fname 53 | ) 54 | ) 55 | self.n_clusters[level] = len(self.clusters[level]) 56 | self.is_loaded = True 57 | 58 | def process_clusters(self): 59 | if not self.is_loaded: 60 | raise RuntimeError("Clusters must be loaded before being processed") 61 | logger.info("Computing flat clusters") 62 | self.flat_clusters[1] = self.clusters[1] 63 | for level in range(2, 1 + self.n_levels): 64 | current_non_flat = self.clusters[level] 65 | prev_flat = self.flat_clusters[level - 1] 66 | self.flat_clusters[level] = np.array( 67 | [ 68 | np.concatenate([prev_flat[el] for el in clus]) 69 | if len(clus) > 0 else np.array([]) 70 | for clus in current_non_flat 71 | ], 72 | dtype=object, 73 | ) 74 | 75 | logger.info("Computing cluster length") 76 | for level, clus in self.clusters.items(): 77 | self.clusters_size[level] = np.array([len(el) for el in clus]) 78 | 79 | for level, clus in self.flat_clusters.items(): 80 | self.flat_clusters_size[level] = np.array([len(el) for el in clus]) 81 | 82 | logger.info("Sorting clusters by length") 83 | for level, clsize in self.clusters_size.items(): 84 | self.size_order[level] = np.argsort(clsize)[::-1] 85 | 86 | for level, flat_clsize in self.flat_clusters_size.items(): 87 | self.flat_size_order[level] = np.argsort(flat_clsize)[::-1] 88 | 89 | self.is_processed = True 90 | 91 | @staticmethod 92 | def from_file( 93 | cluster_path, 94 | cluster_fname="sorted_clusters.npy", 95 | ): 96 | """ 97 | Method for reading hierarchical clusters from files 98 | """ 99 | logger.info("Loading hierarchical clusters from file.") 100 | cl = HierarchicalCluster() 101 | cl.cluster_path = cluster_path 102 | cl.cluster_fname = cluster_fname 103 | cl.n_levels = 0 104 | while True: 105 | if Path(cl.cluster_path, f"level{cl.n_levels + 1}").exists(): 106 | cl.n_levels += 1 107 | else: 108 | break 109 | cl.load_clusters_from_file() 110 | cl.process_clusters() 111 | return cl 112 | 113 | @staticmethod 114 | def from_dict(clusters: List[Dict]): 115 | """ 116 | Read hierarchical clusters from a list of dictionaries. 117 | 118 | Parameters: 119 | clusters: List[Dict] 120 | Each element is a dictionary containing a field name "clusters". 121 | An example is the output of hierarchical_kmeans_gpu.hierarchical_kmeans 122 | 123 | Return: 124 | A instance of HierarchicalCluster. 125 | """ 126 | logger.info("Loading hierarchical clusters from dictionaries.") 127 | cl = HierarchicalCluster() 128 | cl.n_levels = len(clusters) 129 | for level in range(1, 1 + cl.n_levels): 130 | cl.clusters[level] = clusters[level - 1]["clusters"] 131 | cl.n_clusters[level] = len(cl.clusters[level]) 132 | cl.is_loaded = True 133 | cl.process_clusters() 134 | return cl 135 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Automatic Data Curation for Self-Supervised Learning: A Clustering-Based Approach 2 | 3 | **[FAIR at Meta](https://ai.facebook.com/research/)** 4 | 5 | *Huy V. Vo, 6 | Vasil Khalidov, 7 | Timothée Darcet, 8 | Théo Moutakanni, 9 | Nikita Smetanin, 10 | Marc Szafraniec, 11 | Hugo Touvron, 12 | Camille Couprie, 13 | Maxime Oquab, 14 | Armand Joulin, 15 | Hervé Jégou, 16 | Patrick Labatut, 17 | Piotr Bojanowski* 18 | 19 | PyTorch implementation for the data curation pipeline with hierarchical k-means. For more detail, see the paper **[Automatic Data Curation for Self-Supervised Learning: A Clustering-Based Approach](https://arxiv.org/abs/2405.15613)**. 20 | 21 |

22 | data curation pipeline 23 |

24 | 25 | ## Contents 26 | - [Installation](#installation) 27 | - [Running hierarchical k-means](#running-hierarchical-k-means) 28 | * [On small data](#on-small-data) 29 | * [On large data](#on-large-data) 30 | - [Notebook](#notebook) 31 | - [Contributing](#contributing) 32 | - [License](#license) 33 | - [Citation](#citation) 34 | 35 | ## Installation 36 | ``` 37 | git clone git@github.com:facebookresearch/ssl-data-curation.git 38 | cd ssl-data-curation 39 | conda create -n ssl-data-curation python=3.10 40 | conda activate ssl-data-curation 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | ## Running hierarchical k-means 45 | ### On small data 46 | We provide below an example of a 2-level hierarchical k-means on a small toy random dataset. We first run hierarchical k-means on the toy dataset then sample 1000 points from it with hierarchical sampling. A visualisation is provided in [vis/notebook.ipynb](vis/notebook.ipynb). 47 | ``` 48 | import torch 49 | import numpy as np 50 | 51 | from src.clusters import HierarchicalCluster 52 | from src import ( 53 | hierarchical_kmeans_gpu as hkmg, 54 | hierarchical_sampling as hs 55 | ) 56 | 57 | def make_ring(n, rmin, rmax): 58 | r = np.random.rand(n) * (rmax - rmin) + rmin 59 | alpha = np.random.rand(n) * 2 * np.pi 60 | return np.vstack([r * np.cos(alpha), r * np.sin(alpha)]).T 61 | 62 | data = np.concatenate([ 63 | make_ring(20000, 0.7, 1.0) + np.array([-2.2, 1.]), 64 | make_ring(200, 0.7, 1.0) + np.array([0., 1.]), 65 | make_ring(1000, 0.7, 1.0) + np.array([2.2, 1.]), 66 | make_ring(500, 0.7, 1.0) + np.array([-1.2, 0.2]), 67 | make_ring(8000, 0.7, 1.0) + np.array([1.2, 0.2]), 68 | ]) 69 | 70 | clusters = hkmg.hierarchical_kmeans_with_resampling( 71 | data=torch.tensor(data, device="cuda", dtype=torch.float32), 72 | n_clusters=[1000, 300], 73 | n_levels=2, 74 | sample_sizes=[15, 2], 75 | verbose=False, 76 | ) 77 | 78 | cl = HierarchicalCluster.from_dict(clusters) 79 | sampled_indices = hs.hierarchical_sampling(cl, target_size=1000) 80 | ``` 81 | 82 |

83 | data curation pipeline 84 |

85 | 86 | ### On large data 87 | To launch hierarchical k-means on large data, we need to prepare a config file. We provide below an example illustrating how to launch a 2-level hierarchical k-means on random embeddings with config in [configs/2levels_random_embeddings.yaml](configs/2levels_random_embeddings.yaml). 88 | ``` 89 | # Prepare the experiment 90 | cd ssl-data-curation 91 | mkdir -p data 92 | cd scripts 93 | python -c 'import numpy as np; np.save( "../data/100k_random.npy", np.random.randn(100000,256))' 94 | python hierarchical_kmeans_launcher.py \ 95 | --exp_dir ../data/2levels_random_embeddings \ 96 | --embeddings_path ../data/100k_random.npy \ 97 | --config_file ../configs/2levels_random_embeddings.yaml 98 | 99 | cd ../data/2levels_random_embeddings 100 | # Launch with slurm 101 | bash launcher.sh 102 | # Launch locally if only 1 node is used 103 | # bash local_launcher.sh 104 | 105 | cd ssl-data-curation/scripts 106 | # Sampled indices will be saved in ssl-data-curation/data/2levels_random_embeddings/curated_datasets 107 | PYTHONPATH=.. python run_hierarchical_sampling.py \ 108 | --clustering_path ../data/2levels_random_embeddings \ 109 | --target_size 20000 \ 110 | --save 111 | ``` 112 | 113 | We also provide the config used for our web-based image data pool in [configs/4levels_web_based_images.yaml](configs/4levels_web_based_images.yaml). 114 | 115 | ## Notebook 116 | We provide a [notebook](vis/notebook.ipynb) to reproduce visualizations in the paper and show additional examples. 117 | 118 | ## Contributing 119 | See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). 120 | 121 | ## License 122 | This code is CC-BY-NC 4.0 licensed, as found in [LICENSE](LICENSE). 123 | 124 | ## Citation 125 | If you find our work useful, please consider giving a star and a citation: 126 | ``` 127 | @article{vo2024automatic, 128 | title={Automatic Data Curation for Self-Supervised Learning: A Clustering-Based Approach}, 129 | author={Vo, Huy V. and Khalidov, Vasil and Darcet, Timoth{\'e}e and Moutakanni, Th{\'e}o and Smetanin, Nikita and Szafraniec, Marc and Touvron, Hugo and Couprie, Camille and Oquab, Maxime and Joulin, Armand and Jégou, Hervé and Labatut, Patrick and Bojanowski, Piotr}, 130 | journal={arXiv:2405.15613}, 131 | year={2024}, 132 | } 133 | ``` -------------------------------------------------------------------------------- /scripts/split_clusters.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | from argparse import ArgumentParser 9 | import logging 10 | from pathlib import Path 11 | from tqdm import tqdm 12 | 13 | import numpy as np 14 | import torch 15 | 16 | from src.utils import setup_logging 17 | 18 | from src.dist_comm import ( 19 | enable_distributed, 20 | get_global_rank, 21 | get_global_size, 22 | is_main_process, 23 | synchronize, 24 | ) 25 | from src import distributed_kmeans_gpu as dkmg, kmeans_gpu as kmg 26 | 27 | 28 | logger = logging.getLogger("hkmeans") 29 | 30 | 31 | def split_clusters( 32 | data_path, 33 | subset_indices_path, 34 | clusters_path, 35 | n_splits, 36 | n_iters, 37 | dtype, 38 | high_precision, 39 | save_path, 40 | device="cuda", 41 | use_torchrun=False, 42 | checkpoint_period=10, 43 | verbose=False, 44 | ): 45 | enable_distributed( 46 | use_torchrun=use_torchrun, 47 | overwrite=True, 48 | ) 49 | X = np.load(data_path, mmap_mode="r") 50 | if subset_indices_path is not None: 51 | logger.info(f"Using subset with indices in {subset_indices_path}") 52 | subset_indices = np.load(subset_indices_path) 53 | X = dkmg.ExtendedNumpyMemMap(X, subset_indices) 54 | clusters = np.load(clusters_path, allow_pickle=True) 55 | n_clusters = len(clusters) 56 | 57 | part_indices = dkmg.get_part_indices(n_clusters, get_global_size()) 58 | rank = get_global_rank() 59 | 60 | # load checkpoints if exist 61 | if Path(save_path, f"split_checkpoint_{rank}.npy").exists(): 62 | ckpt = np.load( 63 | Path(save_path, f"split_checkpoint_{rank}.npy"), allow_pickle=True 64 | ).item() 65 | small_centroids = list(ckpt["small_centroids"]) 66 | small_clusters = list(ckpt["small_clusters"]) 67 | last_index = ckpt["last_index"] 68 | assert last_index - part_indices[rank] + 1 == len(small_centroids) 69 | else: 70 | small_centroids = [] 71 | small_clusters = [] 72 | last_index = part_indices[rank] - 1 73 | 74 | # run kmeans++ on clusters 75 | for cluster_idx in tqdm( 76 | range(last_index + 1, part_indices[rank + 1]), 77 | desc="Splitting pre-clusters", 78 | file=sys.stdout, 79 | bar_format="{l_bar}{bar}{r_bar}", 80 | ): 81 | if verbose: 82 | logger.info(f"Processing cluster {cluster_idx}") 83 | point_indices = np.sort(clusters[cluster_idx]) 84 | if len(point_indices) > 0: 85 | point_feats = torch.tensor(X[point_indices], device=device, dtype=dtype) 86 | _small_centroids, _small_clusters, _, _ = kmg.kmeans( 87 | point_feats, 88 | min(n_splits, len(point_indices)), 89 | n_iters, 90 | chunk_size=-1, 91 | init_method="kmeans++", 92 | dist="l2", 93 | high_precision=high_precision, 94 | ) 95 | 96 | _small_clusters = kmg.sort_cluster_by_distance( 97 | point_feats, 98 | _small_centroids, 99 | _small_clusters, 100 | device="cuda", 101 | dtype=dtype, 102 | ) 103 | _small_clusters = [point_indices[el.astype(int)] for el in _small_clusters] 104 | 105 | non_empty_clusters = [len(el) > 0 for el in _small_clusters] 106 | _small_clusters = [el for el in _small_clusters if len(el) > 0] 107 | _small_centroids = _small_centroids[non_empty_clusters] 108 | 109 | small_centroids.append(_small_centroids.cpu().numpy()) 110 | small_clusters += _small_clusters 111 | 112 | del point_feats 113 | if( 114 | cluster_idx % checkpoint_period == 0 or 115 | cluster_idx == part_indices[rank + 1] - 1 116 | ): 117 | np.save( 118 | Path(save_path, f"split_checkpoint_{rank}.npy"), 119 | { 120 | "small_centroids": small_centroids, 121 | "small_clusters": small_clusters, 122 | "last_index": cluster_idx, 123 | }, 124 | ) 125 | synchronize() 126 | logger.info("Gathering clusters") 127 | if is_main_process(): 128 | centroids = [] 129 | clusters = [] 130 | for i in tqdm( 131 | range(get_global_size()), 132 | desc="Gathering splitted clusters", 133 | file=sys.stdout, 134 | bar_format="{l_bar}{bar}{r_bar}", 135 | ): 136 | split_data = np.load( 137 | Path(save_path, f"split_checkpoint_{i}.npy"), 138 | allow_pickle=True 139 | ).item() 140 | small_centroids = np.concatenate(split_data["small_centroids"]) 141 | small_clusters = split_data["small_clusters"] 142 | assert( 143 | len(small_centroids) == len(small_clusters) 144 | ), f"Inconsistent shape in split_checkpoint_{i}.npy" 145 | assert split_data["last_index"] == part_indices[i + 1] - 1 146 | centroids.append(small_centroids) 147 | clusters += small_clusters 148 | centroids = np.concatenate(centroids) 149 | clusters = np.array(clusters, dtype=object) 150 | 151 | logger.info("Saving centroids and clusters") 152 | np.save(Path(save_path, "centroids.npy"), centroids) 153 | np.save(Path(save_path, "sorted_clusters.npy"), clusters) 154 | logger.info("Cleaning checkpoints") 155 | for i in range(get_global_size()): 156 | Path(save_path, f"split_checkpoint_{i}.npy").unlink(missing_ok=True) 157 | logger.info("Finished split_clusters!") 158 | 159 | if __name__ == "__main__": 160 | parser = ArgumentParser() 161 | parser.add_argument("--data_path", type=str, required=True) 162 | parser.add_argument("--subset_indices_path", type=str, default=None) 163 | parser.add_argument("--clusters_path", type=str, required=True) 164 | parser.add_argument("--n_splits", type=int, required=True) 165 | parser.add_argument("--n_iters", type=int, required=True) 166 | parser.add_argument("--dtype", type=str, default="float32") 167 | parser.add_argument("--high_precision", type=str, default="float32") 168 | parser.add_argument("--save_path", type=str, required=True) 169 | parser.add_argument("--use_torchrun", action="store_true") 170 | 171 | args = parser.parse_args() 172 | setup_logging() 173 | 174 | def parse_dtype(dtype): 175 | if dtype == "float32": 176 | return torch.float32 177 | elif dtype == "float64": 178 | return torch.float64 179 | elif dtype == "float16": 180 | return torch.float16 181 | else: 182 | raise ValueError(f"Value of args.dtype ({args.dtype}) not regconised") 183 | 184 | args.dtype = parse_dtype(args.dtype) 185 | args.high_precision = parse_dtype(args.high_precision) 186 | 187 | split_clusters( 188 | args.data_path, 189 | args.subset_indices_path, 190 | args.clusters_path, 191 | args.n_splits, 192 | args.n_iters, 193 | args.dtype, 194 | args.high_precision, 195 | args.save_path, 196 | "cuda", 197 | args.use_torchrun, 198 | ) 199 | -------------------------------------------------------------------------------- /src/hierarchical_sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import logging 9 | import random 10 | 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | from src.clusters import HierarchicalCluster 15 | 16 | 17 | logger = logging.getLogger("hkmeans") 18 | 19 | def random_selection(clusters, valid_clusters, num_per_cluster): 20 | """ 21 | Parameters: 22 | clusters: (num_cluster, ) np.array 23 | clusters[i] contain indices of points in cluster i 24 | valid_clusters: list or np.array 25 | indices of clusters that are considered 26 | num_per_cluster: int 27 | number of points selected from each cluster 28 | 29 | Returns: 30 | array containing indices of selected points 31 | """ 32 | num_clusters = len(clusters) 33 | selected = [[]] * num_clusters 34 | for cluster_id in tqdm( 35 | valid_clusters, 36 | desc="Random sampling from clusters", 37 | file=sys.stdout, 38 | bar_format="{l_bar}{bar}{r_bar}", 39 | ): 40 | selected[cluster_id] = random.sample( 41 | list(clusters[cluster_id]), min(num_per_cluster, len(clusters[cluster_id])) 42 | ) 43 | return np.concatenate(selected).astype(np.int64) 44 | 45 | 46 | def closest_to_centroid_selection(sorted_clusters, valid_clusters, num_per_cluster): 47 | """ 48 | Parameters: 49 | sorted_clusters: (num_cluster, ) np.array 50 | clusters[i] contain indices of points in cluster i 51 | indices in clusters[i] are sorted in increasing distance from the centroid i 52 | valid_clusters: list or np.array 53 | indices of clusters that are considered 54 | num_per_cluster: int, number of points selected from each cluster 55 | 56 | Returns: 57 | array containing indices of selected points 58 | """ 59 | num_clusters = len(sorted_clusters) 60 | selected = [[]] * num_clusters 61 | for cluster_id in tqdm( 62 | valid_clusters, 63 | desc="Closest-to-centroid sampling from clusters", 64 | file=sys.stdout, 65 | bar_format="{l_bar}{bar}{r_bar}", 66 | ): 67 | selected[cluster_id] = sorted_clusters[cluster_id][:num_per_cluster] 68 | return np.concatenate(selected).astype(np.int64) 69 | 70 | 71 | def _find_best_cut_left(arr, target): 72 | """ 73 | Find integers x such that sum(min(x, arr)) best approximates target 74 | """ 75 | if target < 0: 76 | raise ValueError(f"target {target} must be non-negative!") 77 | if np.min(arr) < 0: 78 | raise ValueError("arr has negative elements!") 79 | if np.sum(arr) <= target: 80 | return np.max(arr) 81 | left = 0 82 | right = np.max(arr) 83 | while right - left > 1: 84 | mid = (left + right) // 2 85 | sum_with_mid = np.sum(np.minimum(mid, arr)) 86 | if sum_with_mid > target: 87 | right = mid 88 | elif sum_with_mid < target: 89 | left = mid 90 | else: 91 | return mid 92 | if np.sum(np.minimum(right, arr)) <= target: 93 | return right 94 | return left 95 | 96 | 97 | def find_subcluster_target_size( 98 | subcluster_sizes, 99 | target_size, 100 | multiplier, 101 | ): 102 | """ 103 | Given the target number of points to sample from a clusters, 104 | find number of points to sample from its subclusters. 105 | """ 106 | if isinstance(subcluster_sizes, np.ndarray): 107 | arr = subcluster_sizes * multiplier 108 | else: 109 | arr = np.array(subcluster_sizes) * multiplier 110 | best_cut_left = _find_best_cut_left(arr, target_size) 111 | if best_cut_left == np.max(arr): 112 | return arr 113 | else: 114 | subcluster_target_sizes = np.minimum(best_cut_left, arr) 115 | remainder = target_size - subcluster_target_sizes.sum() 116 | candidates = np.where(arr > best_cut_left)[0] 117 | subcluster_target_sizes[np.random.choice(candidates, remainder, replace=False)] = best_cut_left + 1 118 | assert subcluster_target_sizes.sum() == target_size 119 | assert np.all(subcluster_target_sizes <= arr) 120 | return subcluster_target_sizes 121 | 122 | 123 | def recursive_hierarchical_sampling( 124 | clusters: HierarchicalCluster, 125 | level: int, 126 | target_size: int, 127 | cl_index: int, 128 | multiplier: int, 129 | sampling_strategy: str = "r", 130 | ): 131 | """ 132 | Given a target number of points to sample from a cluster, return 133 | the a set of sampled points. 134 | """ 135 | if level == 1: 136 | current_cluster = clusters.clusters[1][cl_index] 137 | current_cluster_size = clusters.clusters_size[1][cl_index] 138 | if current_cluster_size * multiplier <= target_size: 139 | return np.tile(current_cluster, multiplier) 140 | else: 141 | n_replicates = target_size // current_cluster_size 142 | replicates = np.tile(current_cluster, n_replicates) 143 | remaining_target = target_size - n_replicates * current_cluster_size 144 | if sampling_strategy == "r": # random 145 | remaining_samples = np.random.choice( 146 | current_cluster, 147 | remaining_target, 148 | replace=False, 149 | ) 150 | elif sampling_strategy == "c": # "closest" 151 | remaining_samples = current_cluster[:remaining_target] 152 | else: 153 | raise ValueError(f"sampling_strategy={sampling_strategy} is not supported") 154 | return np.concatenate([replicates, remaining_samples]) 155 | else: 156 | subcl_indices = clusters.clusters[level][cl_index] 157 | subcluster_sizes = clusters.flat_clusters_size[level - 1][subcl_indices] 158 | subcluster_target_sizes = find_subcluster_target_size( 159 | subcluster_sizes, 160 | target_size, 161 | multiplier, 162 | ) 163 | samples = [] 164 | for i, subcl_index in enumerate(subcl_indices): 165 | samples.append( 166 | recursive_hierarchical_sampling( 167 | clusters, 168 | level - 1, 169 | subcluster_target_sizes[i], 170 | subcl_index, 171 | multiplier, 172 | sampling_strategy, 173 | ) 174 | ) 175 | return np.concatenate(samples) 176 | 177 | 178 | def hierarchical_sampling( 179 | clusters: HierarchicalCluster, 180 | target_size: int, 181 | multiplier: int = 1, 182 | sampling_strategy: str = "r", 183 | ): 184 | """ 185 | Method for sample hierarchically from a hierarchy of clusters. 186 | """ 187 | if (not clusters.is_loaded) or (not clusters.is_processed): 188 | raise RuntimeError("HierarchicalCluster is not loaded or processed.") 189 | n_levels = clusters.n_levels 190 | cluster_target_sizes = find_subcluster_target_size( 191 | clusters.flat_clusters_size[n_levels], 192 | target_size, 193 | multiplier, 194 | ) 195 | samples = [] 196 | for cl_index in tqdm( 197 | range(len(clusters.clusters[n_levels])), 198 | desc="Hierarchical sampling from clusters", 199 | file=sys.stdout, 200 | bar_format="{l_bar}{bar}{r_bar}", 201 | ): 202 | samples.append( 203 | recursive_hierarchical_sampling( 204 | clusters, 205 | n_levels, 206 | cluster_target_sizes[cl_index], 207 | cl_index, 208 | multiplier, 209 | sampling_strategy, 210 | ) 211 | ) 212 | samples = np.concatenate(samples) 213 | return samples 214 | -------------------------------------------------------------------------------- /src/hierarchical_kmeans_gpu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import logging 9 | from tqdm import tqdm 10 | 11 | import torch 12 | import numpy as np 13 | 14 | from . import kmeans_gpu as kmg 15 | 16 | 17 | logger = logging.getLogger("hkmeans") 18 | MEMORY_LIMIT = 1e8 19 | 20 | 21 | def hierarchical_kmeans( 22 | data, 23 | n_clusters, 24 | n_levels, 25 | init_method="kmeans++", 26 | num_init=1, 27 | verbose=True 28 | ): 29 | """ 30 | Run hierarchical k-means on data without resampling steps. 31 | 32 | Parameters: 33 | data: 2-D numpy array 34 | Data embeddings. 35 | n_clusters: List[int] 36 | Number of clusters for each level of hierarchical k-means 37 | n_levels: int 38 | Number of levels in hierarchical k-means. 39 | init_method: str, default = "k-means++" 40 | Initialization method for k-means centroids. 41 | Options are "k-means" and "random". 42 | num_init: int, default=1 43 | Number of re-initialization for each k-means run. 44 | 45 | Returns: 46 | List[dict], clustering results for each level of hierarchical k-means, 47 | including 48 | centroids: 2-D numpy array 49 | Centroids of clusters. 50 | assigment: 1-D numpy array 51 | Mapping from data points to cluster indices. 52 | clusters: array of array 53 | pot: float 54 | K-means potential. 55 | """ 56 | assert len(n_clusters) == n_levels 57 | logger.info(f"{n_levels}-level hierarchical kmeans") 58 | res = [] 59 | for kmid in range(n_levels): 60 | logger.info(f"Level {kmid+1}") 61 | if kmid == 0: 62 | X = data 63 | else: 64 | X = res[kmid - 1]["centroids"] 65 | chunk_size = min(X.shape[0], int(MEMORY_LIMIT / n_clusters[kmid])) 66 | centroids, clusters, cluster_assignment, pot = kmg.kmeans( 67 | X, 68 | n_clusters=n_clusters[kmid], 69 | n_iters=50, 70 | chunk_size=chunk_size, 71 | num_init=num_init, 72 | init_method=init_method, 73 | dist="l2", 74 | high_precision=torch.float64, 75 | random_state=None, 76 | verbose=verbose 77 | ) 78 | res.append( 79 | { 80 | "centroids": centroids, 81 | "assignment": cluster_assignment, 82 | "clusters": clusters, 83 | "pot": pot, 84 | } 85 | ) 86 | return res 87 | 88 | 89 | def hierarchical_kmeans_with_resampling( 90 | data, 91 | n_clusters, 92 | n_levels, 93 | sample_sizes, 94 | n_resamples=10, 95 | init_method="kmeans++", 96 | num_init=1, 97 | sample_strategy="closest", 98 | verbose=True, 99 | ): 100 | """ 101 | Run hierarchical k-means on data without resampling steps. 102 | 103 | Parameters: 104 | data: 2-D numpy array 105 | Data embeddings. 106 | n_clusters: List[int] 107 | Number of clusters for each level of hierarchical k-means 108 | n_levels: int 109 | Number of levels in hierarchical k-means. 110 | sample_size: List[int] 111 | Number of points to sample from each cluster in resampling steps. 112 | n_resamples: int 113 | Number of resampling steps in each level. 114 | init_method: str, default = "k-means++" 115 | Initialization method for k-means centroids. 116 | Options are "k-means" and "random". 117 | num_init: int, default=1 118 | Number of re-initialization for each k-means run. 119 | sampling_strategy: str, default = "closest" 120 | How to sample points from clusters in resampling steps. 121 | Options are "closest" and "random". 122 | 123 | Returns: 124 | List[dict], clustering results for each level of hierarchical k-means, 125 | including 126 | centroids: 2-D numpy array 127 | Centroids of clusters. 128 | assigment: 1-D numpy array 129 | Mapping from data points to cluster indices. 130 | clusters: array of array 131 | pot: float 132 | K-means potential. 133 | """ 134 | assert len(n_clusters) == n_levels 135 | assert len(sample_sizes) == n_levels 136 | logger.info(f"{n_levels}-level hierarchical kmeans") 137 | res = [] 138 | for kmid in range(n_levels): 139 | logger.info(f"Level {kmid+1}") 140 | logger.info("Initial kmeans") 141 | if kmid == 0: 142 | X = data 143 | else: 144 | X = res[kmid - 1]["centroids"] 145 | chunk_size = min(X.shape[0], int(MEMORY_LIMIT / n_clusters[kmid])) 146 | logger.info("Running the initial k-means") 147 | centroids, clusters, cluster_assignment, _ = kmg.kmeans( 148 | X, 149 | n_clusters=n_clusters[kmid], 150 | n_iters=50, 151 | chunk_size=chunk_size, 152 | num_init=num_init, 153 | init_method=init_method, 154 | dist="l2", 155 | high_precision=torch.float64, 156 | random_state=None, 157 | verbose=verbose, 158 | ) 159 | logger.info("Resampling-kmeans") 160 | if sample_sizes[kmid] > 1: 161 | _sample_size = sample_sizes[kmid] 162 | for _ in tqdm( 163 | range(n_resamples), 164 | desc="Hierarchical k-means resampling steps", 165 | file=sys.stdout, 166 | bar_format="{l_bar}{bar}{r_bar}", 167 | ): 168 | if sample_strategy == "closest": 169 | sorted_clusters = [ 170 | _cluster[ 171 | torch.argsort( 172 | torch.cdist(X[_cluster], centroids[i, None]) 173 | .flatten() 174 | ) 175 | .cpu() 176 | .numpy() 177 | ] 178 | for i, _cluster in enumerate(clusters) 179 | ] 180 | sampled_points = torch.concat( 181 | [ 182 | X[_cluster[: _sample_size]] 183 | for _cluster in sorted_clusters 184 | ] 185 | ) 186 | elif sample_strategy == "random": 187 | sampled_points = torch.concat( 188 | [ 189 | X[ 190 | np.random.choice( 191 | _cluster, 192 | min(len(_cluster), _sample_size), 193 | replace=False 194 | ) 195 | ] 196 | for _cluster in clusters 197 | ] 198 | ) 199 | else: 200 | raise ValueError( 201 | f"sample_strategy={sample_strategy} not supported!" 202 | ) 203 | chunk_size = min( 204 | sampled_points.shape[0], 205 | int(MEMORY_LIMIT / n_clusters[kmid]) 206 | ) 207 | centroids, _, _, _ = kmg.kmeans( 208 | sampled_points, 209 | n_clusters=n_clusters[kmid], 210 | n_iters=50, 211 | chunk_size=chunk_size, 212 | num_init=num_init, 213 | init_method=init_method, 214 | dist="l2", 215 | high_precision=torch.float64, 216 | random_state=None, 217 | verbose=False 218 | ) 219 | cluster_assignment = kmg.assign_clusters( 220 | centroids, 221 | X, 222 | "l2", 223 | chunk_size=chunk_size, 224 | verbose=False 225 | ).cpu().numpy() 226 | clusters = kmg.create_clusters_from_cluster_assignment( 227 | cluster_assignment, 228 | n_clusters[kmid] 229 | ) 230 | res.append( 231 | { 232 | "centroids": centroids, 233 | "assignment": cluster_assignment, 234 | "clusters": clusters, 235 | "pot": -1, 236 | } 237 | ) 238 | return res 239 | -------------------------------------------------------------------------------- /vis/generalized_kmeans_1d.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | 9 | import numpy as np 10 | import torch 11 | from tqdm import tqdm 12 | from sklearn.utils import check_random_state 13 | 14 | from src import kmeans_gpu as kmg 15 | from src.utils import create_clusters_from_cluster_assignment 16 | 17 | 18 | def l2_squared_power(x, xi, n): 19 | """ 20 | Compute L_2 ^ (2 * n) distance 21 | """ 22 | return (x - xi) ** (2 * n) 23 | 24 | 25 | def l2_squared_power_der(x, xi, n): 26 | """ 27 | Compute the derivative of L_2 ^ (2 * n) distance 28 | """ 29 | return 2 * n * (x - xi) ** (2 * n - 1) 30 | 31 | 32 | def l2_squared_power_der2(x, xi, n): 33 | """ 34 | Compute second-order derivative of L_2 ^ (2 * n) distance 35 | """ 36 | return 2 * n * (2 * n - 1) * (x - xi) ** (2 * n - 2) 37 | 38 | 39 | def kmeans_plusplus( 40 | X, 41 | n_clusters, 42 | x_squared_norms, 43 | dist, 44 | power=1, 45 | random_state=None, 46 | n_local_trials=None, 47 | save_running_results=False, 48 | high_precision=torch.float32, 49 | verbose=False, 50 | ): 51 | """ 52 | Computational component for initialization of n_clusters by 53 | k-means++. Prior validation of data is assumed. 54 | Parameters 55 | ---------- 56 | X : torch.tensor of shape (n_samples, n_features) 57 | The data to pick seeds for. 58 | n_clusters : int 59 | The number of seeds to choose. 60 | x_squared_norms : torch.tensor (n_samples,) 61 | Squared Euclidean norm of each data point. 62 | dist: str 63 | Type of distance function. Options are "l2" or "cos". 64 | power: int 65 | Distance is L_2 ^ (2 * power). 66 | random_state : RandomState instance 67 | The generator used to initialize the centers. 68 | See :term:`Glossary `. 69 | n_local_trials : int, default=None 70 | The number of seeding trials for each center (except the first), 71 | of which the one reducing inertia the most is greedily chosen. 72 | Set to None to make the number of trials depend logarithmically 73 | on the number of seeds (2+log(k)); this is the default. 74 | save_running_results: bool, default=False 75 | Whether to save temporary results during execution. 76 | high_precision: torch.Type 77 | type for high-precision computations. 78 | verbose: bool, default=False 79 | 80 | Returns 81 | ------- 82 | centers : torch.tensor of shape (n_clusters, n_features) 83 | The initial centers for k-means. 84 | indices : ndarray of shape (n_clusters,) 85 | The index location of the chosen centers in the data array X. For a 86 | given index and center, X[index] = center. 87 | """ 88 | if random_state is None: 89 | random_state = check_random_state(random_state) 90 | 91 | n_samples, n_features = X.shape 92 | 93 | centers = torch.empty((n_clusters, n_features), dtype=X.dtype).to(X.device) 94 | pots = torch.empty((n_clusters,), device=X.device, dtype=high_precision) 95 | 96 | # Set the number of local seeding trials if none is given 97 | if n_local_trials is None: 98 | n_local_trials = 2 + int(np.log(n_clusters)) 99 | 100 | # Pick first center randomly and track index of point 101 | center_id = random_state.randint(n_samples) 102 | indices = np.full(n_clusters, -1, dtype=int) 103 | centers[0] = X[center_id] 104 | indices[0] = center_id 105 | 106 | # Initialize list of closest distances and calculate current potential 107 | closest_dist_sq = ( 108 | kmg.compute_distance(X[center_id, None], X, x_squared_norms, dist)[0].type( 109 | high_precision 110 | ) 111 | ** power 112 | ) 113 | current_pot = closest_dist_sq.sum() 114 | pots[0] = current_pot 115 | 116 | # Pick the remaining n_clusters-1 points 117 | if verbose: 118 | iterates = tqdm( 119 | range(1, n_clusters), 120 | desc="Genralized kmeans++ initialization", 121 | file=sys.stdout, 122 | bar_format="{l_bar}{bar}{r_bar}", 123 | ) 124 | else: 125 | iterates = range(1, n_clusters) 126 | for c in iterates: 127 | # Choose center candidates by sampling with probability proportional 128 | # to the distance to the closest existing center 129 | rand_vals = ( 130 | torch.tensor(random_state.uniform(size=n_local_trials)).to( 131 | current_pot.device 132 | ) 133 | * current_pot 134 | ) 135 | candidate_ids = torch.searchsorted( 136 | torch.cumsum(closest_dist_sq, dim=0), rand_vals 137 | ) 138 | # numerical imprecision can result in a candidate_id out of range 139 | torch.clip(candidate_ids, None, closest_dist_sq.shape[0] - 1, out=candidate_ids) 140 | 141 | # Compute distances to center candidates 142 | distance_to_candidates = ( 143 | kmg.compute_distance(X[candidate_ids], X, x_squared_norms, dist).type( 144 | high_precision 145 | ) 146 | ** power 147 | ) 148 | 149 | # update closest distances squared and potential for each candidate 150 | torch.minimum( 151 | closest_dist_sq, distance_to_candidates, out=distance_to_candidates 152 | ) 153 | candidates_pot = distance_to_candidates.sum(dim=1) 154 | 155 | # Decide which candidate is the best 156 | best_candidate = torch.argmin(candidates_pot) 157 | current_pot = candidates_pot[best_candidate] 158 | closest_dist_sq = distance_to_candidates[best_candidate] 159 | best_candidate = candidate_ids[best_candidate] 160 | 161 | # Permanently add best center candidate found in local tries 162 | centers[c] = X[best_candidate] 163 | indices[c] = best_candidate 164 | pots[c] = current_pot 165 | 166 | if save_running_results and c % 1000 == 0: 167 | np.save( 168 | "kmpp_running_results.npy", 169 | {"centers": centers.cpu().numpy(), "indices": indices, "iter": c}, 170 | ) 171 | 172 | return centers, indices 173 | 174 | 175 | def compute_centroids(X, n, n_iters=5, method="newton", verbose=False): 176 | """ 177 | Compute k-means centroids given a set of points, according to distortion 178 | function L_2 ^ (2 * n), with Newton method. 179 | """ 180 | if method == "newton": 181 | # Initialize the centroid with L_2^2 means. 182 | c = X.mean() 183 | if len(X) == 1: 184 | return c 185 | for _ in range(n_iters): 186 | if verbose: 187 | f = torch.sum(l2_squared_power(c, X, n)) 188 | print(f, end=", ") 189 | der_f = torch.sum(l2_squared_power_der(c, X, n)) 190 | der2_f = torch.sum(l2_squared_power_der2(c, X, n)) 191 | if der_f == 0: 192 | break 193 | c -= der_f / der2_f 194 | return c 195 | else: 196 | raise ValueError("Method not supported!") 197 | 198 | 199 | def assign_clusters(X, centers, chunk_size=-1): 200 | """ 201 | Assign points to centroids. 202 | """ 203 | cluster_assignment = ( 204 | kmg.assign_clusters(centers, X, "l2", chunk_size=chunk_size, verbose=False) 205 | .cpu() 206 | .numpy() 207 | ) 208 | clusters = create_clusters_from_cluster_assignment(cluster_assignment, len(centers)) 209 | return clusters 210 | 211 | 212 | def update_centroids(X, clusters, n): 213 | """ 214 | Update centroids based on the new clusters after reassignment. 215 | """ 216 | n_clusters = len(clusters) 217 | centers = torch.zeros((n_clusters, 1), device=X.device, dtype=X.dtype) 218 | for cid in range(n_clusters): 219 | if len(clusters[cid]) > 0: 220 | centers[cid, 0] = compute_centroids(X[clusters[cid]], n).item() 221 | return centers 222 | 223 | 224 | def generalized_kmeans_1d( 225 | X, n_clusters, n, n_iters=50, init_method="k-means++", chunk_size=-1 226 | ): 227 | """ 228 | Run generalized k-means with distance L_2 ^ (2 * n) 229 | """ 230 | assert X.ndim == 2 231 | # initialize 232 | if init_method == "k-means++": 233 | x_squared_norms = torch.linalg.vector_norm(X, dim=1) ** 2 234 | centers, _ = kmeans_plusplus(X, n_clusters, x_squared_norms, "l2", n) 235 | else: 236 | centers = X[np.random.choice(len(X), n_clusters, replace=False), :] 237 | clusters = assign_clusters(X, centers, chunk_size=chunk_size) 238 | for _ in tqdm( 239 | range(n_iters), 240 | desc="Generalized kmeans iterations", 241 | file=sys.stdout, 242 | bar_format="{l_bar}{bar}{r_bar}", 243 | ): 244 | centers = update_centroids(X, clusters, n) 245 | clusters = assign_clusters(X, centers, chunk_size=chunk_size) 246 | return centers, clusters 247 | -------------------------------------------------------------------------------- /scripts/hierarchical_kmeans_launcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from argparse import ArgumentParser 8 | from pathlib import Path 9 | 10 | from omegaconf import OmegaConf 11 | 12 | 13 | ROOT = Path().resolve() 14 | MEMORY_LIMIT = 2e8 15 | 16 | def write_main_script(cfg, level_dir, level_id): 17 | """ 18 | Write slurm script for a level of k-means. 19 | """ 20 | save_dir = level_dir 21 | if cfg.n_splits[level_id - 1] > 1: 22 | save_dir = Path(level_dir, "pre_clusters") 23 | Path(save_dir, "logs").mkdir(parents=True) 24 | 25 | chunk_size = int(MEMORY_LIMIT / cfg.n_clusters[level_id - 1]) 26 | if level_id == 1: 27 | data_path = Path(cfg.embeddings_path).resolve() 28 | else: 29 | data_path = Path(cfg.exp_dir, f"level{level_id-1}/centroids.npy").resolve() 30 | 31 | with open(Path(level_dir, "slurm_script.s"), "w") as f: 32 | f.write( 33 | f"""#!/usr/bin/env bash 34 | 35 | #SBATCH --requeue 36 | #SBATCH --nodes={cfg.nnodes[level_id-1]} 37 | #SBATCH --gpus-per-node={cfg.ngpus_per_node[level_id-1]} 38 | #SBATCH --ntasks-per-node={cfg.ngpus_per_node[level_id-1]} 39 | #SBATCH --job-name=kmeans_level{level_id} 40 | #SBATCH --output={save_dir}/logs/%j_0_log.out 41 | #SBATCH --error={save_dir}/logs/%j_0_log.err 42 | #SBATCH --time=4320 43 | #SBATCH --signal=USR2@300 44 | #SBATCH --open-mode=append\n""" 45 | ) 46 | if cfg.ncpus_per_gpu is not None: 47 | f.write(f"#SBATCH --cpus-per-task={cfg.ncpus_per_gpu}\n") 48 | if cfg.slurm_partition is not None: 49 | f.write(f"#SBATCH --partition={cfg.slurm_partition}\n") 50 | 51 | f.write(f""" 52 | EXPDIR={save_dir} 53 | cd {ROOT} 54 | 55 | PYTHONPATH=.. \\ 56 | srun --unbuffered --output="$EXPDIR"/logs/%j_%t_log.out --error="$EXPDIR"/logs/%j_%t_log.err \\ 57 | python -u run_distributed_kmeans.py \\ 58 | --data_path {data_path} \\ 59 | --n_clusters {cfg.n_clusters[level_id-1]} \\ 60 | --n_iters {cfg.n_iters} \\ 61 | --chunk_size {chunk_size} \\ 62 | --dtype {cfg.dtype} \\ 63 | --high_precision {cfg.high_precision} \\ 64 | --checkpoint_period {cfg.checkpoint_period} \\ 65 | --exp_dir $EXPDIR \\ 66 | --n_steps {cfg.n_resampling_steps[level_id-1]} \\ 67 | --sample_size {cfg.sample_size[level_id-1]} \\ 68 | --sampling_strategy {cfg.sampling_strategy}""" 69 | ) 70 | if level_id == 1 and cfg.subset_indices_path is not None: 71 | f.write(f" \\\n --subset_indices_path {cfg.subset_indices_path}\n") 72 | else: 73 | f.write("\n") 74 | 75 | with open(Path(level_dir, "local_script.sh"), "w") as f: 76 | f.write( 77 | f"""#!/usr/bin/env bash 78 | EXPDIR={save_dir} 79 | cd {ROOT} 80 | 81 | PYTHONPATH=.. \\ 82 | torchrun \\ 83 | --nnodes={cfg.nnodes[level_id-1]} \\ 84 | --nproc_per_node={cfg.ngpus_per_node[level_id-1]} \\ 85 | run_distributed_kmeans.py \\ 86 | --use_torchrun \\ 87 | --data_path {data_path} \\ 88 | --n_clusters {cfg.n_clusters[level_id-1]} \\ 89 | --n_iters {cfg.n_iters} \\ 90 | --chunk_size {chunk_size} \\ 91 | --dtype {cfg.dtype} \\ 92 | --high_precision {cfg.high_precision} \\ 93 | --checkpoint_period {cfg.checkpoint_period} \\ 94 | --exp_dir $EXPDIR \\ 95 | --n_steps {cfg.n_resampling_steps[level_id-1]} \\ 96 | --sample_size {cfg.sample_size[level_id-1]} \\ 97 | --sampling_strategy {cfg.sampling_strategy}""" 98 | ) 99 | if level_id == 1 and cfg.subset_indices_path is not None: 100 | f.write(f" \\\n --subset_indices_path {cfg.subset_indices_path}\n") 101 | else: 102 | f.write("\n") 103 | 104 | 105 | def write_split_clusters_script(cfg, level_dir, level_id): 106 | """ 107 | Write slurm script to split pre-clusters into smaller ones if necessary. 108 | """ 109 | if level_id == 1: 110 | data_path = Path(cfg.embeddings_path).resolve() 111 | else: 112 | data_path = Path(cfg.exp_dir, f"level{level_id-1}/centroids.npy").resolve() 113 | 114 | with open(Path(level_dir, "slurm_split_clusters_script.s"), "w") as f: 115 | f.write( 116 | f"""#!/usr/bin/env bash 117 | 118 | #SBATCH --requeue 119 | #SBATCH --nodes={cfg.nnodes[level_id-1]} 120 | #SBATCH --gpus-per-node={cfg.ngpus_per_node[level_id-1]} 121 | #SBATCH --ntasks-per-node={cfg.ngpus_per_node[level_id-1]} 122 | #SBATCH --job-name=split_kmeans_level{level_id} 123 | #SBATCH --output={level_dir}/logs/%j_0_log.out 124 | #SBATCH --error={level_dir}/logs/%j_0_log.err 125 | #SBATCH --time=4320 126 | #SBATCH --signal=USR2@300 127 | #SBATCH --open-mode=append\n""" 128 | ) 129 | if cfg.ncpus_per_gpu is not None: 130 | f.write(f"#SBATCH --cpus-per-task={cfg.ncpus_per_gpu}\n") 131 | if cfg.slurm_partition is not None: 132 | f.write(f"#SBATCH --partition={cfg.slurm_partition}\n") 133 | 134 | f.write(f""" 135 | EXPDIR={level_dir} 136 | cd {ROOT} 137 | 138 | PYTHONPATH=.. \\ 139 | srun --unbuffered --output="$EXPDIR"/logs/%j_%t_log.out --error="$EXPDIR"/logs/%j_%t_log.err \\ 140 | python -u split_clusters.py \\ 141 | --data_path {data_path} \\ 142 | --clusters_path "$EXPDIR"/pre_clusters/sorted_clusters.npy \\ 143 | --n_splits {cfg.n_splits[level_id-1]} \\ 144 | --n_iters {cfg.n_iters} \\ 145 | --dtype float32 \\ 146 | --high_precision float32 \\ 147 | --save_path $EXPDIR""" 148 | ) 149 | if level_id == 1 and cfg.subset_indices_path is not None: 150 | f.write(f" \\\n --subset_indices_path {cfg.subset_indices_path}\n") 151 | else: 152 | f.write("\n") 153 | 154 | with open(Path(level_dir, "local_split_clusters_script.sh"), "w") as f: 155 | f.write( 156 | f"""#!/usr/bin/env bash 157 | 158 | EXPDIR={level_dir} 159 | cd {ROOT} 160 | 161 | PYTHONPATH=.. \\ 162 | torchrun \\ 163 | --nnodes={cfg.nnodes[level_id-1]} \\ 164 | --nproc_per_node={cfg.ngpus_per_node[level_id-1]} \\ 165 | split_clusters.py \\ 166 | --data_path {data_path} \\ 167 | --clusters_path "$EXPDIR"/pre_clusters/sorted_clusters.npy \\ 168 | --n_splits {cfg.n_splits[level_id-1]} \\ 169 | --n_iters {cfg.n_iters} \\ 170 | --dtype float32 \\ 171 | --high_precision float32 \\ 172 | --save_path $EXPDIR""" 173 | ) 174 | if level_id == 1 and cfg.subset_indices_path is not None: 175 | f.write(f" \\\n --subset_indices_path {cfg.subset_indices_path}\n") 176 | else: 177 | f.write("\n") 178 | 179 | 180 | def write_slurm_scripts(cfg): 181 | """ 182 | Write slurm scripts for all levels. 183 | """ 184 | for level_id in range(1, cfg.n_levels + 1): 185 | if cfg.n_splits[level_id - 1] > 1 and cfg.n_resampling_steps[level_id - 1] > 1: 186 | raise ValueError("Cannot use cluster_split and resampling simultaneously") 187 | level_dir = Path(cfg.exp_dir, f"level{level_id}").resolve() 188 | level_dir.mkdir() 189 | Path(level_dir, "logs").mkdir() 190 | 191 | write_main_script(cfg, level_dir, level_id) 192 | if cfg.n_splits[level_id - 1] > 1: 193 | write_split_clusters_script(cfg, level_dir, level_id) 194 | 195 | 196 | def write_launcher(exp_dir, n_levels, n_splits): 197 | """ 198 | Write bash script to launch slurm scripts in all levels. 199 | """ 200 | exp_dir = Path(exp_dir).resolve() 201 | with open(Path(exp_dir, "launcher.sh"), "w") as f: 202 | f.write( 203 | f"ID=$(sbatch --parsable {str(exp_dir)}/level1/slurm_script.s | tail -1)\n" 204 | ) 205 | f.write('echo "Level 1: job $ID"\n') 206 | if n_splits[0] > 1: 207 | f.write( 208 | f'ID=$(sbatch --parsable --dependency=afterok:"$ID" {str(exp_dir)}/level1/slurm_split_clusters_script.s | tail -1)\n' 209 | ) 210 | f.write('echo "Level 1, split clusters: job $ID"\n') 211 | 212 | for level_id in range(2, n_levels + 1): 213 | f.write( 214 | f'ID=$(sbatch --parsable --dependency=afterok:"$ID" {str(exp_dir)}/level{level_id}/slurm_script.s | tail -1)\n' 215 | ) 216 | f.write(f'echo "Level {level_id}: job $ID"\n') 217 | if n_splits[level_id - 1] > 1: 218 | f.write( 219 | f'ID=$(sbatch --parsable --dependency=afterok:"$ID" {str(exp_dir)}/level{level_id}/slurm_split_clusters_script.s | tail -1)\n' 220 | ) 221 | f.write('echo "Level {level_id}, split clusters: job $ID"\n') 222 | 223 | def write_local_launcher(exp_dir, n_levels, n_splits): 224 | """ 225 | Write bash script to launch slurm scripts in all levels. 226 | """ 227 | exp_dir = Path(exp_dir).resolve() 228 | with open(Path(exp_dir, "local_launcher.sh"), "w") as f: 229 | f.write("set -e\n") 230 | for level_id in range(1, n_levels + 1): 231 | f.write(f"bash {str(exp_dir)}/level{level_id}/local_script.sh\n") 232 | if n_splits[level_id - 1] > 1: 233 | f.write(f"bash {str(exp_dir)}/level{level_id}/local_split_clusters_script.sh\n") 234 | 235 | 236 | if __name__ == "__main__": 237 | parser = ArgumentParser() 238 | parser.add_argument("--exp_dir", type=str, required=True) 239 | parser.add_argument("--embeddings_path", type=str, required=True) 240 | parser.add_argument("--config_file", type=str, help="Path to config file") 241 | 242 | args, opts = parser.parse_known_args() 243 | print(f"opts: {opts}") 244 | config_file = args.config_file 245 | del args.config_file 246 | if config_file: 247 | cfg = OmegaConf.load(config_file) 248 | cfg = OmegaConf.merge( 249 | cfg, 250 | OmegaConf.create(vars(args)), 251 | OmegaConf.from_cli(opts), 252 | ) 253 | else: 254 | cfg = OmegaConf.create(vars(args)) 255 | print("Hierarchical k-means config:") 256 | print(OmegaConf.to_yaml(cfg)) 257 | 258 | Path(cfg.exp_dir).mkdir(parents=True) 259 | with open(Path(cfg.exp_dir, "config.yaml"), "w") as fid: 260 | OmegaConf.save(config=cfg, f=fid) 261 | 262 | write_slurm_scripts(cfg) 263 | write_launcher(cfg.exp_dir, cfg.n_levels, cfg.n_splits) 264 | write_local_launcher(cfg.exp_dir, cfg.n_levels, cfg.n_splits) -------------------------------------------------------------------------------- /src/dist_comm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import random 9 | import re 10 | import socket 11 | from typing import Dict, List 12 | 13 | import torch 14 | import torch.distributed as dist 15 | 16 | 17 | _LOCAL_RANK = -1 18 | _LOCAL_WORLD_SIZE = -1 19 | 20 | 21 | def is_distributed_enabled() -> bool: 22 | """ 23 | Returns: 24 | True if distributed training is enabled 25 | """ 26 | return dist.is_available() and dist.is_initialized() 27 | 28 | 29 | def get_global_size() -> int: 30 | """ 31 | Returns: 32 | The number of processes in the process group 33 | """ 34 | return dist.get_world_size() if is_distributed_enabled() else 1 35 | 36 | 37 | def get_global_rank() -> int: 38 | """ 39 | Returns: 40 | The rank of the current process within the global process group. 41 | """ 42 | return dist.get_rank() if is_distributed_enabled() else 0 43 | 44 | 45 | def get_local_rank() -> int: 46 | """ 47 | Returns: 48 | The rank of the current process within the local (per-machine) process group. 49 | """ 50 | if not is_distributed_enabled(): 51 | return 0 52 | assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE 53 | return _LOCAL_RANK 54 | 55 | 56 | def get_local_size() -> int: 57 | """ 58 | Returns: 59 | The size of the per-machine process group, 60 | i.e. the number of processes per machine. 61 | """ 62 | if not is_distributed_enabled(): 63 | return 1 64 | assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE 65 | return _LOCAL_WORLD_SIZE 66 | 67 | 68 | def is_main_process() -> bool: 69 | """ 70 | Returns: 71 | True if the current process is the main one. 72 | """ 73 | return get_global_rank() == 0 74 | 75 | 76 | def save_in_main_process(*args, **kwargs) -> None: 77 | """Utility function to save only from the main process""" 78 | if not is_main_process(): 79 | return 80 | torch.save(*args, **kwargs) 81 | 82 | 83 | def _get_master_port(seed: int = 0) -> int: 84 | MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000) 85 | 86 | master_port_str = os.environ.get("MASTER_PORT") 87 | if master_port_str is None: 88 | rng = random.Random(seed) 89 | return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) 90 | 91 | return int(master_port_str) 92 | 93 | 94 | def _get_available_port() -> int: 95 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 96 | # A "" host address means INADDR_ANY i.e. binding to all interfaces. 97 | # Note this is not compatible with IPv6. 98 | s.bind(("", 0)) 99 | port = s.getsockname()[1] 100 | return port 101 | 102 | 103 | _TORCH_DISTRIBUTED_ENV_VARS = ( 104 | "MASTER_ADDR", 105 | "MASTER_PORT", 106 | "RANK", 107 | "WORLD_SIZE", 108 | "LOCAL_RANK", 109 | "LOCAL_WORLD_SIZE", 110 | ) 111 | 112 | 113 | def _collect_env_vars() -> Dict[str, str]: 114 | return { 115 | env_var: os.environ[env_var] 116 | for env_var in _TORCH_DISTRIBUTED_ENV_VARS 117 | if env_var in os.environ 118 | } 119 | 120 | 121 | def _is_slurm_job_process() -> bool: 122 | return "SLURM_JOB_ID" in os.environ 123 | 124 | 125 | def _parse_slurm_node_list(s: str) -> List[str]: 126 | nodes = [] 127 | # Extract "hostname", "hostname[1-2,3,4-5]," substrings 128 | p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?") 129 | for m in p.finditer(s): 130 | prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)] 131 | for suffix in suffixes.split(","): 132 | span = suffix.split("-") 133 | if len(span) == 1: 134 | nodes.append(prefix + suffix) 135 | else: 136 | width = len(span[0]) 137 | start, end = int(span[0]), int(span[1]) + 1 138 | nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)]) 139 | return nodes 140 | 141 | 142 | class _TorchDistributedEnvironment: 143 | def __init__(self, use_torchrun): 144 | self.master_addr = "127.0.0.1" 145 | self.master_port = 0 146 | self.rank = -1 147 | self.world_size = -1 148 | self.local_rank = -1 149 | self.local_world_size = -1 150 | 151 | if _is_slurm_job_process() and not use_torchrun: 152 | # use_torchrun is set to True to launch computation on nodes with srun or salloc 153 | return self._set_from_slurm_env() 154 | 155 | env_vars = _collect_env_vars() 156 | if not env_vars: 157 | # Environment is not set 158 | pass 159 | elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS): 160 | # Environment is fully set 161 | return self._set_from_preset_env() 162 | else: 163 | # Environment is partially set 164 | collected_env_vars = ", ".join(env_vars.keys()) 165 | raise RuntimeError(f"Partially set environment: {collected_env_vars}") 166 | 167 | if torch.cuda.device_count() > 0: 168 | return self._set_from_local() 169 | 170 | raise RuntimeError("Can't initialize PyTorch distributed environment") 171 | 172 | # Slurm job created with sbatch, submitit, etc... 173 | def _set_from_slurm_env(self): 174 | job_id = int(os.environ["SLURM_JOB_ID"]) 175 | node_count = int(os.environ["SLURM_JOB_NUM_NODES"]) 176 | nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"]) 177 | assert len(nodes) == node_count 178 | 179 | self.master_addr = nodes[0] 180 | self.master_port = _get_master_port(seed=job_id) 181 | print(f"using {self.master_addr}:{self.master_port}") 182 | self.rank = int(os.environ["SLURM_PROCID"]) 183 | self.world_size = int(os.environ["SLURM_NTASKS"]) 184 | assert self.rank < self.world_size 185 | self.local_rank = int(os.environ["SLURM_LOCALID"]) 186 | self.local_world_size = self.world_size // node_count 187 | assert self.local_rank < self.local_world_size 188 | 189 | # Single node job with preset environment (i.e. torchrun) 190 | def _set_from_preset_env(self): 191 | self.master_addr = os.environ["MASTER_ADDR"] 192 | self.master_port = os.environ["MASTER_PORT"] 193 | self.rank = int(os.environ["RANK"]) 194 | self.world_size = int(os.environ["WORLD_SIZE"]) 195 | assert self.rank < self.world_size 196 | self.local_rank = int(os.environ["LOCAL_RANK"]) 197 | self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) 198 | assert self.local_rank < self.local_world_size 199 | 200 | # Single node and GPU job (i.e. local script run) 201 | def _set_from_local(self): 202 | self.master_addr = "127.0.0.1" 203 | self.master_port = _get_available_port() 204 | self.rank = 0 205 | self.world_size = 1 206 | self.local_rank = 0 207 | self.local_world_size = 1 208 | 209 | def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment": 210 | # See the "Environment variable initialization" section from 211 | # https://pytorch.org/docs/stable/distributed.html for the complete list of 212 | # environment variables required for the env:// initialization method. 213 | env_vars = { 214 | "MASTER_ADDR": self.master_addr, 215 | "MASTER_PORT": str(self.master_port), 216 | "RANK": str(self.rank), 217 | "WORLD_SIZE": str(self.world_size), 218 | "LOCAL_RANK": str(self.local_rank), 219 | "LOCAL_WORLD_SIZE": str(self.local_world_size), 220 | } 221 | print(env_vars) 222 | 223 | if not overwrite: 224 | for key in env_vars: 225 | # Only check for difference with preset environment variables 226 | if key in os.environ and os.environ[key] != env_vars[key]: 227 | raise RuntimeError( 228 | f"Cannot export environment variables as {key} is already set" 229 | ) 230 | 231 | os.environ.update(env_vars) 232 | return self 233 | 234 | 235 | def enable_distributed( 236 | *, use_torchrun=False, set_cuda_current_device: bool = True, overwrite: bool = False 237 | ): 238 | """Enable distributed mode 239 | Args: 240 | set_cuda_current_device: If True, call torch.cuda.set_device() to set the 241 | current PyTorch CUDA device to the one matching the local rank. 242 | overwrite: If True, overwrites already set variables. Else fails. 243 | """ 244 | 245 | global _LOCAL_RANK, _LOCAL_WORLD_SIZE 246 | if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0: 247 | raise RuntimeError("Distributed mode has already been enabled") 248 | torch_env = _TorchDistributedEnvironment(use_torchrun) 249 | torch_env.export(overwrite=overwrite) 250 | 251 | if set_cuda_current_device: 252 | torch.cuda.set_device(torch_env.local_rank) 253 | 254 | dist.init_process_group(backend="nccl") 255 | dist.barrier() 256 | 257 | 258 | def gather_tensor(x, do_all_gather=False): 259 | """ 260 | Gather tensors from all ranks to the main rank. 261 | There is an option to all_gather. 262 | """ 263 | world_size = dist.get_world_size() 264 | local_size = torch.tensor(x.size(), device=x.device) 265 | all_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] 266 | dist.all_gather(all_sizes, local_size) 267 | 268 | max_length = max(size[0] for size in all_sizes) 269 | 270 | length_diff = max_length.item() - local_size[0].item() 271 | if length_diff: 272 | pad_size = (length_diff, *x.size()[1:]) 273 | padding = torch.zeros(pad_size, device=x.device, dtype=x.dtype) 274 | x = torch.cat((x, padding)) 275 | 276 | if do_all_gather: 277 | all_tensors_padded = [torch.zeros_like(x) for _ in range(world_size)] 278 | dist.all_gather(all_tensors_padded, x) 279 | synchronize() 280 | all_tensors = [] 281 | for tensor_, size in zip(all_tensors_padded, all_sizes): 282 | all_tensors.append(tensor_[: size[0]]) 283 | return torch.cat(all_tensors) 284 | else: 285 | if is_main_process(): 286 | all_tensors_padded = [torch.zeros_like(x) for _ in range(world_size)] 287 | dist.gather(x, all_tensors_padded) 288 | else: 289 | dist.gather(x, dst=0) 290 | synchronize() 291 | if is_main_process(): 292 | all_tensors = [] 293 | for tensor_, size in zip(all_tensors_padded, all_sizes): 294 | all_tensors.append(tensor_[: size[0]]) 295 | return torch.cat(all_tensors) 296 | else: 297 | return None 298 | 299 | 300 | def synchronize(): 301 | """ 302 | Helper function to synchronize (barrier) among all processes when 303 | using distributed training 304 | """ 305 | if not dist.is_available(): 306 | return 307 | if not dist.is_initialized(): 308 | return 309 | world_size = get_global_size() 310 | if world_size == 1: 311 | return 312 | dist.barrier() 313 | -------------------------------------------------------------------------------- /scripts/run_distributed_kmeans.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import logging 9 | from pathlib import Path 10 | import subprocess 11 | 12 | import numpy as np 13 | import torch 14 | 15 | from src import ( 16 | distributed_kmeans_gpu as dkmg, 17 | kmeans_gpu as kmg, 18 | hierarchical_sampling as hs, 19 | ) 20 | from src.dist_comm import enable_distributed, is_main_process, synchronize 21 | from src.utils import get_last_valid_checkpoint, setup_logging 22 | 23 | 24 | logger = logging.getLogger("hkmeans") 25 | 26 | def check_and_load_npy(load_path, allow_pickle=False, data_name=None): 27 | if load_path.exists(): 28 | if data_name is not None: 29 | logger.info(f"Loading {data_name} from {str(load_path)}") 30 | else: 31 | logger.info(f"Loading from {str(load_path)}") 32 | data = np.load(load_path, allow_pickle=allow_pickle) 33 | return data 34 | else: 35 | return None 36 | 37 | 38 | def check_and_save(save_path, save_data): 39 | if is_main_process(): 40 | if not save_path.exists(): 41 | np.save(save_path, save_data) 42 | synchronize() 43 | 44 | 45 | def main(args): 46 | enable_distributed( 47 | use_torchrun=args.use_torchrun, 48 | overwrite=True, 49 | ) 50 | 51 | X_ori = np.load(args.data_path, mmap_mode="r") 52 | if args.subset_indices_path is not None: 53 | logger.info(f"Using subset with indices in {args.subset_indices_path}") 54 | subset_indices = np.load(args.subset_indices_path) 55 | X_ori = dkmg.ExtendedNumpyMemMap(X_ori, subset_indices) 56 | Xi_ori = dkmg.load_data_to_worker(X_ori, dtype=args.dtype) 57 | 58 | for step_id in range(args.n_steps): 59 | step_dir = Path(args.exp_dir, f"step{step_id}") 60 | sorted_clusters_path = Path(step_dir, "sorted_clusters.npy") 61 | if sorted_clusters_path.exists(): 62 | logger.info( 63 | f"Step {step_id}: sorted clusters exist ({sorted_clusters_path}), skipping" 64 | ) 65 | continue 66 | logger.info(f"Running step {step_id}") 67 | step_dir.mkdir(exist_ok=True) 68 | 69 | # Load resampled points and run kmeans 70 | if step_id == 0: 71 | X = X_ori 72 | Xi = Xi_ori 73 | else: 74 | if is_main_process(): 75 | if not Path(step_dir, "sampled_indices.npy").exists(): 76 | logger.info(f"Sampling points for step {step_id}") 77 | prev_sorted_clusters = np.load( 78 | Path(args.exp_dir, f"step{step_id-1}", "sorted_clusters.npy"), 79 | allow_pickle=True, 80 | ) 81 | logger.info( 82 | f"Sampling from {len(prev_sorted_clusters)} clusters using " 83 | f"'{args.sampling_strategy}' sampling strategy, " 84 | f"{args.sample_size} samples per cluster" 85 | ) 86 | if args.sampling_strategy == "c": 87 | sampler = hs.closest_to_centroid_selection 88 | elif args.sampling_strategy == "r": 89 | sampler = hs.random_selection 90 | else: 91 | raise ValueError( 92 | f"sampling_strategy={args.sampling_strategy} not recognized!" 93 | ) 94 | sampled_indices = sampler( 95 | prev_sorted_clusters, 96 | list(range(len(prev_sorted_clusters))), 97 | args.sample_size, 98 | ) 99 | sampled_indices = np.sort(sampled_indices) 100 | logger.info(f"Selected {len(sampled_indices)} images") 101 | np.save(Path(step_dir, "sampled_indices.npy"), sampled_indices) 102 | else: 103 | logger.info( 104 | f"sampled_indices.npy exists at " 105 | f"{str(Path(step_dir, 'sampled_indices.npy'))}" 106 | ) 107 | synchronize() 108 | sampled_indices = np.load(Path(step_dir, "sampled_indices.npy")) 109 | X = dkmg.ExtendedNumpyMemMap(X_ori, indices=sampled_indices) 110 | Xi = dkmg.load_data_to_worker(X, dtype=args.dtype) 111 | 112 | # Compute centroids 113 | centroids = check_and_load_npy( 114 | Path(step_dir, "centroids.npy"), 115 | allow_pickle=False, 116 | data_name="centroids" 117 | ) 118 | if centroids is not None: 119 | centroids = torch.tensor(centroids, device="cuda", dtype=args.dtype) 120 | synchronize() 121 | else: 122 | logger.info("Begin distributed kmeans") 123 | centroids, _ = dkmg.distributed_kmeans( 124 | X, 125 | Xi, 126 | args.n_clusters, 127 | n_iters=args.n_iters, 128 | chunk_size=args.chunk_size, 129 | init_method="kmeans++", 130 | save_kmpp_results=True, 131 | save_dir=step_dir, 132 | kmpp_checkpoint_period=args.checkpoint_period, 133 | high_precision=args.high_precision, 134 | ) 135 | check_and_save( 136 | Path(step_dir, "centroids.npy"), 137 | centroids.cpu().numpy() 138 | ) 139 | 140 | # Compute cluster_assignment 141 | cluster_assignment = check_and_load_npy( 142 | Path(step_dir, "cluster_assignment.npy"), 143 | allow_pickle=False, 144 | data_name="cluster_assignment" 145 | ) 146 | if cluster_assignment is None: 147 | logger.info("Assign points to clusters") 148 | cluster_assignment = ( 149 | dkmg.distributed_assign_clusters( 150 | X_ori, 151 | Xi_ori, 152 | centroids, 153 | args.chunk_size, 154 | verbose=True 155 | ) 156 | .cpu() 157 | .numpy() 158 | ) 159 | check_and_save( 160 | Path(step_dir, "cluster_assignment.npy"), 161 | cluster_assignment 162 | ) 163 | 164 | # Compute clusters 165 | clusters = check_and_load_npy( 166 | Path(step_dir, "clusters.npy"), 167 | allow_pickle=True, 168 | data_name="clusters" 169 | ) 170 | if clusters is None: 171 | logger.info("Create clusters from cluster_assignment") 172 | clusters = kmg.create_clusters_from_cluster_assignment( 173 | cluster_assignment, args.n_clusters 174 | ) 175 | check_and_save( 176 | Path(step_dir, "clusters.npy"), 177 | clusters 178 | ) 179 | 180 | if not Path(step_dir, "sorted_clusters.npy").exists(): 181 | centroids = centroids.cpu().numpy() 182 | del X, Xi 183 | if step_id == args.n_steps - 1: 184 | del Xi_ori 185 | torch.cuda.empty_cache() 186 | logger.info("Sort points in each cluster by distance to centroid") 187 | _ = dkmg.distributed_sort_cluster_by_distance( 188 | X_ori, 189 | centroids, 190 | clusters, 191 | dtype=torch.float32, 192 | save_dir=step_dir, 193 | checkpoint_period=args.sort_cluster_checkpoint_period, 194 | ) 195 | 196 | # remove checkpoints 197 | if is_main_process(): 198 | while get_last_valid_checkpoint(step_dir, "kmpp_checkpoint_%d.pth"): 199 | get_last_valid_checkpoint(step_dir, "kmpp_checkpoint_%d.pth").unlink( 200 | missing_ok=True 201 | ) 202 | while get_last_valid_checkpoint(step_dir, "centroids_checkpoint_%d.npy"): 203 | get_last_valid_checkpoint( 204 | step_dir, "centroids_checkpoint_%d.npy" 205 | ).unlink(missing_ok=True) 206 | 207 | if is_main_process(): 208 | last_sorted_clusters_path = str( 209 | Path(args.exp_dir, f"step{args.n_steps-1}", "sorted_clusters.npy").resolve() 210 | ) 211 | last_centroids_path = str( 212 | Path(args.exp_dir, f"step{args.n_steps-1}", "centroids.npy").resolve() 213 | ) 214 | 215 | link_command = f'ln -s {last_sorted_clusters_path} {str(Path(args.exp_dir, "sorted_clusters.npy").resolve())}' 216 | process = subprocess.Popen(link_command.split(), stdout=subprocess.PIPE) 217 | _, _ = process.communicate() 218 | 219 | link_command = f'ln -s {last_centroids_path} {str(Path(args.exp_dir, "centroids.npy").resolve())}' 220 | process = subprocess.Popen(link_command.split(), stdout=subprocess.PIPE) 221 | _, _ = process.communicate() 222 | 223 | logger.info("Finished all steps!") 224 | 225 | 226 | if __name__ == "__main__": 227 | parser = argparse.ArgumentParser() 228 | parser.add_argument("--data_path", type=str, required=True) 229 | parser.add_argument("--subset_indices_path", type=str, default=None) 230 | parser.add_argument("--n_clusters", type=int, required=True) 231 | parser.add_argument("--chunk_size", type=int, default=1000) 232 | parser.add_argument("--dtype", type=str, default="float32") 233 | parser.add_argument("--high_precision", type=str, default="float32") 234 | parser.add_argument("--checkpoint_period", type=int, default=1000) 235 | parser.add_argument( 236 | "--sort_cluster_checkpoint_period", 237 | type=int, 238 | default=-1 239 | ) 240 | parser.add_argument("--exp_dir", type=str, default="tmp") 241 | parser.add_argument("--n_iters", type=int, default=10) 242 | parser.add_argument("--use_torchrun", action="store_true") 243 | 244 | parser.add_argument( 245 | "--n_steps", type=int, default=1, help="Number of resampling step" 246 | ) 247 | parser.add_argument( 248 | "--sample_size", 249 | type=int, 250 | required=True, 251 | help="Number of samples per cluster in resampling", 252 | ) 253 | parser.add_argument( 254 | "--sampling_strategy", 255 | type=str, 256 | default="c", 257 | help="resampling with closest (c) or random (r) strategy", 258 | ) 259 | 260 | args = parser.parse_args() 261 | setup_logging() 262 | 263 | def parse_dtype(dtype): 264 | if dtype == "float32": 265 | return torch.float32 266 | elif dtype == "float64": 267 | return torch.float64 268 | elif dtype == "float16": 269 | return torch.float16 270 | else: 271 | raise ValueError(f"Value of args.dtype ({args.dtype}) not regconised") 272 | 273 | args.dtype = parse_dtype(args.dtype) 274 | args.high_precision = parse_dtype(args.high_precision) 275 | if args.dtype == torch.float64: 276 | args.high_precision = torch.float64 277 | assert args.high_precision in [torch.float32, torch.float64] 278 | 279 | logger.info(f"Args: {args}") 280 | 281 | main(args) 282 | synchronize() 283 | -------------------------------------------------------------------------------- /src/kmeans_gpu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import numpy as np 9 | import torch 10 | from tqdm import tqdm 11 | 12 | from .utils import create_clusters_from_cluster_assignment 13 | from sklearn.utils import check_random_state 14 | 15 | 16 | def matmul_transpose(X, Y): 17 | """ 18 | Compute X . Y.T 19 | """ 20 | return torch.matmul(X, Y.T) 21 | 22 | 23 | def compute_distance( 24 | X, Y, Y_squared_norms, dist="l2", X_squared_norm=None, matmul_fn=matmul_transpose 25 | ): 26 | """ 27 | Compute pairwise distance between rows of X and Y. 28 | 29 | Parameters: 30 | X: torch.tensor of shape (n_samples_x, n_features) 31 | Y: torch.tensor of shape (n_samples_y, n_features) 32 | Y is supposed to be larger than X. 33 | Y_squared_norms: torch.tensor of shape (n_samples_y, ) 34 | Squared L2 norm of rows of Y. 35 | It can be provided to avoid re-computation. 36 | dist: 'cos' or 'l2' 37 | If 'cos', assuming that rows of X are normalized 38 | to have L2 norm equal to 1. 39 | X_squared_norm: torch.tensor of shape (n_samples_x, ) 40 | Squared L2 norm of rows of X. 41 | matmul_fn: matmul function. 42 | 43 | Returns: 44 | 45 | Pairwise distance between rows of X and Y. 46 | 47 | """ 48 | 49 | if dist == "cos": 50 | return 2 - 2 * matmul_fn(X, Y) 51 | elif dist == "l2": 52 | if X_squared_norm is None: 53 | X_squared_norm = torch.linalg.vector_norm(X, dim=1) ** 2 54 | return X_squared_norm[:, None] - 2 * matmul_fn(X, Y) + Y_squared_norms[None, :] 55 | else: 56 | raise ValueError(f'dist = "{dist}" not supported!') 57 | 58 | 59 | # A modified version of _kmeans_plusplus 60 | # from https://github.com/scikit-learn/scikit-learn/blob/364c77e04/sklearn/cluster/_kmeans.py#L63 61 | def kmeans_plusplus( 62 | X, 63 | n_clusters, 64 | x_squared_norms, 65 | dist, 66 | random_state=None, 67 | n_local_trials=None, 68 | high_precision=torch.float64, 69 | verbose=False, 70 | ): 71 | """ 72 | Computational component for initialization of n_clusters by 73 | k-means++. Prior validation of data is assumed. 74 | 75 | Parameters 76 | X : torch.tensor of shape (n_samples, n_features) 77 | The data to pick seeds for. 78 | n_clusters : int 79 | The number of seeds to choose. 80 | x_squared_norms : torch.tensor (n_samples,) 81 | Squared Euclidean norm of each data point. 82 | random_state : RandomState instance 83 | The generator used to initialize the centers. 84 | n_local_trials : int, default=None 85 | The number of seeding trials for each center (except the first), 86 | of which the one reducing inertia the most is greedily chosen. 87 | Set to None to make the number of trials depend logarithmically 88 | on the number of seeds (2+log(k)); this is the default. 89 | high_precision: torch.float32 or torch.float64, to save GPU memory, one 90 | can use float32 or float16 for data 'X', 'high_precision' will be 91 | use in aggregation operation to avoid overflow. 92 | 93 | Returns 94 | centers : torch.tensor of shape (n_clusters, n_features) 95 | The initial centers for k-means. 96 | indices : ndarray of shape (n_clusters,) 97 | The index location of the chosen centers in the data array X. For a 98 | given index and center, X[index] = center. 99 | 100 | """ 101 | if random_state is None: 102 | random_state = check_random_state(random_state) 103 | 104 | n_samples, n_features = X.shape 105 | 106 | centers = torch.empty((n_clusters, n_features), dtype=X.dtype).to(X.device) 107 | pots = torch.empty((n_clusters,), device=X.device, dtype=high_precision) 108 | 109 | # Set the number of local seeding trials if none is given 110 | if n_local_trials is None: 111 | n_local_trials = 2 + int(np.log(n_clusters)) 112 | 113 | # Pick first center randomly and track index of point 114 | center_id = random_state.randint(n_samples) 115 | indices = np.full(n_clusters, -1, dtype=int) 116 | centers[0] = X[center_id] 117 | indices[0] = center_id 118 | 119 | # Initialize list of closest distances and calculate current potential 120 | closest_dist_sq = compute_distance(X[center_id, None], X, x_squared_norms, dist)[ 121 | 0 122 | ].type(high_precision) 123 | current_pot = closest_dist_sq.sum() 124 | pots[0] = current_pot 125 | 126 | # Pick the remaining n_clusters-1 points 127 | if verbose: 128 | iterates = tqdm( 129 | range(1, n_clusters), 130 | desc="Kmeans++ initialization", 131 | file=sys.stdout, 132 | bar_format="{l_bar}{bar}{r_bar}", 133 | ) 134 | else: 135 | iterates = range(1, n_clusters) 136 | for c in iterates: 137 | # Choose center candidates by sampling with probability proportional 138 | # to the squared distance to the closest existing center 139 | rand_vals = ( 140 | torch.tensor(random_state.uniform(size=n_local_trials)).to( 141 | current_pot.device 142 | ) 143 | * current_pot 144 | ) 145 | candidate_ids = torch.searchsorted( 146 | torch.cumsum(closest_dist_sq, dim=0), rand_vals 147 | ) 148 | # numerical imprecision can result in a candidate_id out of range 149 | torch.clip(candidate_ids, None, closest_dist_sq.shape[0] - 1, out=candidate_ids) 150 | 151 | # Compute distances to center candidates 152 | distance_to_candidates = compute_distance( 153 | X[candidate_ids], X, x_squared_norms, dist 154 | ).type(high_precision) 155 | 156 | # update closest distances squared and potential for each candidate 157 | torch.minimum( 158 | closest_dist_sq, distance_to_candidates, out=distance_to_candidates 159 | ) 160 | candidates_pot = distance_to_candidates.sum(dim=1) 161 | 162 | # Decide which candidate is the best 163 | best_candidate = torch.argmin(candidates_pot) 164 | current_pot = candidates_pot[best_candidate] 165 | closest_dist_sq = distance_to_candidates[best_candidate] 166 | best_candidate = candidate_ids[best_candidate] 167 | 168 | # Permanently add best center candidate found in local tries 169 | centers[c] = X[best_candidate] 170 | indices[c] = best_candidate 171 | pots[c] = current_pot 172 | 173 | return centers, indices 174 | 175 | 176 | def assign_clusters(centroids, X, dist, chunk_size=-1, verbose=False): 177 | """ 178 | Assign data points to their closest clusters. 179 | 180 | Parameters: 181 | 182 | centroids: torch.tensor of shape (n_clusters, n_features) 183 | Centroids of the clusters. 184 | X: torch.tensor of shape (n_samples, n_features) 185 | Data. 186 | dist: 'cos' or 'l2' 187 | If 'cos', assuming that rows of X are normalized 188 | to have L2 norm equal to 1. 189 | chunk_size: int 190 | Number of data points that are assigned at once. 191 | Use a small chunk_size if n_clusters is large to avoid 192 | out-of-memory error, e.g. chunk_size <= 1e9/n_clusters. 193 | Default is -1, meaning all data points are assigned at once. 194 | verbose: bool 195 | Whether to print progress bar. 196 | 197 | Returns: 198 | 199 | torch.tensor of shape (n_samples, ) containing the cluster id of 200 | each data point. 201 | 202 | """ 203 | 204 | cluster_ids = [] 205 | n_samples, _ = X.shape 206 | x_squared_norms = torch.linalg.vector_norm(X, dim=1) ** 2 207 | centroid_squared_norm = torch.linalg.vector_norm(centroids, dim=1) ** 2 208 | if chunk_size < 0: 209 | try: 210 | distance_from_centroids = compute_distance( 211 | centroids, X, x_squared_norms, dist, centroid_squared_norm 212 | ) 213 | except Exception as e: 214 | raise MemoryError( 215 | f"matrices are too large, consider setting chunk_size ({chunk_size}) to a smaller number" 216 | ) from e 217 | cluster_ids = torch.argmin(distance_from_centroids, dim=0) 218 | else: 219 | n_iters = (n_samples + chunk_size - 1) // chunk_size 220 | if verbose: 221 | iterates = tqdm( 222 | range(n_iters), 223 | desc="Assigning data points to centroids", 224 | file=sys.stdout, 225 | bar_format="{l_bar}{bar}{r_bar}", 226 | ) 227 | else: 228 | iterates = range(n_iters) 229 | 230 | for chunk_idx in iterates: 231 | begin_idx = chunk_idx * chunk_size 232 | end_idx = min(n_samples, (chunk_idx + 1) * chunk_size) 233 | distance_from_centroids = compute_distance( 234 | centroids, 235 | X[begin_idx:end_idx], 236 | x_squared_norms[begin_idx:end_idx], 237 | dist, 238 | centroid_squared_norm, 239 | ) 240 | cluster_ids.append(torch.argmin(distance_from_centroids, dim=0)) 241 | del distance_from_centroids 242 | cluster_ids = torch.cat(cluster_ids) 243 | return cluster_ids 244 | 245 | 246 | def compute_centroids( 247 | centroids, cluster_assignment, n_clusters, X, high_precision=torch.float32 248 | ): 249 | """ 250 | Compute centroids of each cluster given its data points. 251 | 252 | Parameters: 253 | 254 | centroids: torch.tensor of shape (n_clusters, n_features) 255 | Previous centroids of the clusters. 256 | cluster_assignment: torch.tensor of shape (n_samples, ) 257 | Cluster id of data points. 258 | n_clusters: int 259 | Number of clusters. 260 | X: torch.tensor of shape (n_samples, n_features) 261 | Data. 262 | high_precision: torch.float32 or torch.float64, to save GPU memory, one 263 | can use float32 or float16 for data 'X', 'high_precision' will be 264 | use in aggregation operation to avoid overflow. 265 | 266 | Returns: 267 | 268 | torch.tensor of shape (n_clusters, n_features), new centroids 269 | """ 270 | clusters = create_clusters_from_cluster_assignment(cluster_assignment, n_clusters) 271 | new_centroids = torch.zeros_like(centroids) 272 | for i in range(n_clusters): 273 | if len(clusters[i]) > 0: 274 | new_centroids[i] = torch.mean( 275 | X[clusters[i].astype(int)].type(high_precision), dim=0 276 | ) 277 | else: 278 | new_centroids[i] = centroids[i] 279 | return new_centroids 280 | 281 | 282 | def _kmeans( 283 | X, 284 | n_clusters, 285 | n_iters, 286 | chunk_size=-1, 287 | init_method="kmeans++", 288 | dist="l2", 289 | high_precision=torch.float32, 290 | random_state=None, 291 | verbose=False, 292 | ): 293 | """ 294 | Run kmeans once. 295 | 296 | Parameters: See above. 297 | 298 | Returns: 299 | 300 | centroids: 301 | clusters: np.array of np.array 302 | Indices of points in each cluster. A subarray corresponds to a cluster. 303 | cluster_assignment: 304 | pot: float, kmeans objective 305 | 306 | """ 307 | if random_state is None: 308 | random_state = check_random_state(random_state) 309 | 310 | x_squared_norms = torch.linalg.vector_norm(X, dim=1) ** 2 311 | if init_method == "kmeans++": 312 | centroids, _ = kmeans_plusplus( 313 | X, 314 | n_clusters, 315 | x_squared_norms, 316 | dist, 317 | high_precision=high_precision, 318 | random_state=random_state, 319 | verbose=verbose, 320 | ) 321 | else: 322 | centroids = torch.tensor( 323 | X[np.sort(random_state.choice(range(len(X)), n_clusters, replace=False))], 324 | device=X.device, 325 | dtype=X.dtype, 326 | ) 327 | 328 | cluster_assignment = assign_clusters(centroids, X, dist, chunk_size).cpu().numpy() 329 | for _iter in range(n_iters): 330 | centroids = compute_centroids( 331 | centroids, cluster_assignment, n_clusters, X, high_precision 332 | ) 333 | cluster_assignment = ( 334 | assign_clusters(centroids, X, dist, chunk_size).cpu().numpy() 335 | ) 336 | clusters = create_clusters_from_cluster_assignment(cluster_assignment, n_clusters) 337 | pot = np.sum( 338 | [ 339 | torch.sum( 340 | torch.cdist( 341 | X[el.astype(int)], X[el.astype(int)].mean(dim=0, keepdim=True) 342 | ) 343 | ** 2 344 | ).item() 345 | for el in clusters 346 | ] 347 | ) 348 | return centroids, clusters, cluster_assignment, pot 349 | 350 | 351 | def kmeans( 352 | X, 353 | n_clusters, 354 | n_iters, 355 | chunk_size=-1, 356 | num_init=10, 357 | init_method="kmeans++", 358 | dist="l2", 359 | high_precision=torch.float32, 360 | random_state=None, 361 | verbose=False, 362 | ): 363 | """ 364 | Run kmeans multiple times and return the clustering with the best objective. 365 | 366 | Parameters: See above and 367 | 368 | num_init: int 369 | Number of kmeans runs. 370 | 371 | Returns: 372 | 373 | Same as _kmeans 374 | 375 | """ 376 | 377 | n_clusters = min(X.shape[0], n_clusters) 378 | best_centroids, best_clusters, best_cluster_assignment, best_pot = ( 379 | None, 380 | None, 381 | None, 382 | np.Inf, 383 | ) 384 | for _ in range(num_init): 385 | centroids, clusters, cluster_assignment, pot = _kmeans( 386 | X, 387 | n_clusters, 388 | n_iters, 389 | chunk_size=chunk_size, 390 | init_method=init_method, 391 | dist=dist, 392 | high_precision=high_precision, 393 | random_state=random_state, 394 | verbose=verbose, 395 | ) 396 | if pot < best_pot: 397 | best_centroids, best_clusters, best_cluster_assignment, best_pot = ( 398 | centroids, 399 | clusters, 400 | cluster_assignment, 401 | pot, 402 | ) 403 | return best_centroids, best_clusters, best_cluster_assignment, best_pot 404 | 405 | 406 | def sort_cluster_by_distance( 407 | X, centroids, clusters, device="cuda", dtype=torch.float32, verbose=False, 408 | ): 409 | """ 410 | Sort data points in each cluster in increasing order of distance to the centroid. 411 | 412 | Parameters: 413 | 414 | X: data 415 | centroids: 416 | clusters: 417 | 418 | Returns: 419 | 420 | sorted_clusters: np.array of np.array 421 | Indices of points in each cluster. A subarray corresponds to a cluster. 422 | 423 | """ 424 | 425 | n_clusters, n_dim = centroids.shape[0], centroids.shape[1] 426 | 427 | sorted_clusters = [] 428 | if verbose: 429 | iterates = tqdm( 430 | range(n_clusters), 431 | desc="Sorting clusters by distance", 432 | file=sys.stdout, 433 | bar_format="{l_bar}{bar}{r_bar}", 434 | ) 435 | else: 436 | iterates = range(n_clusters) 437 | for cluster_idx in iterates: 438 | if len(clusters[cluster_idx]) > 0: 439 | point_indices = np.sort(clusters[cluster_idx]).astype(int) 440 | point_feats = torch.tensor(X[point_indices], device=device, dtype=dtype) 441 | _centroid = centroids[cluster_idx].reshape(1, n_dim).type(dtype) 442 | 443 | dist_to_centroid = torch.cdist(point_feats, _centroid).flatten() 444 | sorted_clusters.append( 445 | point_indices[torch.argsort(dist_to_centroid).cpu().numpy()] 446 | ) 447 | del point_feats 448 | else: 449 | sorted_clusters.append(np.array([]).astype(int)) 450 | return np.array(sorted_clusters, dtype=object) 451 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. 401 | -------------------------------------------------------------------------------- /src/distributed_kmeans_gpu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import collections 9 | import logging 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | import torch.distributed as dist 15 | from sklearn.utils import check_random_state 16 | from tqdm import tqdm 17 | 18 | from . import kmeans_gpu as kmg 19 | from .dist_comm import ( 20 | gather_tensor, 21 | get_global_rank, 22 | get_global_size, 23 | is_main_process, 24 | synchronize, 25 | ) 26 | from .utils import ( 27 | _delete_old_checkpoint, 28 | get_last_valid_checkpoint, 29 | ) 30 | 31 | 32 | logger = logging.getLogger("hkmeans") 33 | 34 | class ExtendedNumpyMemMap(object): 35 | """ 36 | Class representing an arbitrary slice of a memmap to a numpy array or an array 37 | """ 38 | 39 | def __init__(self, X, indices): 40 | """ 41 | Parameters: 42 | X: memmap to a numy array, or an array 43 | indices: array, indices representing the slice 44 | """ 45 | if not isinstance(indices, np.ndarray): 46 | raise ValueError("indices must be a numpy array") 47 | if indices.ndim != 1: 48 | raise ValueError("indices must have dimension 1") 49 | self.X = X 50 | self.indices = indices 51 | self.shape = (len(indices), X.shape[1]) 52 | 53 | def __getitem__(self, ids): 54 | return self.X[self.indices[ids]] 55 | 56 | def __len__(self): 57 | return len(self.indices) 58 | 59 | def numpy(self): 60 | return np.array(self.X[self.indices]) 61 | 62 | def to_tensor(self, dtype, device): 63 | return torch.tensor(self.numpy(), device=device, dtype=dtype) 64 | 65 | 66 | def get_part_indices(num_points, world_size): 67 | """ 68 | Get indices of data points managed by each worker 69 | """ 70 | return [round(num_points / world_size * i) for i in range(world_size + 1)] 71 | 72 | 73 | def get_part_len(part_idx, num_points, world_size): 74 | """ 75 | Get number of data points managed by each worker 76 | """ 77 | return round(num_points / world_size * (part_idx + 1)) - round( 78 | num_points / world_size * part_idx 79 | ) 80 | 81 | 82 | def load_data_to_worker(X, device="cuda", dtype=torch.float32): 83 | """ 84 | Parameters: 85 | X: memmap / array or ExtendedNumpyMemMap, the data matrix 86 | device: 87 | dtype: 88 | 89 | Returns: 90 | part of the data allocated to the current worker 91 | """ 92 | rank = get_global_rank() 93 | part_indices = get_part_indices(X.shape[0], get_global_size()) 94 | logger.info(f"Rank {rank}: Loading data") 95 | Xi = torch.tensor( 96 | np.array(X[part_indices[rank] : part_indices[rank + 1]]), 97 | device=device, 98 | dtype=dtype, 99 | ) 100 | synchronize() 101 | logger.info(f"Rank: {rank}, X.shape: {X.shape}, Xi.shape: {Xi.shape}") 102 | return Xi 103 | 104 | 105 | def distributed_matmul(X, Xi, Y, do_all_gather=False): 106 | """ 107 | Compute matrix multiplication XY in a distributed manner. 108 | 109 | Parameters: 110 | 111 | X: mem_map of an array of shape (n_samples, n_features) or the array itself 112 | Data. 113 | Xi: torch.tensor 114 | Part of data that is managed by the current device. 115 | Y: torch.tensor 116 | Same on all worker. 117 | do_all_gather: bool 118 | Whether to only store the final result in the main 119 | process (False) or to have a copy of it in all processes (True). In the 120 | former case, returns None except for the main process. 121 | 122 | Returns: 123 | 124 | Product of X and Y. 125 | 126 | """ 127 | XY = torch.matmul(Xi, Y) 128 | return gather_tensor(XY, do_all_gather) 129 | 130 | 131 | def compute_data_squared_norms(X, Xi, do_all_gather=False): 132 | """ 133 | Compute squared L2 norm of rows of X in a distributed manner. 134 | 135 | Parameters: 136 | 137 | X: mem_map of an array of shape (n_samples, n_features) or the array itself 138 | Data. 139 | Xi: torch.tensor 140 | Part of data that is managed by the current device. 141 | do_all_gather: bool 142 | Whether to only store the final result in the main 143 | process (False) or to have a copy of it in all processes (True). In the 144 | former case, returns None except for the main process. 145 | 146 | Returns: 147 | 148 | Squared L2 norm of rows of X 149 | 150 | """ 151 | xi_squared_norms = torch.linalg.vector_norm(Xi, dim=1) ** 2 152 | return gather_tensor(xi_squared_norms, do_all_gather) 153 | 154 | 155 | def distributed_squared_euclidean_distance( 156 | X, Xi, Y, X_squared_norms, do_all_gather=False 157 | ): 158 | """ 159 | Compute squared Euclidean distance between X and Y. 160 | 161 | Parameters: 162 | 163 | X: mem_map of an array of shape (n_samples, n_features) or the array itself 164 | Data. 165 | Xi: torch.tensor 166 | Part of data that is managed by the current device. 167 | Y: torch.tensor 168 | Same on all worker. 169 | X_squared_norms: torch.tensor of shape (n_samples, ) 170 | Squared L2 norm of rows of X. 171 | do_all_gather: bool 172 | Whether to only store the final result in the main 173 | process (False) or to have a copy of it in all processes (True). In the 174 | former case, returns None except for the main process. 175 | 176 | Returns: 177 | 178 | Pairwise squared Euclidean distance between rows of X and Y. 179 | 180 | """ 181 | XY = distributed_matmul(X, Xi, Y.T, do_all_gather) 182 | if do_all_gather: 183 | Y_squared_norms = torch.linalg.vector_norm(Y, dim=1) ** 2 184 | XY_dist = X_squared_norms[:, None] - 2 * XY + Y_squared_norms[None, :] 185 | return XY_dist 186 | else: 187 | if is_main_process(): 188 | Y_squared_norms = torch.linalg.vector_norm(Y, dim=1) ** 2 189 | XY_dist = X_squared_norms[:, None] - 2 * XY + Y_squared_norms[None, :] 190 | return XY_dist 191 | else: 192 | return None 193 | 194 | 195 | def select_best_candidate( 196 | X, 197 | Xi, 198 | xi_squared_norms, 199 | candidate_ids, 200 | closest_dist_sq, 201 | high_precision=torch.float32, 202 | ): 203 | """ 204 | The selection sub-procedure of kmeans++ initialization. 205 | Given a list of candidates to select as the next centroid, it find 206 | the candidate that would result in the smallest partial kmeans objective. 207 | 208 | Parameters: 209 | 210 | X: mem_map of an array of shape (n_samples, n_features) or the array itself 211 | Data. 212 | Xi: torch.tensor 213 | Part of data that is managed by the current device. 214 | candidate_ids: tensor 215 | List of indices of points to select as the next centroid. 216 | closest_dist_sq: torch.tensor of shape (n_samples,) 217 | Squared Euclidean distance to the closest selected centroid. 218 | high_precision: torch.float32 or torch.float64 219 | The precision used when high precision is required. 220 | 221 | Returns: 222 | 223 | int, best candidate in candidate_ids 224 | current_pot: the updated kmeans potential after adding the new centroid 225 | updated closest_dist_sq 226 | """ 227 | 228 | if high_precision not in [torch.float32, torch.float64]: 229 | raise ValueError( 230 | "Only support high_precision value in [torch.float32, torch.float64]" 231 | ) 232 | 233 | part_indices = get_part_indices(X.shape[0], get_global_size()) 234 | rank = get_global_rank() 235 | 236 | # load features of the candidates from X 237 | Y_candidates = torch.tensor( 238 | np.array(X[candidate_ids.detach().cpu().numpy()]), 239 | device=Xi.device, 240 | dtype=Xi.dtype, 241 | ) 242 | # compute squared Euclidean distance from candidates to data points 243 | distance_to_candidates = ( 244 | xi_squared_norms[:, None] 245 | - 2 * torch.matmul(Xi, Y_candidates.T) 246 | + torch.linalg.vector_norm(Y_candidates, dim=1)[None, :] ** 2 247 | ) 248 | distance_to_candidates = distance_to_candidates.type(high_precision).T 249 | 250 | # compute the kmeans potentials if adding each of the candidates 251 | torch.minimum( 252 | closest_dist_sq[part_indices[rank] : part_indices[rank + 1]], 253 | distance_to_candidates, 254 | out=distance_to_candidates, 255 | ) 256 | candidates_pot = distance_to_candidates.sum(dim=1) 257 | dist.all_reduce(candidates_pot, op=dist.ReduceOp.SUM) 258 | 259 | # select the candidate that results in the smallest potential 260 | best_candidate = torch.argmin(candidates_pot) 261 | 262 | # gather closest_dist_sq 263 | new_closest_dist_sq = distance_to_candidates[best_candidate].contiguous() 264 | new_closest_dist_sq = gather_tensor(new_closest_dist_sq, do_all_gather=True) 265 | 266 | # Update potential 267 | current_pot = new_closest_dist_sq.sum() 268 | 269 | return candidate_ids[best_candidate].item(), current_pot, new_closest_dist_sq 270 | 271 | 272 | def distributed_kmeans_plusplus_init( 273 | X, 274 | Xi, 275 | n_clusters, 276 | x_squared_norms, 277 | random_state=None, 278 | n_local_trials=None, 279 | high_precision=torch.float32, 280 | save_dir=None, 281 | checkpoint_period=-1, 282 | max_num_checkpoints=5, 283 | saving_checkpoint_pattern="kmpp_checkpoint_%d.pth", 284 | ): 285 | """Computational component for initialization of n_clusters by 286 | k-means++. Prior validation of data is assumed. 287 | 288 | Parameters 289 | 290 | X: mem_map of an array of shape (n_samples, n_features) or the array itself 291 | Data. 292 | Xi: torch.tensor 293 | Part of data that is managed by the current device. 294 | n_clusters : int 295 | The number of seeds to choose. 296 | x_squared_norms : torch.tensor (n_samples,) 297 | Squared Euclidean norm of each data point. 298 | random_state : RandomState instance 299 | The generator used to initialize the centers. 300 | n_local_trials : int, default=None 301 | The number of seeding trials for each center (except the first), 302 | of which the one reducing inertia the most is greedily chosen. 303 | Set to None to make the number of trials depend logarithmically 304 | on the number of seeds (2+log(k)); this is the default. 305 | high_precision: torch.float32 or torch.float64 306 | The precision used when high precision is required 307 | save_dir: str or Path 308 | Location for saving checkpoints. 309 | checkpoint_period: int 310 | Save checkpoint after every 'checkpoint_period' iterations, put -1 if do 311 | not want checkpointing. 312 | max_num_checkpoints: int 313 | Maximum number of checkpoints to keep, if exceeded, the oldest checkpoint 314 | will be deleted. 315 | 316 | Returns 317 | 318 | centers : torch.tensor of shape (n_clusters, n_features) 319 | The initial centers for k-means. 320 | indices : ndarray of shape (n_clusters,) 321 | The index location of the chosen centers in the data array X. For a 322 | given index and center, X[index] = center. 323 | """ 324 | if checkpoint_period > 0: 325 | assert save_dir 326 | 327 | # Common variables in devices 328 | n_samples, n_features = X.shape 329 | num_candidates = torch.tensor(0, device=Xi.device) 330 | candidate_ids = None 331 | Y_candidates = None 332 | 333 | xi_squared_norms = torch.linalg.vector_norm(Xi, dim=1) ** 2 334 | 335 | # Set the number of local seeding trials if none is given 336 | if n_local_trials is None: 337 | n_local_trials = 2 + int(np.log(n_clusters)) 338 | 339 | if get_last_valid_checkpoint(save_dir, saving_checkpoint_pattern): 340 | # Load data from checkpoint if exists 341 | ckpt_path = get_last_valid_checkpoint(save_dir, saving_checkpoint_pattern) 342 | logger.info(f"Loading checkpoint from {ckpt_path}") 343 | ckpt = torch.load(ckpt_path, map_location="cpu") 344 | begin_iter = ckpt["iter"] + 1 345 | centers = ckpt["centers"].to(Xi.device) 346 | pots = ckpt["pots"].to(Xi.device) 347 | current_pot = ckpt["current_pot"].to(Xi.device) 348 | closest_dist_sq = ckpt["closest_dist_sq"].to(Xi.device) 349 | indices = ckpt["indices"] 350 | random_state = ckpt["random_state"] 351 | else: 352 | logger.info("Initializing the first centroid") 353 | begin_iter = 1 354 | centers = torch.empty( 355 | (n_clusters, n_features), dtype=Xi.dtype, device=Xi.device 356 | ) 357 | pots = torch.empty((n_clusters,), dtype=high_precision, device=Xi.device) 358 | indices = np.full(n_clusters, -1, dtype=int) 359 | 360 | if random_state is None: 361 | random_state = check_random_state(random_state) 362 | 363 | if is_main_process(): 364 | # Pick first center randomly and track index of point 365 | center_id = random_state.randint(n_samples) 366 | centers[0] = torch.tensor( 367 | X[center_id], dtype=centers.dtype, device=centers.device 368 | ) 369 | indices[0] = center_id 370 | 371 | Y_candidates = centers[ 372 | [0] 373 | ] # guarantee that Y_candidates have size (1, n_features) 374 | 375 | if not is_main_process(): 376 | Y_candidates = torch.zeros( 377 | (1, n_features), device=Xi.device, dtype=Xi.dtype 378 | ) 379 | dist.broadcast(Y_candidates, src=0) 380 | synchronize() 381 | 382 | # Initialize list of closest distances and calculate current potential 383 | closest_dist_sq = ( 384 | distributed_squared_euclidean_distance( 385 | X, Xi, Y_candidates, x_squared_norms, do_all_gather=True 386 | ) 387 | .ravel() 388 | .type(high_precision) 389 | ) 390 | current_pot = closest_dist_sq.sum() 391 | pots[0] = current_pot 392 | 393 | synchronize() 394 | logger.info("Begin main loop") 395 | # Pick the remaining n_clusters-1 points 396 | if is_main_process(): 397 | iterates = tqdm( 398 | range(begin_iter, n_clusters), 399 | desc="Distributed kmeans++ initialization", 400 | file=sys.stdout, 401 | bar_format="{l_bar}{bar}{r_bar}", 402 | ) 403 | else: 404 | iterates = range(begin_iter, n_clusters) 405 | for c in iterates: 406 | if is_main_process(): 407 | # Choose center candidates by sampling with probability proportional 408 | # to the squared distance to the closest existing center 409 | rand_vals = ( 410 | torch.tensor(random_state.uniform(size=n_local_trials)).to( 411 | current_pot.device 412 | ) 413 | * current_pot 414 | ) 415 | candidate_ids = torch.searchsorted( 416 | torch.cumsum(closest_dist_sq, dim=0), rand_vals 417 | ) 418 | # numerical imprecision can result in a candidate_id out of range 419 | torch.clip( 420 | candidate_ids, None, closest_dist_sq.shape[0] - 1, out=candidate_ids 421 | ) 422 | num_candidates = torch.tensor(len(candidate_ids), device=Xi.device) 423 | 424 | # broadcast candidate_ids candidates to all processes 425 | dist.broadcast(num_candidates, src=0) 426 | synchronize() 427 | if not is_main_process(): 428 | candidate_ids = torch.zeros( 429 | (num_candidates,), device=Xi.device, dtype=torch.int64 430 | ) 431 | dist.broadcast(candidate_ids, src=0) 432 | synchronize() 433 | 434 | best_candidate, current_pot, closest_dist_sq = select_best_candidate( 435 | X, 436 | Xi, 437 | xi_squared_norms, 438 | candidate_ids, 439 | closest_dist_sq, 440 | high_precision=high_precision, 441 | ) 442 | pots[c] = current_pot 443 | 444 | if is_main_process(): 445 | # Permanently add best center candidate found in local tries 446 | centers[c] = torch.tensor( 447 | X[best_candidate], dtype=Xi.dtype, device=Xi.device 448 | ) 449 | indices[c] = best_candidate 450 | 451 | if ( 452 | checkpoint_period > 0 453 | and save_dir 454 | and ((c + 1) % checkpoint_period == 0 or c + 1 == n_clusters) 455 | ): 456 | logger.info("Saving checkpoint to " + saving_checkpoint_pattern % c) 457 | torch.save( 458 | { 459 | "centers": centers.cpu(), 460 | "indices": indices, 461 | "iter": c, 462 | "current_pot": current_pot.cpu(), 463 | "pots": pots.cpu(), 464 | "closest_dist_sq": closest_dist_sq.cpu(), 465 | "random_state": random_state, 466 | }, 467 | Path(save_dir, saving_checkpoint_pattern % c), 468 | pickle_protocol=4, 469 | ) 470 | _delete_old_checkpoint( 471 | save_dir, 472 | c, 473 | checkpoint_period, 474 | max_num_checkpoints, 475 | saving_checkpoint_pattern, 476 | ) 477 | 478 | indices = torch.tensor(indices, device=Xi.device) 479 | logger.info(f"Kmeans potential of kmeans++ initialization: {current_pot}") 480 | dist.broadcast(centers, src=0) 481 | dist.broadcast(indices, src=0) 482 | synchronize() 483 | 484 | return centers, indices 485 | 486 | 487 | def distributed_assign_clusters(X, Xi, centroids, chunk_size, verbose=False): 488 | """ 489 | The assignment sub-procedure of k-means. Given the centroids, assign data points to the index 490 | of the nearest centroids. 491 | 492 | Parameters: 493 | 494 | X: mem_map of an array of shape (n_samples, n_features) or the array itself 495 | Data. Though not used, still put X here to have a consistent function signature. 496 | Xi: torch.tensor 497 | Part of data that is managed by the current device. 498 | centroids: torch.tensor of shape (n_clusters x n_features) 499 | Centroids of clusters. 500 | chunk_size: int 501 | Number of data points that are assigned at once. 502 | Use a small chunk_size if n_clusters is large to avoid 503 | out-of-memory error, e.g. chunk_size <= 1e9/n_clusters. 504 | Default is -1, meaning all data points are assigned at once. 505 | verbose: bool 506 | Whether to print progress. 507 | 508 | Returns: 509 | 510 | The assignment of points in X to centroids, each process has a copy of the final result. 511 | 512 | """ 513 | 514 | cluster_assignment = kmg.assign_clusters(centroids, Xi, "l2", chunk_size, verbose=verbose) 515 | cluster_assignment = gather_tensor(cluster_assignment, do_all_gather=True) 516 | return cluster_assignment 517 | 518 | 519 | def distributed_compute_centroids( 520 | X, Xi, n_clusters, centroids, cluster_assignment, high_precision=torch.float32 521 | ): 522 | """ 523 | Compute centroids of each cluster given its data points. 524 | 525 | Parameters: 526 | 527 | X: mem_map of an array of shape (n_samples, n_features) or the array itself 528 | Data. Though not used, still put X here to have a consistent function signature. 529 | Xi: torch.tensor 530 | Part of data that is managed by the current device. 531 | n_clusters: int 532 | Number of clusters. 533 | centroids: torch.tensor of shape (n_clusters, n_features) 534 | Previous centroids of the clusters. 535 | cluster_assignment: torch.tensor of shape (n_samples, ) 536 | Cluster id of data points. 537 | high_precision: torch.float32 or torch.float64, to save GPU memory, one 538 | can use float32 or float16 for data 'X', 'high_precision' will be 539 | use in aggregation operation to avoid overflow. 540 | 541 | Returns: 542 | 543 | torch.tensor of shape (n_clusters, n_features), new centroids. 544 | 545 | """ 546 | part_indices = get_part_indices(X.shape[0], get_global_size()) 547 | rank = get_global_rank() 548 | cluster_assignment_i = cluster_assignment[ 549 | part_indices[rank] : part_indices[rank + 1] 550 | ] 551 | clusters_i = kmg.create_clusters_from_cluster_assignment( 552 | cluster_assignment_i, n_clusters 553 | ) 554 | 555 | in_cluster_sum = torch.zeros( 556 | (n_clusters, Xi.shape[1]), device=Xi.device, dtype=high_precision 557 | ) 558 | for _cluster_idx in range(n_clusters): 559 | if len(clusters_i[_cluster_idx]) > 0: 560 | in_cluster_sum[_cluster_idx] = torch.sum( 561 | Xi[clusters_i[_cluster_idx]].type(high_precision), dim=0 562 | ) 563 | dist.all_reduce(in_cluster_sum, op=dist.ReduceOp.SUM) 564 | 565 | cluster_size = collections.Counter(cluster_assignment) 566 | for _cluster_idx in range(n_clusters): 567 | if cluster_size[_cluster_idx] > 0: 568 | in_cluster_sum[_cluster_idx] = ( 569 | in_cluster_sum[_cluster_idx] / cluster_size[_cluster_idx] 570 | ) 571 | else: 572 | in_cluster_sum[_cluster_idx] = centroids[_cluster_idx] 573 | return in_cluster_sum.type(Xi.dtype) 574 | 575 | 576 | def distributed_kmeans( 577 | X, 578 | Xi, 579 | n_clusters, 580 | n_iters=10, 581 | chunk_size=1000, 582 | init_method="kmeans++", 583 | random_state=None, 584 | save_dir=None, 585 | save_kmpp_results=True, 586 | kmpp_checkpoint_period=-1, 587 | high_precision=torch.float32, 588 | checkpoint_period=5, 589 | checkpoint_pattern="centroids_checkpoint_%d.npy", 590 | ): 591 | """ 592 | Parameters: 593 | 594 | X: mem_map of an array of shape (n_samples, n_features) or the array itself 595 | Data. 596 | Xi: torch.tensor 597 | Part of data that is managed by the current device. 598 | n_clusters: int 599 | Number of clusters. 600 | chunk_size: int 601 | Number of data points that are assigned at once. 602 | Use a small chunk_size if n_clusters is large to avoid 603 | out-of-memory error, e.g. chunk_size <= 1e9/n_clusters. 604 | Default is -1, meaning all data points are assigned at once. 605 | init_method: str 606 | 'kmeans++' or 'random' 607 | save_kmpp_results: bool 608 | Whether to save kmeans++ init results. 609 | save_dir: str or Path 610 | Where to save results. 611 | 612 | Returns: 613 | 614 | centroids: 615 | cluster_assignment: array containing the cluster index of each point. 616 | 617 | """ 618 | assert save_dir or ( 619 | not save_kmpp_results 620 | ), "provide save_dir to save kmeans++ init results" 621 | 622 | if get_last_valid_checkpoint(save_dir, checkpoint_pattern): 623 | ckpt_path = get_last_valid_checkpoint(save_dir, checkpoint_pattern) 624 | logger.info(f"Loading checkpoint from {ckpt_path}") 625 | begin_iter = int(Path(ckpt_path).stem.split("_")[-1]) 626 | centroids = torch.tensor(np.load(ckpt_path), dtype=Xi.dtype, device=Xi.device) 627 | cluster_assignment = ( 628 | distributed_assign_clusters(X, Xi, centroids, chunk_size).cpu().numpy() 629 | ) 630 | else: 631 | if random_state is None: 632 | random_state = check_random_state(random_state) 633 | if init_method == "kmeans++": 634 | x_squared_norms = compute_data_squared_norms(X, Xi, do_all_gather=True) 635 | centroids, indices = distributed_kmeans_plusplus_init( 636 | X, 637 | Xi, 638 | n_clusters, 639 | x_squared_norms, 640 | save_dir=save_dir, 641 | high_precision=high_precision, 642 | checkpoint_period=kmpp_checkpoint_period, 643 | random_state=random_state, 644 | ) 645 | if save_kmpp_results: 646 | Path(save_dir).mkdir(parents=True, exist_ok=True) 647 | np.save(Path(save_dir, "kmpp_centers.npy"), centroids.cpu().numpy()) 648 | np.save(Path(save_dir, "kmpp_indices.npy"), indices.cpu().numpy()) 649 | 650 | elif init_method == "random": 651 | indices = torch.tensor( 652 | np.sort(random_state.choice(range(len(X)), n_clusters, replace=False)), 653 | device=Xi.device, 654 | ) 655 | dist.broadcast(indices, src=0) 656 | synchronize() 657 | centroids = torch.tensor( 658 | X[indices.cpu().numpy()], dtype=Xi.dtype, device=Xi.device 659 | ) 660 | else: 661 | raise ValueError(f'Initialization method "{init_method}" not supported!') 662 | 663 | cluster_assignment = ( 664 | distributed_assign_clusters(X, Xi, centroids, chunk_size).cpu().numpy() 665 | ) 666 | begin_iter = 0 667 | 668 | for _iter in tqdm( 669 | range(begin_iter, n_iters), 670 | desc="Distributed kmeans interation", 671 | file=sys.stdout, 672 | bar_format="{l_bar}{bar}{r_bar}", 673 | ): 674 | centroids = distributed_compute_centroids( 675 | X, 676 | Xi, 677 | n_clusters, 678 | centroids, 679 | cluster_assignment, 680 | high_precision=high_precision, 681 | ) 682 | cluster_assignment = ( 683 | distributed_assign_clusters(X, Xi, centroids, chunk_size).cpu().numpy() 684 | ) 685 | if ( 686 | checkpoint_period > 0 687 | and (_iter + 1) % checkpoint_period == 0 688 | and is_main_process() 689 | ): 690 | logger.info("Saving checkpoint to " + checkpoint_pattern % (_iter + 1)) 691 | np.save( 692 | Path(save_dir, checkpoint_pattern % (_iter + 1)), 693 | centroids.cpu().numpy(), 694 | ) 695 | synchronize() 696 | return centroids, cluster_assignment 697 | 698 | 699 | def distributed_sort_cluster_by_distance( 700 | X, 701 | centroids, 702 | clusters, 703 | device="cuda", 704 | dtype=torch.float32, 705 | save_dir=None, 706 | checkpoint_period=-1, 707 | ): 708 | """ 709 | Parameters: 710 | X: memory map of an array of shape (n_samples, n_features) or the array itself 711 | Data. 712 | centroids: torch.tensor of shape (n_clusters x dim) 713 | Centroids of clusters. 714 | clusters: (n_clusters,) array or list 715 | clusters[i] contains indices of points in cluster i 716 | 717 | Returns: 718 | 719 | sorted_clusters: list 720 | sorted_clusters[i] contains indices of points in cluster i in increasing order 721 | from the centroid. 722 | """ 723 | 724 | n_clusters, n_dim = centroids.shape[0], centroids.shape[1] 725 | part_indices = get_part_indices(n_clusters, get_global_size()) 726 | rank = get_global_rank() 727 | 728 | if checkpoint_period > 0 and Path( 729 | save_dir, 730 | f"sorted_clusters_checkpoint_{rank}.npy" 731 | ).exists(): 732 | cluster_data = np.load( 733 | Path( 734 | save_dir, 735 | f"sorted_clusters_checkpoint_{rank}.npy" 736 | ), 737 | allow_pickle=True 738 | ).item() 739 | sorted_clusters = cluster_data["sorted_clusters"] 740 | prev_item = cluster_data["prev_item"] 741 | else: 742 | sorted_clusters = [] 743 | prev_item = part_indices[rank] - 1 744 | 745 | for cluster_idx in tqdm( 746 | range(prev_item + 1, part_indices[rank + 1]), 747 | desc="Distributed sorting clusters by distance", 748 | file=sys.stdout, 749 | bar_format="{l_bar}{bar}{r_bar}", 750 | ): 751 | point_indices = np.sort(clusters[cluster_idx]) 752 | point_feats = torch.tensor(X[point_indices], device=device, dtype=dtype) 753 | _centroid = torch.tensor( 754 | centroids[cluster_idx], 755 | device=device, 756 | dtype=dtype 757 | ).reshape(1, n_dim) 758 | 759 | dist_to_centroid = torch.cdist(point_feats, _centroid).flatten() 760 | sorted_clusters.append( 761 | point_indices[torch.argsort(dist_to_centroid).cpu().numpy()] 762 | ) 763 | del point_feats 764 | 765 | if( 766 | ( 767 | checkpoint_period > 0 and 768 | cluster_idx % checkpoint_period == 0 769 | ) or 770 | cluster_idx == part_indices[rank + 1] - 1 771 | ): 772 | logger.info(f"Saving checkpoint to {save_dir}/sorted_clusters_checkpoint_{rank}.npy") 773 | np.save( 774 | Path(save_dir, f"sorted_clusters_checkpoint_{rank}.npy"), 775 | { 776 | "sorted_clusters": sorted_clusters, 777 | "prev_item": cluster_idx 778 | } 779 | ) 780 | synchronize() 781 | if is_main_process(): 782 | logger.info("Gathering clusters") 783 | sorted_clusters = [] 784 | for i in tqdm( 785 | range(get_global_size()), 786 | desc="Distributed gathering sorted clusters", 787 | file=sys.stdout, 788 | bar_format="{l_bar}{bar}{r_bar}", 789 | ): 790 | rank_data = np.load( 791 | Path(save_dir, f"sorted_clusters_checkpoint_{i}.npy"), 792 | allow_pickle=True, 793 | ).item() 794 | assert rank_data['prev_item'] == part_indices[i + 1] - 1 795 | sorted_clusters += rank_data["sorted_clusters"] 796 | sorted_clusters = np.array(sorted_clusters, dtype=object) 797 | np.save( 798 | Path(save_dir, "sorted_clusters.npy"), 799 | sorted_clusters 800 | ) 801 | for i in range(get_global_size()): 802 | Path(save_dir, f"sorted_clusters_checkpoint_{i}.npy").unlink(missing_ok=True) 803 | 804 | synchronize() 805 | sorted_clusters = np.load( 806 | Path(save_dir, "sorted_clusters.npy"), 807 | allow_pickle=True 808 | ) 809 | return sorted_clusters 810 | --------------------------------------------------------------------------------