├── .gitignore
├── LICENSE
├── README.md
├── annotate_data
├── annotate.sh
├── clusters.py
├── domains.py
├── edu.py
├── embed.py
├── fasttext.py
├── perplexity.py
└── tokens.py
├── define_domains
├── k-means-clustering
│ ├── CODE_OF_CONDUCT.md
│ ├── CONTRIBUTING.md
│ ├── LICENSE
│ ├── README.md
│ ├── __init__.py
│ ├── configs
│ │ ├── 1level_dcml.yaml
│ │ ├── 2levels_random_embeddings.yaml
│ │ └── 4levels_web_based_images.yaml
│ ├── exps
│ │ ├── dclm-1level-k118
│ │ │ └── level1
│ │ │ │ ├── centroids.npy
│ │ │ │ └── slurm_script.s
│ │ ├── dclm-1level-k13824
│ │ │ └── level1
│ │ │ │ ├── centroids.npy
│ │ │ │ └── slurm_script.s
│ │ ├── dclm-1level-k24
│ │ │ └── level1
│ │ │ │ ├── centroids.npy
│ │ │ │ └── slurm_script.s
│ │ ├── dclm-1level-k2822
│ │ │ └── level1
│ │ │ │ ├── centroids.npy
│ │ │ │ └── slurm_script.s
│ │ ├── dclm-1level-k576
│ │ │ └── level1
│ │ │ │ ├── centroids.npy
│ │ │ │ └── slurm_script.s
│ │ └── dclm-1level-k67723
│ │ │ └── level1
│ │ │ ├── centroids.npy
│ │ │ └── slurm_script.s
│ ├── images
│ │ ├── curation_pipeline.png
│ │ └── toy_example.png
│ ├── requirements.txt
│ ├── scripts
│ │ ├── __init__.py
│ │ ├── hierarchical_kmeans_launcher.py
│ │ ├── run_distributed_kmeans.py
│ │ ├── run_hierarchical_sampling.py
│ │ └── split_clusters.py
│ ├── setup.py
│ ├── src
│ │ ├── __init__.py
│ │ ├── clusters.py
│ │ ├── dist_comm.py
│ │ ├── distributed_kmeans_gpu.py
│ │ ├── hierarchical_kmeans_gpu.py
│ │ ├── hierarchical_sampling.py
│ │ ├── kmeans_gpu.py
│ │ └── utils.py
│ └── vis
│ │ ├── __init__.py
│ │ ├── generalized_kmeans_1d.py
│ │ └── notebook.ipynb
├── prompt_classify.py
├── prompt_classify.sh
├── taxonomies
│ ├── formats.yaml
│ └── topics.yaml
├── train_classifier.py
└── train_classifier.sh
├── domain_statistics.py
├── learn_mixtures
├── average_mixtures.py
├── combine_mixtures.py
└── training_mixes.py
├── select_training_data.py
└── website
├── assets
├── data
│ ├── examples
│ │ ├── format0.json
│ │ ├── format1.json
│ │ ├── format10.json
│ │ ├── format11.json
│ │ ├── format12.json
│ │ ├── format13.json
│ │ ├── format14.json
│ │ ├── format15.json
│ │ ├── format16.json
│ │ ├── format17.json
│ │ ├── format18.json
│ │ ├── format19.json
│ │ ├── format2.json
│ │ ├── format20.json
│ │ ├── format21.json
│ │ ├── format22.json
│ │ ├── format23.json
│ │ ├── format3.json
│ │ ├── format4.json
│ │ ├── format5.json
│ │ ├── format6.json
│ │ ├── format7.json
│ │ ├── format8.json
│ │ ├── format9.json
│ │ ├── topic0.json
│ │ ├── topic0_format0.json
│ │ ├── topic0_format1.json
│ │ ├── topic0_format10.json
│ │ ├── topic0_format11.json
│ │ ├── topic0_format12.json
│ │ ├── topic0_format13.json
│ │ ├── topic0_format14.json
│ │ ├── topic0_format15.json
│ │ ├── topic0_format16.json
│ │ ├── topic0_format17.json
│ │ ├── topic0_format18.json
│ │ ├── topic0_format19.json
│ │ ├── topic0_format2.json
│ │ ├── topic0_format20.json
│ │ ├── topic0_format21.json
│ │ ├── topic0_format22.json
│ │ ├── topic0_format23.json
│ │ ├── topic0_format3.json
│ │ ├── topic0_format4.json
│ │ ├── topic0_format5.json
│ │ ├── topic0_format6.json
│ │ ├── topic0_format7.json
│ │ ├── topic0_format8.json
│ │ ├── topic0_format9.json
│ │ ├── topic1.json
│ │ ├── topic10.json
│ │ ├── topic10_format0.json
│ │ ├── topic10_format1.json
│ │ ├── topic10_format10.json
│ │ ├── topic10_format11.json
│ │ ├── topic10_format12.json
│ │ ├── topic10_format13.json
│ │ ├── topic10_format14.json
│ │ ├── topic10_format15.json
│ │ ├── topic10_format16.json
│ │ ├── topic10_format17.json
│ │ ├── topic10_format18.json
│ │ ├── topic10_format19.json
│ │ ├── topic10_format2.json
│ │ ├── topic10_format20.json
│ │ ├── topic10_format21.json
│ │ ├── topic10_format22.json
│ │ ├── topic10_format23.json
│ │ ├── topic10_format3.json
│ │ ├── topic10_format4.json
│ │ ├── topic10_format5.json
│ │ ├── topic10_format6.json
│ │ ├── topic10_format7.json
│ │ ├── topic10_format8.json
│ │ ├── topic10_format9.json
│ │ ├── topic11.json
│ │ ├── topic11_format0.json
│ │ ├── topic11_format1.json
│ │ ├── topic11_format10.json
│ │ ├── topic11_format11.json
│ │ ├── topic11_format12.json
│ │ ├── topic11_format13.json
│ │ ├── topic11_format14.json
│ │ ├── topic11_format15.json
│ │ ├── topic11_format16.json
│ │ ├── topic11_format17.json
│ │ ├── topic11_format18.json
│ │ ├── topic11_format19.json
│ │ ├── topic11_format2.json
│ │ ├── topic11_format20.json
│ │ ├── topic11_format21.json
│ │ ├── topic11_format22.json
│ │ ├── topic11_format23.json
│ │ ├── topic11_format3.json
│ │ ├── topic11_format4.json
│ │ ├── topic11_format5.json
│ │ ├── topic11_format6.json
│ │ ├── topic11_format7.json
│ │ ├── topic11_format8.json
│ │ ├── topic11_format9.json
│ │ ├── topic12.json
│ │ ├── topic12_format0.json
│ │ ├── topic12_format1.json
│ │ ├── topic12_format10.json
│ │ ├── topic12_format11.json
│ │ ├── topic12_format12.json
│ │ ├── topic12_format13.json
│ │ ├── topic12_format14.json
│ │ ├── topic12_format15.json
│ │ ├── topic12_format16.json
│ │ ├── topic12_format17.json
│ │ ├── topic12_format18.json
│ │ ├── topic12_format19.json
│ │ ├── topic12_format2.json
│ │ ├── topic12_format20.json
│ │ ├── topic12_format21.json
│ │ ├── topic12_format22.json
│ │ ├── topic12_format23.json
│ │ ├── topic12_format3.json
│ │ ├── topic12_format4.json
│ │ ├── topic12_format5.json
│ │ ├── topic12_format6.json
│ │ ├── topic12_format7.json
│ │ ├── topic12_format8.json
│ │ ├── topic12_format9.json
│ │ ├── topic13.json
│ │ ├── topic13_format0.json
│ │ ├── topic13_format1.json
│ │ ├── topic13_format10.json
│ │ ├── topic13_format11.json
│ │ ├── topic13_format12.json
│ │ ├── topic13_format13.json
│ │ ├── topic13_format14.json
│ │ ├── topic13_format15.json
│ │ ├── topic13_format16.json
│ │ ├── topic13_format17.json
│ │ ├── topic13_format18.json
│ │ ├── topic13_format19.json
│ │ ├── topic13_format2.json
│ │ ├── topic13_format20.json
│ │ ├── topic13_format21.json
│ │ ├── topic13_format22.json
│ │ ├── topic13_format23.json
│ │ ├── topic13_format3.json
│ │ ├── topic13_format4.json
│ │ ├── topic13_format5.json
│ │ ├── topic13_format6.json
│ │ ├── topic13_format7.json
│ │ ├── topic13_format8.json
│ │ ├── topic13_format9.json
│ │ ├── topic14.json
│ │ ├── topic14_format0.json
│ │ ├── topic14_format1.json
│ │ ├── topic14_format10.json
│ │ ├── topic14_format11.json
│ │ ├── topic14_format12.json
│ │ ├── topic14_format13.json
│ │ ├── topic14_format14.json
│ │ ├── topic14_format15.json
│ │ ├── topic14_format16.json
│ │ ├── topic14_format17.json
│ │ ├── topic14_format18.json
│ │ ├── topic14_format19.json
│ │ ├── topic14_format2.json
│ │ ├── topic14_format20.json
│ │ ├── topic14_format21.json
│ │ ├── topic14_format22.json
│ │ ├── topic14_format23.json
│ │ ├── topic14_format3.json
│ │ ├── topic14_format4.json
│ │ ├── topic14_format5.json
│ │ ├── topic14_format6.json
│ │ ├── topic14_format7.json
│ │ ├── topic14_format8.json
│ │ ├── topic14_format9.json
│ │ ├── topic15.json
│ │ ├── topic15_format0.json
│ │ ├── topic15_format1.json
│ │ ├── topic15_format10.json
│ │ ├── topic15_format11.json
│ │ ├── topic15_format12.json
│ │ ├── topic15_format13.json
│ │ ├── topic15_format14.json
│ │ ├── topic15_format15.json
│ │ ├── topic15_format16.json
│ │ ├── topic15_format17.json
│ │ ├── topic15_format18.json
│ │ ├── topic15_format19.json
│ │ ├── topic15_format2.json
│ │ ├── topic15_format20.json
│ │ ├── topic15_format21.json
│ │ ├── topic15_format22.json
│ │ ├── topic15_format23.json
│ │ ├── topic15_format3.json
│ │ ├── topic15_format4.json
│ │ ├── topic15_format5.json
│ │ ├── topic15_format6.json
│ │ ├── topic15_format7.json
│ │ ├── topic15_format8.json
│ │ ├── topic15_format9.json
│ │ ├── topic16.json
│ │ ├── topic16_format0.json
│ │ ├── topic16_format1.json
│ │ ├── topic16_format10.json
│ │ ├── topic16_format11.json
│ │ ├── topic16_format12.json
│ │ ├── topic16_format13.json
│ │ ├── topic16_format14.json
│ │ ├── topic16_format15.json
│ │ ├── topic16_format16.json
│ │ ├── topic16_format17.json
│ │ ├── topic16_format18.json
│ │ ├── topic16_format19.json
│ │ ├── topic16_format2.json
│ │ ├── topic16_format20.json
│ │ ├── topic16_format21.json
│ │ ├── topic16_format22.json
│ │ ├── topic16_format23.json
│ │ ├── topic16_format3.json
│ │ ├── topic16_format4.json
│ │ ├── topic16_format5.json
│ │ ├── topic16_format6.json
│ │ ├── topic16_format7.json
│ │ ├── topic16_format8.json
│ │ ├── topic16_format9.json
│ │ ├── topic17.json
│ │ ├── topic17_format0.json
│ │ ├── topic17_format1.json
│ │ ├── topic17_format10.json
│ │ ├── topic17_format11.json
│ │ ├── topic17_format12.json
│ │ ├── topic17_format13.json
│ │ ├── topic17_format14.json
│ │ ├── topic17_format15.json
│ │ ├── topic17_format16.json
│ │ ├── topic17_format17.json
│ │ ├── topic17_format18.json
│ │ ├── topic17_format19.json
│ │ ├── topic17_format2.json
│ │ ├── topic17_format20.json
│ │ ├── topic17_format21.json
│ │ ├── topic17_format22.json
│ │ ├── topic17_format23.json
│ │ ├── topic17_format3.json
│ │ ├── topic17_format4.json
│ │ ├── topic17_format5.json
│ │ ├── topic17_format6.json
│ │ ├── topic17_format7.json
│ │ ├── topic17_format8.json
│ │ ├── topic17_format9.json
│ │ ├── topic18.json
│ │ ├── topic18_format0.json
│ │ ├── topic18_format1.json
│ │ ├── topic18_format10.json
│ │ ├── topic18_format11.json
│ │ ├── topic18_format12.json
│ │ ├── topic18_format13.json
│ │ ├── topic18_format14.json
│ │ ├── topic18_format15.json
│ │ ├── topic18_format16.json
│ │ ├── topic18_format17.json
│ │ ├── topic18_format18.json
│ │ ├── topic18_format19.json
│ │ ├── topic18_format2.json
│ │ ├── topic18_format20.json
│ │ ├── topic18_format21.json
│ │ ├── topic18_format22.json
│ │ ├── topic18_format23.json
│ │ ├── topic18_format3.json
│ │ ├── topic18_format4.json
│ │ ├── topic18_format5.json
│ │ ├── topic18_format6.json
│ │ ├── topic18_format7.json
│ │ ├── topic18_format8.json
│ │ ├── topic18_format9.json
│ │ ├── topic19.json
│ │ ├── topic19_format0.json
│ │ ├── topic19_format1.json
│ │ ├── topic19_format10.json
│ │ ├── topic19_format11.json
│ │ ├── topic19_format12.json
│ │ ├── topic19_format13.json
│ │ ├── topic19_format14.json
│ │ ├── topic19_format15.json
│ │ ├── topic19_format16.json
│ │ ├── topic19_format17.json
│ │ ├── topic19_format18.json
│ │ ├── topic19_format19.json
│ │ ├── topic19_format2.json
│ │ ├── topic19_format20.json
│ │ ├── topic19_format21.json
│ │ ├── topic19_format22.json
│ │ ├── topic19_format23.json
│ │ ├── topic19_format3.json
│ │ ├── topic19_format4.json
│ │ ├── topic19_format5.json
│ │ ├── topic19_format6.json
│ │ ├── topic19_format7.json
│ │ ├── topic19_format8.json
│ │ ├── topic19_format9.json
│ │ ├── topic1_format0.json
│ │ ├── topic1_format1.json
│ │ ├── topic1_format10.json
│ │ ├── topic1_format11.json
│ │ ├── topic1_format12.json
│ │ ├── topic1_format13.json
│ │ ├── topic1_format14.json
│ │ ├── topic1_format15.json
│ │ ├── topic1_format16.json
│ │ ├── topic1_format17.json
│ │ ├── topic1_format18.json
│ │ ├── topic1_format19.json
│ │ ├── topic1_format2.json
│ │ ├── topic1_format20.json
│ │ ├── topic1_format21.json
│ │ ├── topic1_format22.json
│ │ ├── topic1_format23.json
│ │ ├── topic1_format3.json
│ │ ├── topic1_format4.json
│ │ ├── topic1_format5.json
│ │ ├── topic1_format6.json
│ │ ├── topic1_format7.json
│ │ ├── topic1_format8.json
│ │ ├── topic1_format9.json
│ │ ├── topic2.json
│ │ ├── topic20.json
│ │ ├── topic20_format0.json
│ │ ├── topic20_format1.json
│ │ ├── topic20_format10.json
│ │ ├── topic20_format11.json
│ │ ├── topic20_format12.json
│ │ ├── topic20_format13.json
│ │ ├── topic20_format14.json
│ │ ├── topic20_format15.json
│ │ ├── topic20_format16.json
│ │ ├── topic20_format17.json
│ │ ├── topic20_format18.json
│ │ ├── topic20_format19.json
│ │ ├── topic20_format2.json
│ │ ├── topic20_format20.json
│ │ ├── topic20_format21.json
│ │ ├── topic20_format22.json
│ │ ├── topic20_format23.json
│ │ ├── topic20_format3.json
│ │ ├── topic20_format4.json
│ │ ├── topic20_format5.json
│ │ ├── topic20_format6.json
│ │ ├── topic20_format7.json
│ │ ├── topic20_format8.json
│ │ ├── topic20_format9.json
│ │ ├── topic21.json
│ │ ├── topic21_format0.json
│ │ ├── topic21_format1.json
│ │ ├── topic21_format10.json
│ │ ├── topic21_format11.json
│ │ ├── topic21_format12.json
│ │ ├── topic21_format13.json
│ │ ├── topic21_format14.json
│ │ ├── topic21_format15.json
│ │ ├── topic21_format16.json
│ │ ├── topic21_format17.json
│ │ ├── topic21_format18.json
│ │ ├── topic21_format19.json
│ │ ├── topic21_format2.json
│ │ ├── topic21_format20.json
│ │ ├── topic21_format21.json
│ │ ├── topic21_format22.json
│ │ ├── topic21_format23.json
│ │ ├── topic21_format3.json
│ │ ├── topic21_format4.json
│ │ ├── topic21_format5.json
│ │ ├── topic21_format6.json
│ │ ├── topic21_format7.json
│ │ ├── topic21_format8.json
│ │ ├── topic21_format9.json
│ │ ├── topic22.json
│ │ ├── topic22_format0.json
│ │ ├── topic22_format1.json
│ │ ├── topic22_format10.json
│ │ ├── topic22_format11.json
│ │ ├── topic22_format12.json
│ │ ├── topic22_format13.json
│ │ ├── topic22_format14.json
│ │ ├── topic22_format15.json
│ │ ├── topic22_format16.json
│ │ ├── topic22_format17.json
│ │ ├── topic22_format18.json
│ │ ├── topic22_format19.json
│ │ ├── topic22_format2.json
│ │ ├── topic22_format20.json
│ │ ├── topic22_format21.json
│ │ ├── topic22_format22.json
│ │ ├── topic22_format23.json
│ │ ├── topic22_format3.json
│ │ ├── topic22_format4.json
│ │ ├── topic22_format5.json
│ │ ├── topic22_format6.json
│ │ ├── topic22_format7.json
│ │ ├── topic22_format8.json
│ │ ├── topic22_format9.json
│ │ ├── topic23.json
│ │ ├── topic23_format0.json
│ │ ├── topic23_format1.json
│ │ ├── topic23_format10.json
│ │ ├── topic23_format11.json
│ │ ├── topic23_format12.json
│ │ ├── topic23_format13.json
│ │ ├── topic23_format14.json
│ │ ├── topic23_format15.json
│ │ ├── topic23_format16.json
│ │ ├── topic23_format17.json
│ │ ├── topic23_format18.json
│ │ ├── topic23_format19.json
│ │ ├── topic23_format2.json
│ │ ├── topic23_format20.json
│ │ ├── topic23_format21.json
│ │ ├── topic23_format22.json
│ │ ├── topic23_format23.json
│ │ ├── topic23_format3.json
│ │ ├── topic23_format4.json
│ │ ├── topic23_format5.json
│ │ ├── topic23_format6.json
│ │ ├── topic23_format7.json
│ │ ├── topic23_format8.json
│ │ ├── topic23_format9.json
│ │ ├── topic2_format0.json
│ │ ├── topic2_format1.json
│ │ ├── topic2_format10.json
│ │ ├── topic2_format11.json
│ │ ├── topic2_format12.json
│ │ ├── topic2_format13.json
│ │ ├── topic2_format14.json
│ │ ├── topic2_format15.json
│ │ ├── topic2_format16.json
│ │ ├── topic2_format17.json
│ │ ├── topic2_format18.json
│ │ ├── topic2_format19.json
│ │ ├── topic2_format2.json
│ │ ├── topic2_format20.json
│ │ ├── topic2_format21.json
│ │ ├── topic2_format22.json
│ │ ├── topic2_format23.json
│ │ ├── topic2_format3.json
│ │ ├── topic2_format4.json
│ │ ├── topic2_format5.json
│ │ ├── topic2_format6.json
│ │ ├── topic2_format7.json
│ │ ├── topic2_format8.json
│ │ ├── topic2_format9.json
│ │ ├── topic3.json
│ │ ├── topic3_format0.json
│ │ ├── topic3_format1.json
│ │ ├── topic3_format10.json
│ │ ├── topic3_format11.json
│ │ ├── topic3_format12.json
│ │ ├── topic3_format13.json
│ │ ├── topic3_format14.json
│ │ ├── topic3_format15.json
│ │ ├── topic3_format16.json
│ │ ├── topic3_format17.json
│ │ ├── topic3_format18.json
│ │ ├── topic3_format19.json
│ │ ├── topic3_format2.json
│ │ ├── topic3_format20.json
│ │ ├── topic3_format21.json
│ │ ├── topic3_format22.json
│ │ ├── topic3_format23.json
│ │ ├── topic3_format3.json
│ │ ├── topic3_format4.json
│ │ ├── topic3_format5.json
│ │ ├── topic3_format6.json
│ │ ├── topic3_format7.json
│ │ ├── topic3_format8.json
│ │ ├── topic3_format9.json
│ │ ├── topic4.json
│ │ ├── topic4_format0.json
│ │ ├── topic4_format1.json
│ │ ├── topic4_format10.json
│ │ ├── topic4_format11.json
│ │ ├── topic4_format12.json
│ │ ├── topic4_format13.json
│ │ ├── topic4_format14.json
│ │ ├── topic4_format15.json
│ │ ├── topic4_format16.json
│ │ ├── topic4_format17.json
│ │ ├── topic4_format18.json
│ │ ├── topic4_format19.json
│ │ ├── topic4_format2.json
│ │ ├── topic4_format20.json
│ │ ├── topic4_format21.json
│ │ ├── topic4_format22.json
│ │ ├── topic4_format23.json
│ │ ├── topic4_format3.json
│ │ ├── topic4_format4.json
│ │ ├── topic4_format5.json
│ │ ├── topic4_format6.json
│ │ ├── topic4_format7.json
│ │ ├── topic4_format8.json
│ │ ├── topic4_format9.json
│ │ ├── topic5.json
│ │ ├── topic5_format0.json
│ │ ├── topic5_format1.json
│ │ ├── topic5_format10.json
│ │ ├── topic5_format11.json
│ │ ├── topic5_format12.json
│ │ ├── topic5_format13.json
│ │ ├── topic5_format14.json
│ │ ├── topic5_format15.json
│ │ ├── topic5_format16.json
│ │ ├── topic5_format17.json
│ │ ├── topic5_format18.json
│ │ ├── topic5_format19.json
│ │ ├── topic5_format2.json
│ │ ├── topic5_format20.json
│ │ ├── topic5_format21.json
│ │ ├── topic5_format22.json
│ │ ├── topic5_format23.json
│ │ ├── topic5_format3.json
│ │ ├── topic5_format4.json
│ │ ├── topic5_format5.json
│ │ ├── topic5_format6.json
│ │ ├── topic5_format7.json
│ │ ├── topic5_format8.json
│ │ ├── topic5_format9.json
│ │ ├── topic6.json
│ │ ├── topic6_format0.json
│ │ ├── topic6_format1.json
│ │ ├── topic6_format10.json
│ │ ├── topic6_format11.json
│ │ ├── topic6_format12.json
│ │ ├── topic6_format13.json
│ │ ├── topic6_format14.json
│ │ ├── topic6_format15.json
│ │ ├── topic6_format16.json
│ │ ├── topic6_format17.json
│ │ ├── topic6_format18.json
│ │ ├── topic6_format19.json
│ │ ├── topic6_format2.json
│ │ ├── topic6_format20.json
│ │ ├── topic6_format21.json
│ │ ├── topic6_format22.json
│ │ ├── topic6_format23.json
│ │ ├── topic6_format3.json
│ │ ├── topic6_format4.json
│ │ ├── topic6_format5.json
│ │ ├── topic6_format6.json
│ │ ├── topic6_format7.json
│ │ ├── topic6_format8.json
│ │ ├── topic6_format9.json
│ │ ├── topic7.json
│ │ ├── topic7_format0.json
│ │ ├── topic7_format1.json
│ │ ├── topic7_format10.json
│ │ ├── topic7_format11.json
│ │ ├── topic7_format12.json
│ │ ├── topic7_format13.json
│ │ ├── topic7_format14.json
│ │ ├── topic7_format15.json
│ │ ├── topic7_format16.json
│ │ ├── topic7_format17.json
│ │ ├── topic7_format18.json
│ │ ├── topic7_format19.json
│ │ ├── topic7_format2.json
│ │ ├── topic7_format20.json
│ │ ├── topic7_format21.json
│ │ ├── topic7_format22.json
│ │ ├── topic7_format23.json
│ │ ├── topic7_format3.json
│ │ ├── topic7_format4.json
│ │ ├── topic7_format5.json
│ │ ├── topic7_format6.json
│ │ ├── topic7_format7.json
│ │ ├── topic7_format8.json
│ │ ├── topic7_format9.json
│ │ ├── topic8.json
│ │ ├── topic8_format0.json
│ │ ├── topic8_format1.json
│ │ ├── topic8_format10.json
│ │ ├── topic8_format11.json
│ │ ├── topic8_format12.json
│ │ ├── topic8_format13.json
│ │ ├── topic8_format14.json
│ │ ├── topic8_format15.json
│ │ ├── topic8_format16.json
│ │ ├── topic8_format17.json
│ │ ├── topic8_format18.json
│ │ ├── topic8_format19.json
│ │ ├── topic8_format2.json
│ │ ├── topic8_format20.json
│ │ ├── topic8_format21.json
│ │ ├── topic8_format22.json
│ │ ├── topic8_format23.json
│ │ ├── topic8_format3.json
│ │ ├── topic8_format4.json
│ │ ├── topic8_format5.json
│ │ ├── topic8_format6.json
│ │ ├── topic8_format7.json
│ │ ├── topic8_format8.json
│ │ ├── topic8_format9.json
│ │ ├── topic9.json
│ │ ├── topic9_format0.json
│ │ ├── topic9_format1.json
│ │ ├── topic9_format10.json
│ │ ├── topic9_format11.json
│ │ ├── topic9_format12.json
│ │ ├── topic9_format13.json
│ │ ├── topic9_format14.json
│ │ ├── topic9_format15.json
│ │ ├── topic9_format16.json
│ │ ├── topic9_format17.json
│ │ ├── topic9_format18.json
│ │ ├── topic9_format19.json
│ │ ├── topic9_format2.json
│ │ ├── topic9_format20.json
│ │ ├── topic9_format21.json
│ │ ├── topic9_format22.json
│ │ ├── topic9_format23.json
│ │ ├── topic9_format3.json
│ │ ├── topic9_format4.json
│ │ ├── topic9_format5.json
│ │ ├── topic9_format6.json
│ │ ├── topic9_format7.json
│ │ ├── topic9_format8.json
│ │ └── topic9_format9.json
│ ├── formats.json
│ ├── statistics.json
│ └── topics.json
├── images
│ ├── ai2_logo.png
│ ├── icon.png
│ ├── mixtures_implicit.png
│ ├── mixtures_regmix.png
│ ├── pli_logo.svg
│ ├── princeton_logo.png
│ ├── results_main.png
│ ├── treemaps.png
│ ├── uc_berkeley_logo.png
│ └── uw_logo.png
└── js
│ └── treemaps.js
└── index.html
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .DS_Store
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | #
Organize the Web: Constructing Domains Enhances Pre-Training Data Curation
2 |
3 | [[Paper](https://arxiv.org/pdf/2502.10341.pdf)] [[Website](https://weborganizer.allen.ai)] [[Hugging Face](https://huggingface.co/WebOrganizer)]
4 |
5 |
6 |
7 | *Interactively explore these domains and examples of web pages they contain at [https://weborganizer.allen.ai](https://weborganizer.allen.ai)*
8 |
9 |
10 |
11 | ## Resources
12 |
13 | #### Domain Classifiers
14 | All our domain classifiers are available on Huggingface Hub. Our default domain classifiers use both the URL and web site content to make predictions. We also provide two additional models that only use the web site content and therefore can be applied to a wider variety of documents.
15 | 1. __Topic__: [WebOrganizer/TopicClassifier](https://huggingface.co/WebOrganizer/TopicClassifier) ([-NoURL version](https://huggingface.co/WebOrganizer/TopicClassifier-NoURL))
16 | 2. __Format__: [WebOrganizer/FormatClassifier](https://huggingface.co/WebOrganizer/FormatClassifier) ([-NoURL version](https://huggingface.co/WebOrganizer/FormatClassifier-NoURL))
17 |
18 | These domains classifiers are trained on the following datasets:
19 | 1. In a first stage, 1M web pages classifed by __Llama-3.1-8B__, available on HuggingFace Hub:
20 | * [WebOrganizer/TopicAnnotations-Llama-3.1-8B](https://huggingface.co/datasets/WebOrganizer/TopicAnnotations-Llama-3.1-8B)
21 | * [WebOrganizer/FormatAnnotations-Llama-3.1-8B](https://huggingface.co/datasets/WebOrganizer/FormatAnnotations-Llama-3.1-8B)
22 | 2. In a second stage, 100K pages classified by __Llama-3.1-405B-FP8__, available on HuggingFace Hub:
23 | * [WebOrganizer/TopicAnnotations-Llama-3.1-405B-FP8](https://huggingface.co/datasets/WebOrganizer/TopicAnnotations-Llama-3.1-405B-FP8)
24 | * [WebOrganizer/FormatAnnotations-Llama-3.1-405B-FP8](https://huggingface.co/datasets/WebOrganizer/FormatAnnotations-Llama-3.1-405B-FP8)
25 |
26 | The __topic and format definitions__ and instructions for prompting large language models to classify documents are available in `define_domains/taxonomies`. The script for prompting models is `define_domains/prompt_classify.sh`. The 1M web pages were randomly sampled from DCLM RefinedWeb.
27 |
28 |
29 | #### Corpus Annotations
30 | We pre-process the `1b-1x` pool from DataComps-LM using [RefinedWeb filters](https://github.com/mlfoundations/dclm/blob/main/baselines/baselines_configs/dclm_baseline_refinedweb.yaml) and [BFF deduplication](https://github.com/mlfoundations/dclm/tree/main/dedup/bff).
31 | The resulting 200B token corpus is available at, together with the annotations: [WebOrganizer/Corpus-200B](https://huggingface.co/datasets/WebOrganizer/Corpus-200B).
32 |
33 | __Download the dataset by cloning the repository with Git LFS instead of HuggingFace's `load_dataset()`.__
34 |
35 | The dataset has the following folder structure:
36 | ```bash
37 | Corpus-200B/
38 | documents/ # Pre-processed web documents
39 | - CC_shard_00000000_processed.jsonl.zst
40 | - CC_shard_00000001_processed.jsonl.zst
41 | - ...
42 | tokens/ # number of tokens per document (GPT-NeoX tokenizer)
43 | - CC_shard_00000000_processed.npy
44 | - CC_shard_00000001_processed.npy
45 | - ...
46 | scores_dclm-fasttext/ # DCLM-fasttext score
47 | - CC_shard_00000000_processed.npy
48 | - ...
49 | scores_fineweb-edu/ # FineWeb-Edu score
50 | - CC_shard_00000000_processed.npy
51 | - ...
52 | scores_fineweb-edu__rounded/ # Rounded FineWeb-Edu score
53 | - CC_shard_00000000_processed__rounded.npy
54 | - ...
55 | domains_topics/ # TopicClassifier annotations
56 | - CC_shard_00000000_processed__choice.npy # index of top choice
57 | - ...
58 | domain_topics__logits/
59 | - CC_shard_00000000_processed__logits.npy # logits for each topic
60 | - ...
61 | domains_formats/ # FormatClassifier annotations
62 | - CC_shard_00000000_processed__choice.npy # index of top choice
63 | - ...
64 | domains_formats/ # FormatClassifier annotations
65 | - CC_shard_00000000_processed__logits.npy # logits for each format
66 | - ...
67 | domains_clusters-k24/ # K-means clusters
68 | - CC_shard_00000000_processed.npy # cluster assignment for each document
69 | - ...
70 | ```
71 | We also include statistics about the presence and co-occurence of domains in the `domain_statistics/` folder, computed with the `domain_statistics.py` script.
72 |
73 | ## Installation
74 | Different steps in this repository require different dependencies:
75 |
76 | * __Data pre-processing__: *coming soon*
77 | ```bash
78 | # install datatools and gte...
79 | ```
80 |
81 | * __K-means clustering__: The code in `define_domains/k-means-clustering` is a fork of [facebookresearch/ssl-data-curation](https://github.com/facebookresearch/ssl-data-curation/tree/main). Please read the README in the this directory for installation instructions and to see our modifications.
82 |
83 | * __DataComps-LM tokenization and training__: Please refer to the [DataComps-LM repository](https://github.com/mlfoundations/dclm) for instructions on how to tokenize and train models for DataComps-LM.
84 |
85 |
86 | ## Training New Domain Classifiers
87 | You can define a new taxonomy config in `define_domains/taxonomies` and then train a new domain classifier using the `define_domains/prompt_classify.sh` script.
88 | To distill the Llama annotations into a new domain classifier, use the `define_domains/train_classifier.sh` script and pass the new training dataset as a script option. For two stage training, simply run the training script twice with different training datasets, and initialize the second stage with the model checkpoint from the first stage.
89 |
90 | ## Annotating Data
91 | The script `annotate_data/annotate.sh` does large-scale data annotation using a slurm job array to iterate through the document shards in the `Corpus-200B` folder, and annotate each document with quality and domain annotations, which are stored as numpy arrays in separate annotation folders.
92 |
93 | ## Predict a Training Distribution with RegMix
94 | *Coming soon...*
95 |
96 | ## Selecting Training Data for Language Models
97 | `select_training_data.py` uses the folder structure of the `Corpus-200B` and used by the annotation scripts to select training data for language models.
98 |
99 | Example usage:
100 | ```python
101 | python select_training_data.py \
102 | --input_base "datasets/Corpus-200B" \
103 | --output_base "datasets/selected/Baseline-30B" \
104 | --num_tokens 30000000000 \
105 | --do_sample \
106 | --num_proc 16
107 | ```
108 |
109 | It supports various options for quality filtering and domain mixing and uses multiple workers to write data in parallel.
110 | The script first writes indices for each document shard in the `Corpus-200B` folder and then uses multiple workers to write the data in parallel.
111 | You can use the `domain_statistics.py` script to summarize the domain distribution of datasets and use these for selecting training data by passing them to `--ref_distribution `.
112 |
113 | The folder of selected documents can then be used with the tokenization and training scripts from the [DCLM repository](https://github.com/mlfoundations/dclm) to train a new language model.
114 |
115 |
116 | ## Citation
117 | ```bibtex
118 | @article{wettig2025organize,
119 | title={Organize the Web: Constructing Domains Enhances Pre-Training Data Curation},
120 | author={Alexander Wettig and Kyle Lo and Sewon Min and Hannaneh Hajishirzi and Danqi Chen and Luca Soldaini},
121 | journal={arXiv preprint arXiv:2502.10341},
122 | year={2025}
123 | }
124 | ```
125 |
--------------------------------------------------------------------------------
/annotate_data/annotate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH -J annotate
3 | #SBATCH -N 1 -c 9 --gres=gpu:1 --mem=72G
4 | #SBATCH --output=slurm/%x-%A_%a.out
5 | #SBATCH -t 0-24
6 | #SBATCH -a 0-31
7 |
8 | # Point to the root data directory
9 | # Each job will iterate through a subset of files in $DATA_ROOT/$DOCUMENTS_DIR
10 | # and annotate them with quality scores and domains
11 | data_root=${DATA_ROOT:-}
12 | documents_dir=${DOCUMENTS_DIR:-"documents"}
13 |
14 | # Use WORKER/NUM_WORKERS env variables, slurm array variables or default to 0/1
15 | num_workers=${NUM_WORKERS:-${SLURM_ARRAY_TASK_COUNT:-1}}
16 | worker=${WORKER:-${SLURM_ARRAY_TASK_ID:-0}}
17 |
18 | files=( $(ls -1 "$data_root/$documents_dir" | .jsonl.zst ) )
19 | num_files=${#files[@]}
20 |
21 | # Iterate through files for this work
22 | for id in $(jq -n "range($worker; $num_files; $num_workers)"); do
23 | file=${files[$id]}
24 | output_file=${file%%.*}
25 |
26 | # Tokenize data and compute length
27 | python tokens.py \
28 | $data_root/$documents_dir/$file \
29 | $data_root/tokens/$output_file
30 |
31 | # Compute DCLM-fasttext scores
32 | python fasttext.py \
33 | $data_root/$documents_dir/$file \
34 | $data_root/scores_dclm-fasttext/$output_file \
35 | --model_path
36 |
37 |
38 | # ^ The two scripts above do not make use of a GPU and should be run separately
39 | # Everything below is accelerated a lot with GPUs
40 |
41 | # Compute FineWeb-Edu scores
42 | python edu.py \
43 | $data_root/$documents_dir/$file \
44 | $data_root/scores_fineweb-edu/$output_file \
45 | --model_name HuggingFaceTB/fineweb-edu-classifier
46 |
47 | # Compute Topic and Format domains
48 | python domains.py \
49 | $data_root/$documents_dir/$file \
50 | $data_root/domains_topics/$output_file \
51 | --model_name WebOrganizer/WebOrganizer-TopicClassifier
52 | python domains.py \
53 | $data_root/$documents_dir/$file \
54 | $data_root/domains_formats/$output_file \
55 | --model_name WebOrganizer/WebOrganizer-FormatClassifier
56 |
57 | # # For annotating kmeans clusters
58 | # python embed.py \
59 | # $data_root/$documents_dir/$file \
60 | # $data_root/embeds/$output_file
61 | # python clusters.py \
62 | # $data_root/embeds/${output_file}.npy \
63 | # $data_root/domains_clusters-k24/$output_file \
64 | # --clustering_folder ../define_domains/k-means-clustering/exps/dclm-k24
65 | done
66 |
--------------------------------------------------------------------------------
/annotate_data/clusters.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import numpy as np
3 | import torch
4 | from dataclasses import dataclass
5 | from functools import partial
6 | from pathlib import Path
7 | from typing import Optional
8 |
9 | from datatools.process import process, ProcessOptions
10 | from datatools.load import load, LoadOptions
11 | from simple_parsing import ArgumentParser
12 |
13 |
14 | @dataclass
15 | class ScriptOptions:
16 | clustering_folder: Path
17 | batch_size: int = 4192
18 | device: str = "cpu"
19 |
20 |
21 | def assign_clusters(dataset, indices, process_id, options):
22 | centroids_paths = sorted(options.clustering_folder.glob("level*/centroids.npy"))
23 |
24 | centroids_by_level = [torch.tensor(np.load(centroids_path)).to(options.device) for centroids_path in centroids_paths]
25 |
26 | for i in tqdm(range(0, len(dataset), options.batch_size), disable = process_id != 0):
27 | batch = [dataset[j] for j in range(i, min(i + options.batch_size, len(dataset)))]
28 | embeddings = torch.tensor(np.stack(batch)).to(options.device)
29 |
30 | assignments_by_level = []
31 |
32 | for centroids in centroids_by_level:
33 | # Compute distances
34 | distances = torch.cdist(embeddings, centroids)
35 |
36 | # Get cluster assignments
37 | cluster_ids = torch.argmin(distances, dim=1)
38 | assignments_by_level.append(cluster_ids.cpu().numpy())
39 |
40 | embeddings = centroids[cluster_ids]
41 |
42 | for cluster_id_by_level in zip(*assignments_by_level):
43 | if len(cluster_id_by_level) == 1:
44 | yield {
45 | "": cluster_id_by_level[0]
46 | }
47 | else:
48 | yield {
49 | f"level{i+1}": cluster_id
50 | for i, cluster_id in enumerate(cluster_id_by_level)
51 | }
52 |
53 |
54 | if __name__ == "__main__":
55 | parser = ArgumentParser()
56 |
57 | parser.add_argument("inputs", type=Path, nargs="+", help="Input embeds paths")
58 | parser.add_argument("output", type=Path, help="Output dataset path")
59 |
60 | parser.add_arguments(ScriptOptions, dest="script_options")
61 | parser.add_arguments(LoadOptions, dest="load_options")
62 | parser.add_arguments(ProcessOptions, dest="process_options")
63 |
64 | args = parser.parse_args()
65 | args.process_options.ndarray = True
66 |
67 | print("Arguments:", args)
68 | dataset = load(*args.inputs, options=args.load_options)
69 | N = len(dataset)
70 | print(f"Loaded dataset with {N} samples")
71 |
72 | process(dataset, partial(assign_clusters, options=args.script_options), args.output, args.process_options)
--------------------------------------------------------------------------------
/annotate_data/domains.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from functools import partial
3 | from pathlib import Path
4 | from tqdm import tqdm
5 |
6 | from transformers import AutoModelForSequenceClassification, AutoTokenizer
7 | import torch
8 | from torch.utils.data import DataLoader
9 | import numpy
10 |
11 | from datatools.process import process, ProcessOptions
12 | from datatools.load import load, LoadOptions
13 | from simple_parsing import ArgumentParser, field
14 | from typing import Dict, Any
15 |
16 | @dataclass
17 | class EmbedOptions:
18 | model_name: str
19 | batch_size: int = 128
20 | num_dataloader_workers: int = 8
21 | max_length: int = 8192
22 | input_template: str = """{url}
23 |
24 | {text}"""
25 |
26 |
27 | class DataCollator:
28 | def __init__(self, tokenizer, options):
29 | self.tokenizer = tokenizer
30 | self.options = options
31 |
32 | @torch.no_grad()
33 | def __call__(self, features):
34 | documents = [self.options.input_template.format(**f) for f in features]
35 | return self.tokenizer(documents, padding=True, truncation=True, return_tensors='pt', max_length=self.options.max_length)
36 |
37 |
38 | def load_model_and_tokenizer(options):
39 | tokenizer = AutoTokenizer.from_pretrained(options.model_name)
40 | model = AutoModelForSequenceClassification.from_pretrained(options.model_name, trust_remote_code=True)
41 | return model, tokenizer
42 |
43 |
44 |
45 | @torch.inference_mode()
46 | def predict_fn(subset, indices, process_id, options):
47 |
48 | model, tokenizer = load_model_and_tokenizer(options)
49 | model.to(torch.bfloat16)
50 | model.cuda()
51 | model.eval()
52 |
53 | data_loader = DataLoader(subset,
54 | batch_size=options.batch_size,
55 | collate_fn=DataCollator(tokenizer, options),
56 | num_workers=options.num_dataloader_workers,
57 | prefetch_factor=4,
58 | pin_memory=True,
59 | shuffle=False)
60 |
61 | for batch in tqdm(data_loader, disable=(process_id != 0)):
62 | for key in batch:
63 | batch[key] = batch[key].cuda()
64 |
65 | model_output = model(**batch)
66 |
67 | logits = model_output.logits.float().cpu()
68 | choices = logits.argmax(axis=-1).cpu()
69 |
70 | for seq_logits, seq_choice in zip(logits, choices):
71 | yield {"logits": seq_logits.numpy(), "choice": seq_choice.item()}
72 |
73 |
74 | if __name__ == "__main__":
75 | parser = ArgumentParser()
76 |
77 | parser.add_argument("inputs", type=Path, nargs="+", help="Input dataset paths")
78 | parser.add_argument("output", type=Path, help="Output dataset path")
79 |
80 | parser.add_arguments(EmbedOptions, dest="embed_options")
81 | parser.add_arguments(LoadOptions, dest="load_options")
82 | parser.add_arguments(ProcessOptions, dest="process_options")
83 |
84 | args = parser.parse_args()
85 | args.process_options.ndarray = True
86 |
87 | print("Arguments:", args)
88 | dataset = load(*args.inputs, options=args.load_options)
89 | N = len(dataset)
90 | print(f"Loaded dataset with {N} samples")
91 |
92 | process(
93 | dataset,
94 | partial(predict_fn, options=args.embed_options),
95 | args.output,
96 | args.process_options
97 | )
98 |
--------------------------------------------------------------------------------
/annotate_data/edu.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from functools import partial
3 | from pathlib import Path
4 | from tqdm import tqdm
5 |
6 | from transformers import AutoModelForSequenceClassification, AutoTokenizer
7 | import torch
8 | from torch.utils.data import DataLoader
9 | import numpy
10 |
11 | from datatools.process import process, ProcessOptions
12 | from datatools.load import load, LoadOptions
13 | from simple_parsing import ArgumentParser, field
14 | from typing import Dict, Any
15 |
16 | @dataclass
17 | class EmbedOptions:
18 | model_name: str = "HuggingFaceTB/fineweb-edu-classifier"
19 | batch_size: int = 128
20 | num_dataloader_workers: int = 8
21 | max_length: int = 512
22 | input_template: str = "{text}"
23 |
24 |
25 | class DataCollator:
26 | def __init__(self, tokenizer, options):
27 | self.tokenizer = tokenizer
28 | self.options = options
29 |
30 | @torch.no_grad()
31 | def __call__(self, features):
32 | documents = [self.options.input_template.format(**f) for f in features]
33 | return self.tokenizer(documents, padding=True, truncation=True, return_tensors='pt', max_length=self.options.max_length)
34 |
35 |
36 | def load_model_and_tokenizer(options):
37 | tokenizer = AutoTokenizer.from_pretrained(options.model_name)
38 | model = AutoModelForSequenceClassification.from_pretrained(options.model_name)
39 | return model, tokenizer
40 |
41 |
42 |
43 | @torch.inference_mode()
44 | def predict_fn(subset, indices, process_id, options):
45 |
46 | model, tokenizer = load_model_and_tokenizer(options)
47 | model.to(torch.bfloat16)
48 | model.cuda()
49 | model.eval()
50 |
51 | data_loader = DataLoader(subset,
52 | batch_size=options.batch_size,
53 | collate_fn=DataCollator(tokenizer, options),
54 | num_workers=options.num_dataloader_workers,
55 | prefetch_factor=4,
56 | pin_memory=True,
57 | shuffle=False)
58 |
59 | for batch in tqdm(data_loader, disable=(process_id != 0)):
60 | for key in batch:
61 | batch[key] = batch[key].cuda()
62 |
63 | model_output = model(**batch)
64 |
65 | scores = model_output.logits.squeeze(-1).float().cpu().detach().numpy()
66 |
67 | for seq_score in scores:
68 | yield {
69 | "": seq_score.item(),
70 | "rounded": int(round(max(0, min(seq_score.item(), 5))))
71 | }
72 |
73 |
74 | if __name__ == "__main__":
75 | parser = ArgumentParser()
76 |
77 | parser.add_argument("inputs", type=Path, nargs="+", help="Input dataset paths")
78 | parser.add_argument("output", type=Path, help="Output dataset path")
79 |
80 | parser.add_arguments(EmbedOptions, dest="embed_options")
81 | parser.add_arguments(LoadOptions, dest="load_options")
82 | parser.add_arguments(ProcessOptions, dest="process_options")
83 |
84 | args = parser.parse_args()
85 | args.process_options.ndarray = True
86 |
87 | print("Arguments:", args)
88 | dataset = load(*args.inputs, options=args.load_options)
89 | N = len(dataset)
90 | print(f"Loaded dataset with {N} samples")
91 |
92 | process(
93 | dataset,
94 | partial(predict_fn, options=args.embed_options),
95 | args.output,
96 | args.process_options
97 | )
98 |
--------------------------------------------------------------------------------
/annotate_data/embed.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from functools import partial
3 | from pathlib import Path
4 | from tqdm import tqdm
5 |
6 | from transformers import AutoModel, AutoTokenizer
7 | import torch
8 | from torch.utils.data import DataLoader
9 | import numpy
10 |
11 | from datatools.process import process, ProcessOptions
12 | from datatools.load import load, LoadOptions
13 | from simple_parsing import ArgumentParser, field
14 | from typing import Dict, Any
15 |
16 | @dataclass
17 | class EmbedOptions:
18 | model_name: str = "Alibaba-NLP/gte-base-en-v1.5"
19 | batch_size: int = 128
20 | num_dataloader_workers: int = 8
21 | pooling_strategy: str = "cls"
22 | normalize_embeddings: bool = True
23 | max_length: int = 8192
24 | input_template: str = "{text}"
25 |
26 |
27 | class DataCollator:
28 | def __init__(self, tokenizer, options):
29 | self.tokenizer = tokenizer
30 | self.options = options
31 |
32 | @torch.no_grad()
33 | def __call__(self, features):
34 | documents = [self.options.input_template.format(**f) for f in features]
35 | return self.tokenizer(documents, padding=True, truncation=True, return_tensors='pt', max_length=self.options.max_length)
36 |
37 |
38 | @torch.inference_mode()
39 | def pooling(model_output, attention_mask, pooling_strategy):
40 | if pooling_strategy == "cls":
41 | return model_output.last_hidden_state[:, 0].float()
42 | elif pooling_strategy == "mean":
43 | token_embeddings = model_output[0]
44 | input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
45 | return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
46 |
47 |
48 | def load_model_and_tokenizer(options):
49 | if options.model_name.startswith("nomic-ai/"):
50 | try:
51 | from contrastors.models.encoder.modeling_nomic_bert import NomicBertModel
52 | except:
53 | raise ImportError("Could not import NomicBertModel. Please install the https://github.com/nomic-ai/contrastors in this folder")
54 |
55 | tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
56 | model = NomicBertModel.from_pretrained('nomic-ai/nomic-embed-text-v1', add_pooling_layer=False)
57 | else:
58 | tokenizer = AutoTokenizer.from_pretrained(options.model_name)
59 | model = AutoModel.from_pretrained(options.model_name, trust_remote_code=True)
60 | return model, tokenizer
61 |
62 |
63 |
64 | @torch.inference_mode()
65 | def predict_fn(subset, indices, process_id, options):
66 |
67 | model, tokenizer = load_model_and_tokenizer(options)
68 | model.to(torch.bfloat16)
69 | model.cuda()
70 | model.eval()
71 |
72 | data_loader = DataLoader(subset,
73 | batch_size=options.batch_size,
74 | collate_fn=DataCollator(tokenizer, options),
75 | num_workers=options.num_dataloader_workers,
76 | prefetch_factor=4,
77 | pin_memory=True,
78 | shuffle=False)
79 |
80 | for batch in tqdm(data_loader, disable=(process_id != 0)):
81 | for key in batch:
82 | batch[key] = batch[key].cuda()
83 |
84 | model_output = model(**batch)
85 | embeddings = pooling(model_output, batch['attention_mask'], options.pooling_strategy)
86 |
87 | if options.normalize_embeddings:
88 | embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
89 |
90 | for embed in embeddings:
91 | yield {"": embed.cpu().numpy()}
92 |
93 |
94 | if __name__ == "__main__":
95 | parser = ArgumentParser()
96 |
97 | parser.add_argument("inputs", type=Path, nargs="+", help="Input dataset paths")
98 | parser.add_argument("output", type=Path, help="Output dataset path")
99 |
100 | parser.add_arguments(EmbedOptions, dest="embed_options")
101 | parser.add_arguments(LoadOptions, dest="load_options")
102 | parser.add_arguments(ProcessOptions, dest="process_options")
103 |
104 | args = parser.parse_args()
105 | args.process_options.ndarray = True
106 |
107 | print("Arguments:", args)
108 | dataset = load(*args.inputs, options=args.load_options)
109 | N = len(dataset)
110 | print(f"Loaded dataset with {N} samples")
111 |
112 | process(
113 | dataset,
114 | partial(predict_fn, options=args.embed_options),
115 | args.output,
116 | args.process_options
117 | )
118 |
--------------------------------------------------------------------------------
/annotate_data/fasttext.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, List, Callable
3 | from functools import partial
4 |
5 | import fasttext
6 | from tqdm import tqdm
7 |
8 | from pathlib import Path
9 |
10 | from simple_parsing import ArgumentParser, field
11 | from datatools.process import process, ProcessOptions
12 | from datatools.load import load, LoadOptions
13 |
14 |
15 | def classify_fasttext_hq_prob(model: fasttext.FastText._FastText, content: str):
16 | # Clean the input text by joining all lines into a single string
17 | text = " ".join(content.strip().splitlines())
18 |
19 | # Make the prediction
20 | pred = model.predict(text)
21 |
22 | # Extract the predicted label and its probability
23 | (pred_label, pred_prob) = pred
24 | pred_label = pred_label[0]
25 | hq_prob = pred_prob[0]
26 |
27 | # If the predicted label is 'CC', adjust the probability of it being 'Wikipedia'
28 | if pred_label == "__label__cc":
29 | hq_prob = 1 - hq_prob
30 |
31 | # Return the output
32 | return hq_prob
33 |
34 |
35 | def predict_fn(dataset, indices, process_id, model_path, text_field="text"):
36 | model = fasttext.load_model(model_path)
37 |
38 | for i in tqdm(range(len(dataset)), disable=(process_id != 0)):
39 | text = dataset[i][text_field]
40 | hq_prob = classify_fasttext_hq_prob(model, text)
41 | yield {"": hq_prob}
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = ArgumentParser(add_config_path_arg=True)
46 |
47 | parser.add_argument("inputs", type=Path, nargs="+", help="Input dataset paths")
48 | parser.add_argument("output", type=Path, help="Output dataset path")
49 |
50 |
51 | parser.add_argument("--text_field", type=str, default="text", help="Name of the field containing the text to classify")
52 | parser.add_argument("--model_path", type=str, default="fasttext_oh_eli5.bin", help="Path to the FastText model")
53 |
54 | parser.add_arguments(LoadOptions, dest="load_options")
55 | parser.add_arguments(ProcessOptions, dest="process_options")
56 |
57 | args = parser.parse_args()
58 | args.process_options.ndarray = True
59 |
60 | print("Arguments:", args)
61 | dataset = load(*args.inputs, options=args.load_options)
62 | N = len(dataset)
63 | print(f"Loaded dataset with {N} samples")
64 |
65 |
66 | process(
67 | dataset,
68 | partial(
69 | predict_fn,
70 | text_field=args.text_field,
71 | model_path=args.model_path
72 | ),
73 | args.output, args.process_options
74 | )
75 |
--------------------------------------------------------------------------------
/annotate_data/perplexity.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from functools import partial
3 | from pathlib import Path
4 | from tqdm import tqdm
5 |
6 | from transformers import AutoModelForCausalLM, AutoTokenizer
7 | import torch
8 | from torch.utils.data import DataLoader
9 | import numpy
10 |
11 | from datatools.process import process, ProcessOptions
12 | from datatools.load import load, LoadOptions
13 | from simple_parsing import ArgumentParser, field
14 | from typing import Dict, Any
15 |
16 |
17 | @dataclass
18 | class PerplexityOptions:
19 | model_name: str = "EleutherAI/pythia-160m"
20 | batch_size: int = 32
21 | num_dataloader_workers: int = 8
22 | max_length: int = 2048
23 |
24 |
25 | class DataCollator:
26 | def __init__(self, max_length):
27 | self.max_length = max_length
28 |
29 | @torch.no_grad()
30 | def __call__(self, features):
31 | bsz = len(features)
32 | seqs = [features[i]["input_ids"] for i in range(bsz)]
33 | max_length = min(max(len(seq) for seq in seqs), self.max_length)
34 |
35 | input_ids = torch.zeros(bsz, max_length, dtype=torch.long)
36 | attention_mask = torch.zeros(bsz, max_length, dtype=torch.long)
37 |
38 | for i, seq in enumerate(seqs):
39 | seq = seq[:max_length]
40 | input_ids[i, :len(seq)] = torch.tensor(seq)
41 | attention_mask[i, :len(seq)] = 1
42 |
43 | return {
44 | "input_ids": input_ids,
45 | "attention_mask": attention_mask
46 | }
47 |
48 |
49 | @torch.inference_mode()
50 | def predict_fn(subset, indices, process_id, options):
51 | model = AutoModelForCausalLM.from_pretrained(options.model_name, attn_implementation="flash_attention_2")
52 | model.to(torch.bfloat16)
53 | model.cuda()
54 | model.eval()
55 |
56 | data_loader = DataLoader(subset,
57 | batch_size=options.batch_size,
58 | collate_fn=DataCollator(options.max_length),
59 | num_workers=options.num_dataloader_workers,
60 | prefetch_factor=4,
61 | pin_memory=True,
62 | shuffle=False)
63 |
64 | for batch in tqdm(data_loader, disable=(process_id != 0)):
65 | input_ids = batch["input_ids"].cuda()
66 | attention_mask = batch["attention_mask"].cuda()
67 |
68 | logits = model(input_ids=input_ids, attention_mask=attention_mask).logits[:, :-1].float()
69 | labels = torch.where(attention_mask == 1, input_ids, torch.zeros_like(input_ids) - 100)[:, 1:]
70 | seq_lens = attention_mask.sum(1)
71 |
72 | seq_losses = torch.nn.functional.cross_entropy(logits.transpose(1, 2), labels, reduction='none').sum(1) / seq_lens
73 |
74 | for seq_loss in seq_losses :
75 | yield {"": seq_loss.cpu().numpy()}
76 |
77 |
78 | if __name__ == "__main__":
79 | parser = ArgumentParser()
80 |
81 | parser.add_argument("inputs", type=Path, nargs="+", help="Input dataset paths")
82 | parser.add_argument("output", type=Path, help="Output dataset path")
83 |
84 | parser.add_arguments(PerplexityOptions, dest="embed_options")
85 | parser.add_arguments(LoadOptions, dest="load_options")
86 | parser.add_arguments(ProcessOptions, dest="process_options")
87 |
88 | args = parser.parse_args()
89 | args.process_options.ndarray = True
90 |
91 | print("Arguments:", args)
92 | dataset = load(*args.inputs, options=args.load_options)
93 | N = len(dataset)
94 | print(f"Loaded dataset with {N} samples")
95 |
96 | process(
97 | dataset,
98 | partial(predict_fn, options=args.embed_options),
99 | args.output,
100 | args.process_options
101 | )
102 |
--------------------------------------------------------------------------------
/annotate_data/tokens.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, List, Callable
3 | from functools import partial
4 |
5 | from tqdm import tqdm
6 |
7 | from pathlib import Path
8 |
9 | from simple_parsing import ArgumentParser, field
10 | from datatools.process import process, ProcessOptions
11 | from datatools.load import load, LoadOptions
12 | from transformers import AutoTokenizer
13 | import numpy as np
14 |
15 | def predict_fn(dataset, indices, process_id, tokenizer_name, text_field="text"):
16 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
17 |
18 | for i in tqdm(range(len(dataset)), disable=(process_id != 0)):
19 | text = dataset[i][text_field]
20 | tokens = tokenizer.encode(text)
21 | num_tokens = len(tokens) + 1 # + 1 for
22 | yield {
23 | "": num_tokens,
24 | # "bin": np.clip(np.log2(num_tokens).astype(int), 6, 11)
25 | }
26 |
27 |
28 | if __name__ == "__main__":
29 | parser = ArgumentParser(add_config_path_arg=True)
30 |
31 | parser.add_argument("inputs", type=Path, nargs="+", help="Input dataset paths")
32 | parser.add_argument("output", type=Path, help="Output dataset path")
33 |
34 | parser.add_argument("--text_field", type=str, default="text", help="Name of the field containing the text to classify")
35 | parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b", help="Path to the FastText model")
36 |
37 | parser.add_arguments(LoadOptions, dest="load_options")
38 | parser.add_arguments(ProcessOptions, dest="process_options")
39 |
40 | args = parser.parse_args()
41 | args.process_options.ndarray = True
42 |
43 | print("Arguments:", args)
44 | dataset = load(*args.inputs, options=args.load_options)
45 | N = len(dataset)
46 | print(f"Loaded dataset with {N} samples")
47 |
48 |
49 | process(
50 | dataset,
51 | partial(
52 | predict_fn,
53 | text_field=args.text_field,
54 | tokenizer_name=args.tokenizer
55 | ),
56 | args.output, args.process_options
57 | )
58 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when there is a
56 | reasonable belief that an individual's behavior may have a negative impact on
57 | the project or its community.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported by contacting the project team at . All
63 | complaints will be reviewed and investigated and will result in a response that
64 | is deemed necessary and appropriate to the circumstances. The project team is
65 | obligated to maintain confidentiality with regard to the reporter of an incident.
66 | Further details of specific enforcement policies may be posted separately.
67 |
68 | Project maintainers who do not follow or enforce the Code of Conduct in good
69 | faith may face temporary or permanent repercussions as determined by other
70 | members of the project's leadership.
71 |
72 | ## Attribution
73 |
74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76 |
77 | [homepage]: https://www.contributor-covenant.org
78 |
79 | For answers to common questions about this code of conduct, see
80 | https://www.contributor-covenant.org/faq
81 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to this repository
2 | We want to make contributing to this project as easy and transparent as
3 | possible.
4 |
5 | ## Pull Requests
6 | We actively welcome your pull requests.
7 |
8 | 1. Fork the repo and create your branch from `main`.
9 | 2. If you've added code that should be tested, add tests.
10 | 3. If you've changed APIs, update the documentation.
11 | 4. Ensure the test suite passes.
12 | 5. Make sure your code lints.
13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14 |
15 | ## Contributor License Agreement ("CLA")
16 | In order to accept your pull request, we need you to submit a CLA. You only need
17 | to do this once to work on any of Meta's open source projects.
18 |
19 | Complete your CLA here:
20 |
21 | ## Issues
22 | We use GitHub issues to track public bugs. Please ensure your description is
23 | clear and has sufficient instructions to be able to reproduce the issue.
24 |
25 | Meta has a [bounty program](https://bugbounty.meta.com/) for the safe
26 | disclosure of security bugs. In those cases, please go through the process
27 | outlined on that page and do not file a public issue.
28 |
29 | ## License
30 | By contributing to this repository, you agree that your contributions will be licensed
31 | under the LICENSE file in the root directory of this source tree.
32 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/README.md:
--------------------------------------------------------------------------------
1 | # Automatic Data Curation for Self-Supervised Learning: A Clustering-Based Approach
2 |
3 | __*This repository has been adapted from the [ssl-data-curation](https://github.com/facebookresearch/ssl-data-curation) repository.
4 | We have added functionality to read embeddings from many numpy files and MDS datasets. We have also changed the multi-node slurm implementation to make use of torchrun on each node. The experimental scripts for the paper can be found in `exps/*/level1/slurm_script.s`.*__
5 |
6 | **[FAIR at Meta](https://ai.facebook.com/research/)**
7 |
8 | *Huy V. Vo,
9 | Vasil Khalidov,
10 | Timothée Darcet,
11 | Théo Moutakanni,
12 | Nikita Smetanin,
13 | Marc Szafraniec,
14 | Hugo Touvron,
15 | Camille Couprie,
16 | Maxime Oquab,
17 | Armand Joulin,
18 | Hervé Jégou,
19 | Patrick Labatut,
20 | Piotr Bojanowski*
21 |
22 | PyTorch implementation for the data curation pipeline with hierarchical k-means. For more detail, see the paper **[Automatic Data Curation for Self-Supervised Learning: A Clustering-Based Approach](https://arxiv.org/abs/2405.15613)**.
23 |
24 |
25 |
26 |
27 |
28 | ## Contents
29 | - [Installation](#installation)
30 | - [Running hierarchical k-means](#running-hierarchical-k-means)
31 | * [On small data](#on-small-data)
32 | * [On large data](#on-large-data)
33 | - [Notebook](#notebook)
34 | - [Contributing](#contributing)
35 | - [License](#license)
36 | - [Citation](#citation)
37 |
38 | ## Installation
39 | ```
40 | git clone git@github.com:facebookresearch/ssl-data-curation.git
41 | cd ssl-data-curation
42 | conda create -n ssl-data-curation python=3.10
43 | conda activate ssl-data-curation
44 | pip install -r requirements.txt
45 | ```
46 |
47 | ## Running hierarchical k-means
48 | ### On small data
49 | We provide below an example of a 2-level hierarchical k-means on a small toy random dataset. We first run hierarchical k-means on the toy dataset then sample 1000 points from it with hierarchical sampling. A visualisation is provided in [vis/notebook.ipynb](vis/notebook.ipynb).
50 | ```
51 | import torch
52 | import numpy as np
53 |
54 | from src.clusters import HierarchicalCluster
55 | from src import (
56 | hierarchical_kmeans_gpu as hkmg,
57 | hierarchical_sampling as hs
58 | )
59 |
60 | def make_ring(n, rmin, rmax):
61 | r = np.random.rand(n) * (rmax - rmin) + rmin
62 | alpha = np.random.rand(n) * 2 * np.pi
63 | return np.vstack([r * np.cos(alpha), r * np.sin(alpha)]).T
64 |
65 | data = np.concatenate([
66 | make_ring(20000, 0.7, 1.0) + np.array([-2.2, 1.]),
67 | make_ring(200, 0.7, 1.0) + np.array([0., 1.]),
68 | make_ring(1000, 0.7, 1.0) + np.array([2.2, 1.]),
69 | make_ring(500, 0.7, 1.0) + np.array([-1.2, 0.2]),
70 | make_ring(8000, 0.7, 1.0) + np.array([1.2, 0.2]),
71 | ])
72 |
73 | clusters = hkmg.hierarchical_kmeans_with_resampling(
74 | data=torch.tensor(data, device="cuda", dtype=torch.float32),
75 | n_clusters=[1000, 300],
76 | n_levels=2,
77 | sample_sizes=[15, 2],
78 | verbose=False,
79 | )
80 |
81 | cl = HierarchicalCluster.from_dict(clusters)
82 | sampled_indices = hs.hierarchical_sampling(cl, target_size=1000)
83 | ```
84 |
85 |
86 |
87 |
88 |
89 | ### On large data
90 | To launch hierarchical k-means on large data, we need to prepare a config file. We provide below an example illustrating how to launch a 2-level hierarchical k-means on random embeddings with config in [configs/2levels_random_embeddings.yaml](configs/2levels_random_embeddings.yaml).
91 | ```
92 | # Prepare the experiment
93 | cd ssl-data-curation
94 | mkdir -p data
95 | cd scripts
96 | python -c 'import numpy as np; np.save( "../data/100k_random.npy", np.random.randn(100000,256))'
97 | python hierarchical_kmeans_launcher.py \
98 | --exp_dir ../data/2levels_random_embeddings \
99 | --embeddings_path ../data/100k_random.npy \
100 | --config_file ../configs/2levels_random_embeddings.yaml
101 |
102 | cd ../data/2levels_random_embeddings
103 | # Launch with slurm
104 | bash launcher.sh
105 | # Launch locally if only 1 node is used
106 | # bash local_launcher.sh
107 |
108 | cd ssl-data-curation/scripts
109 | # Sampled indices will be saved in ssl-data-curation/data/2levels_random_embeddings/curated_datasets
110 | PYTHONPATH=.. python run_hierarchical_sampling.py \
111 | --clustering_path ../data/2levels_random_embeddings \
112 | --target_size 20000 \
113 | --save
114 | ```
115 |
116 | We also provide the config used for our web-based image data pool in [configs/4levels_web_based_images.yaml](configs/4levels_web_based_images.yaml).
117 |
118 | ## Notebook
119 | We provide a [notebook](vis/notebook.ipynb) to reproduce visualizations in the paper and show additional examples.
120 |
121 | ## Contributing
122 | See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
123 |
124 | ## License
125 | This code is CC-BY-NC 4.0 licensed, as found in [LICENSE](LICENSE).
126 |
127 | ## Citation
128 | If you find our work useful, please consider giving a star and a citation:
129 | ```
130 | @article{vo2024automatic,
131 | title={Automatic Data Curation for Self-Supervised Learning: A Clustering-Based Approach},
132 | author={Vo, Huy V. and Khalidov, Vasil and Darcet, Timoth{\'e}e and Moutakanni, Th{\'e}o and Smetanin, Nikita and Szafraniec, Marc and Touvron, Hugo and Couprie, Camille and Oquab, Maxime and Joulin, Armand and Jégou, Hervé and Labatut, Patrick and Bojanowski, Piotr},
133 | journal={arXiv:2405.15613},
134 | year={2024},
135 | }
136 | ```
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/configs/1level_dcml.yaml:
--------------------------------------------------------------------------------
1 | # Number of levels in hierarchical k-means.
2 | n_levels: 1
3 | # Number of updates of centroids in the main k-means loop.
4 | n_iters: 50
5 | # Number of clusters in each level of hierarchical k-means.
6 | # For efficiency in the first level, we run first a k-means
7 | # with 100k clusters, then split each cluster into 100
8 | # smaller ones to have 10M clusters.
9 | n_clusters:
10 | - 24
11 | # If > 1, run the level in two steps. First, k-means is executed once.
12 | # Then, each obtained cluster is splitted into "n_split" smaller clusters,
13 | # which are considered final and used in the subsequent level.
14 | n_splits:
15 | - 1
16 | # Number of resampling steps in each level.
17 | # For efficiency, we do not use resampling in the first level.
18 | n_resampling_steps:
19 | - 1
20 | # Number of data points sampled from each cluster in the resampling steps.
21 | # It is roughly half the average cluster size in each level.
22 | sample_size:
23 | - 1
24 | # Specified if running only on a subset of the data pool.
25 | # For example, we extract embeddings for all images in the data pool,
26 | # but run the curation pipeline only on a deduplicated subset.
27 | subset_indices_path: null
28 | checkpoint_period: 10_000
29 | dtype: float32
30 | high_precision: float32
31 | ngpus_per_node:
32 | - 8
33 | nnodes:
34 | - 4
35 | ncpus_per_gpu: 6
36 | sampling_strategy: r
37 | slurm_partition: pli-c
38 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/configs/2levels_random_embeddings.yaml:
--------------------------------------------------------------------------------
1 | # Number of levels in hierarchical k-means.
2 | n_levels: 2
3 | # Number of updates of centroids in the main k-means loop.
4 | n_iters: 50
5 | # Number of clusters in each level of hierarchical k-means.
6 | n_clusters:
7 | - 5000
8 | - 1000
9 | # If > 1, run the level in two steps. First, k-means is executed once.
10 | # Then, each obtained cluster is splitted into "n_split" smaller clusters,
11 | # which are considered final and used in the subsequent level.
12 | n_splits:
13 | - 1
14 | - 1
15 | # Number of resampling steps in each level.
16 | n_resampling_steps:
17 | - 10
18 | - 10
19 | # Number of data points sampled from each cluster in the resampling steps.
20 | # It is roughly half the average cluster size in each level.
21 | sample_size:
22 | - 10
23 | - 3
24 | # Specified if running only on a subset of the data pool.
25 | # For example, we extract embeddings for all images in the data pool,
26 | # but run the curation pipeline only on a deduplicated subset.
27 | subset_indices_path: null
28 | checkpoint_period: 1000
29 | dtype: float64
30 | high_precision: float64
31 | ngpus_per_node:
32 | - 2
33 | - 2
34 | nnodes:
35 | - 1
36 | - 1
37 | ncpus_per_gpu: 10
38 | sampling_strategy: c
39 | slurm_partition: null
40 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/configs/4levels_web_based_images.yaml:
--------------------------------------------------------------------------------
1 | # Number of levels in hierarchical k-means.
2 | n_levels: 4
3 | # Number of updates of centroids in the main k-means loop.
4 | n_iters: 50
5 | # Number of clusters in each level of hierarchical k-means.
6 | # For efficiency in the first level, we run first a k-means
7 | # with 100k clusters, then split each cluster into 100
8 | # smaller ones to have 10M clusters.
9 | n_clusters:
10 | - 100_000
11 | - 500_000
12 | - 50_000
13 | - 10_000
14 | # If > 1, run the level in two steps. First, k-means is executed once.
15 | # Then, each obtained cluster is splitted into "n_split" smaller clusters,
16 | # which are considered final and used in the subsequent level.
17 | n_splits:
18 | - 100
19 | - 1
20 | - 1
21 | - 1
22 | # Number of resampling steps in each level.
23 | # For efficiency, we do not use resampling in the first level.
24 | n_resampling_steps:
25 | - 1
26 | - 10
27 | - 10
28 | - 10
29 | # Number of data points sampled from each cluster in the resampling steps.
30 | # It is roughly half the average cluster size in each level.
31 | sample_size:
32 | - 1
33 | - 10
34 | - 5
35 | - 3
36 | # Specified if running only on a subset of the data pool.
37 | # For example, we extract embeddings for all images in the data pool,
38 | # but run the curation pipeline only on a deduplicated subset.
39 | subset_indices_path: null
40 | checkpoint_period: 10_000
41 | dtype: float64
42 | high_precision: float64
43 | ngpus_per_node:
44 | - 8
45 | - 8
46 | - 8
47 | - 8
48 | nnodes:
49 | - 16
50 | - 2
51 | - 1
52 | - 1
53 | ncpus_per_gpu: 10
54 | sampling_strategy: c
55 | slurm_partition: null
56 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/exps/dclm-1level-k118/level1/centroids.npy:
--------------------------------------------------------------------------------
1 | /scratch/gpfs/awettig/delve/k-means-clustering/exps/dclm-1level-k118/level1/step0/centroids.npy
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/exps/dclm-1level-k118/level1/slurm_script.s:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --nodes=4
4 | #SBATCH --gres=gpu:8
5 | #SBATCH --ntasks-per-node=1
6 | #SBATCH --job-name=kmeans_level1
7 | #SBATCH --time=1-0
8 | #SBATCH --mem=800G
9 | #SBATCH --cpus-per-task=32
10 | #SBATCH --partition=pli-c
11 |
12 | EXPDIR=/scratch/gpfs/awettig/delve/k-means-clustering/exps/dclm-1level-k118/level1
13 | cd /scratch/gpfs/awettig/delve/k-means-clustering/scripts
14 |
15 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
16 |
17 | PYTHONPATH=.. \
18 | srun -N 4 --unbuffered --output="$EXPDIR"/logs/%j_%t_log.out --error="$EXPDIR"/logs/%j_%t_log.err torchrun \
19 | --nnodes=4 \
20 | --nproc_per_node=8 \
21 | --rdzv_backend=c10d \
22 | --rdzv_endpoint=$master_addr:56321 \
23 | run_distributed_kmeans.py \
24 | --use_torchrun \
25 | --data_path /scratch/gpfs/PLI/awettig/dclm/dclm-pool-1b-1x/deduplicated/embeds \
26 | --n_clusters 118 \
27 | --n_iters 50 \
28 | --chunk_size 1694915 \
29 | --dtype float32 \
30 | --high_precision float32 \
31 | --checkpoint_period 10000 \
32 | --exp_dir $EXPDIR \
33 | --n_steps 1 \
34 | --sample_size 1 \
35 | --do_not_sort_clusters \
36 | --held_out_shards 100 \
37 | --sampling_strategy r
38 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/exps/dclm-1level-k13824/level1/centroids.npy:
--------------------------------------------------------------------------------
1 | /scratch/gpfs/awettig/delve/k-means-clustering/exps/dclm-1level-k13824/level1/step0/centroids.npy
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/exps/dclm-1level-k13824/level1/slurm_script.s:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --nodes=4
4 | #SBATCH --gres=gpu:8
5 | #SBATCH --ntasks-per-node=1
6 | #SBATCH --job-name=kmeans_level1
7 | #SBATCH --time=1-0
8 | #SBATCH --mem=800G
9 | #SBATCH --cpus-per-task=32
10 | #SBATCH --partition=pli-c
11 |
12 | EXPDIR=/scratch/gpfs/awettig/delve/k-means-clustering/exps/dclm-1level-k13824/level1
13 | cd /scratch/gpfs/awettig/delve/k-means-clustering/scripts
14 |
15 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
16 |
17 | PYTHONPATH=.. \
18 | srun -N 4 --unbuffered --output="$EXPDIR"/logs/%j_%t_log.out --error="$EXPDIR"/logs/%j_%t_log.err torchrun \
19 | --nnodes=4 \
20 | --nproc_per_node=8 \
21 | --rdzv_backend=c10d \
22 | --rdzv_endpoint=$master_addr:56321 \
23 | run_distributed_kmeans.py \
24 | --use_torchrun \
25 | --data_path /scratch/gpfs/PLI/awettig/dclm/dclm-pool-1b-1x/deduplicated/embeds \
26 | --n_clusters 13824 \
27 | --n_iters 50 \
28 | --chunk_size 14467 \
29 | --dtype float32 \
30 | --high_precision float32 \
31 | --checkpoint_period 10000 \
32 | --exp_dir $EXPDIR \
33 | --n_steps 1 \
34 | --sample_size 1 \
35 | --do_not_sort_clusters \
36 | --held_out_shards 100 \
37 | --sampling_strategy r
38 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/exps/dclm-1level-k24/level1/centroids.npy:
--------------------------------------------------------------------------------
1 | /scratch/gpfs/awettig/delve/k-means-clustering/exps/dclm-1level-k24/level1/step0/centroids.npy
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/exps/dclm-1level-k24/level1/slurm_script.s:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --nodes=3
4 | #SBATCH --gres=gpu:8
5 | #SBATCH --ntasks-per-node=1
6 | #SBATCH --job-name=kmeans_level1
7 | #SBATCH --time=1-0
8 | #SBATCH --mem=800G
9 | #SBATCH --cpus-per-task=32
10 | #SBATCH --partition=pli-c
11 |
12 | EXPDIR=/scratch/gpfs/awettig/delve/k-means-clustering/exps/dclm-1level-k24/level1
13 | cd /scratch/gpfs/awettig/delve/k-means-clustering/scripts
14 |
15 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
16 |
17 | PYTHONPATH=.. \
18 | srun -N 3 --unbuffered --output="$EXPDIR"/logs/%j_%t_log.out --error="$EXPDIR"/logs/%j_%t_log.err torchrun \
19 | --nnodes=3 \
20 | --nproc_per_node=8 \
21 | --rdzv_backend=c10d \
22 | --rdzv_endpoint=$master_addr:56321 \
23 | run_distributed_kmeans.py \
24 | --use_torchrun \
25 | --data_path /scratch/gpfs/PLI/awettig/dclm/dclm-pool-1b-1x/deduplicated/embeds \
26 | --n_clusters 24 \
27 | --n_iters 50 \
28 | --chunk_size 8333333 \
29 | --dtype float32 \
30 | --high_precision float32 \
31 | --checkpoint_period 10000 \
32 | --exp_dir $EXPDIR \
33 | --n_steps 1 \
34 | --sample_size 1 \
35 | --do_not_sort_clusters \
36 | --held_out_shards 100 \
37 | --sampling_strategy r
38 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/exps/dclm-1level-k2822/level1/centroids.npy:
--------------------------------------------------------------------------------
1 | /scratch/gpfs/awettig/delve/k-means-clustering/exps/dclm-1level-k2822/level1/step0/centroids.npy
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/exps/dclm-1level-k2822/level1/slurm_script.s:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --nodes=4
4 | #SBATCH --gres=gpu:8
5 | #SBATCH --ntasks-per-node=1
6 | #SBATCH --job-name=kmeans_level1
7 | #SBATCH --time=1-0
8 | #SBATCH --mem=800G
9 | #SBATCH --cpus-per-task=32
10 | #SBATCH --partition=pli-c
11 |
12 | EXPDIR=/scratch/gpfs/awettig/delve/k-means-clustering/exps/dclm-1level-k2822/level1
13 | cd /scratch/gpfs/awettig/delve/k-means-clustering/scripts
14 |
15 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
16 |
17 | PYTHONPATH=.. \
18 | srun -N 4 --unbuffered --output="$EXPDIR"/logs/%j_%t_log.out --error="$EXPDIR"/logs/%j_%t_log.err torchrun \
19 | --nnodes=4 \
20 | --nproc_per_node=8 \
21 | --rdzv_backend=c10d \
22 | --rdzv_endpoint=$master_addr:56321 \
23 | run_distributed_kmeans.py \
24 | --use_torchrun \
25 | --data_path /scratch/gpfs/PLI/awettig/dclm/dclm-pool-1b-1x/deduplicated/embeds \
26 | --n_clusters 2822 \
27 | --n_iters 50 \
28 | --chunk_size 70871 \
29 | --dtype float32 \
30 | --high_precision float32 \
31 | --checkpoint_period 10000 \
32 | --exp_dir $EXPDIR \
33 | --n_steps 1 \
34 | --sample_size 1 \
35 | --do_not_sort_clusters \
36 | --held_out_shards 100 \
37 | --sampling_strategy r
38 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/exps/dclm-1level-k576/level1/centroids.npy:
--------------------------------------------------------------------------------
1 | /scratch/gpfs/awettig/delve/k-means-clustering/exps/dclm-1level-k576/level1/step0/centroids.npy
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/exps/dclm-1level-k576/level1/slurm_script.s:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --nodes=4
4 | #SBATCH --gres=gpu:8
5 | #SBATCH --ntasks-per-node=1
6 | #SBATCH --job-name=kmeans_level1
7 | #SBATCH --time=1-0
8 | #SBATCH --mem=800G
9 | #SBATCH --cpus-per-task=32
10 | #SBATCH --partition=pli-c
11 |
12 | EXPDIR=/scratch/gpfs/awettig/delve/k-means-clustering/exps/dclm-1level-k576/level1
13 | cd /scratch/gpfs/awettig/delve/k-means-clustering/scripts
14 |
15 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
16 |
17 | PYTHONPATH=.. \
18 | srun -N 4 --unbuffered --output="$EXPDIR"/logs/%j_%t_log.out --error="$EXPDIR"/logs/%j_%t_log.err torchrun \
19 | --nnodes=4 \
20 | --nproc_per_node=8 \
21 | --rdzv_backend=c10d \
22 | --rdzv_endpoint=$master_addr:56321 \
23 | run_distributed_kmeans.py \
24 | --use_torchrun \
25 | --data_path /scratch/gpfs/PLI/awettig/dclm/dclm-pool-1b-1x/deduplicated/embeds \
26 | --n_clusters 576 \
27 | --n_iters 50 \
28 | --chunk_size 347222 \
29 | --dtype float32 \
30 | --high_precision float32 \
31 | --checkpoint_period 10000 \
32 | --exp_dir $EXPDIR \
33 | --n_steps 1 \
34 | --sample_size 1 \
35 | --do_not_sort_clusters \
36 | --held_out_shards 100 \
37 | --sampling_strategy r
38 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/exps/dclm-1level-k67723/level1/centroids.npy:
--------------------------------------------------------------------------------
1 | /scratch/gpfs/awettig/delve/k-means-clustering/exps/dclm-1level-k67723/level1/step0/centroids.npy
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/exps/dclm-1level-k67723/level1/slurm_script.s:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --nodes=4
4 | #SBATCH --gres=gpu:8
5 | #SBATCH --ntasks-per-node=1
6 | #SBATCH --job-name=kmeans_level1
7 | #SBATCH --time=1-0
8 | #SBATCH --mem=800G
9 | #SBATCH --cpus-per-task=32
10 | #SBATCH --partition=pli-c
11 |
12 | EXPDIR=/scratch/gpfs/awettig/delve/k-means-clustering/exps/dclm-1level-k67723/level1
13 | cd /scratch/gpfs/awettig/delve/k-means-clustering/scripts
14 |
15 | master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
16 |
17 | PYTHONPATH=.. \
18 | srun -N 4 --unbuffered --output="$EXPDIR"/logs/%j_%t_log.out --error="$EXPDIR"/logs/%j_%t_log.err torchrun \
19 | --nnodes=4 \
20 | --nproc_per_node=8 \
21 | --rdzv_backend=c10d \
22 | --rdzv_endpoint=$master_addr:56321 \
23 | run_distributed_kmeans.py \
24 | --use_torchrun \
25 | --data_path /scratch/gpfs/PLI/awettig/dclm/dclm-pool-1b-1x/deduplicated/embeds \
26 | --n_clusters 67723 \
27 | --n_iters 50 \
28 | --chunk_size 2953 \
29 | --dtype float32 \
30 | --high_precision float32 \
31 | --checkpoint_period 10000 \
32 | --exp_dir $EXPDIR \
33 | --n_steps 1 \
34 | --sample_size 1 \
35 | --do_not_sort_clusters \
36 | --held_out_shards 100 \
37 | --sampling_strategy r
38 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/images/curation_pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeCreator/WebOrganizer/b3da665635be3ee6c51484509f0fa5699f24d28c/define_domains/k-means-clustering/images/curation_pipeline.png
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/images/toy_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CodeCreator/WebOrganizer/b3da665635be3ee6c51484509f0fa5699f24d28c/define_domains/k-means-clustering/images/toy_example.png
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu121
2 | torch==2.2
3 | matplotlib==3.8.2
4 | scipy==1.11.4
5 | numpy==1.24.4
6 | omegaconf
7 | scikit-learn>=1.5.0
8 | tqdm
9 | ipykernel
10 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/scripts/run_hierarchical_sampling.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | from argparse import ArgumentParser
9 | from pathlib import Path
10 |
11 | import numpy as np
12 |
13 | from src.clusters import HierarchicalCluster
14 | from src.utils import setup_logging
15 | from src.hierarchical_sampling import hierarchical_sampling
16 |
17 | logger = logging.getLogger("hkmeans")
18 |
19 | if __name__ == "__main__":
20 | parser = ArgumentParser()
21 | parser.add_argument("--save", action="store_true")
22 | parser.add_argument("--clustering_path", "-clus", type=str, required=True)
23 | parser.add_argument(
24 | "--target_size",
25 | type=int,
26 | required=True,
27 | help="Target size of the sampled set"
28 | )
29 | parser.add_argument(
30 | "--multiplier",
31 | "-m",
32 | type=int,
33 | default=1,
34 | help="Maximum number of times an image is selected"
35 | )
36 | parser.add_argument(
37 | "--sampling_strategy",
38 | "-ss",
39 | type=str,
40 | default="r",
41 | help='"r" for random, "c" for closest',
42 | )
43 | parser.add_argument(
44 | "--sort_indices",
45 | action="store_true",
46 | help="If true, sort indices in increasing order",
47 | )
48 | parser.add_argument(
49 | "--name_suffix",
50 | type=str,
51 | default="",
52 | help="Suffix to add to the indice file name",
53 | )
54 | parser.add_argument(
55 | "--valid_indices_path",
56 | type=str,
57 | default=None,
58 | help=(
59 | "Path to .npy file containing valid indices of the base dataset. "
60 | "The clustering is computed only on these valid images."
61 | ),
62 | )
63 | parser.add_argument(
64 | "--cluster_fname",
65 | type=str,
66 | default="sorted_clusters.npy",
67 | help="name of files containing clusters",
68 | )
69 | parser.add_argument("--save_dir_name", type=str, default="curated_datasets")
70 |
71 | args = parser.parse_args()
72 | args.clustering_path = Path(args.clustering_path).resolve()
73 | setup_logging()
74 | logger.info(f"args: {args}")
75 |
76 | cl = HierarchicalCluster.from_file(
77 | cluster_path=args.clustering_path,
78 | cluster_fname=args.cluster_fname
79 | )
80 |
81 | sampled_indices = hierarchical_sampling(
82 | cl,
83 | args.target_size,
84 | args.multiplier,
85 | args.sampling_strategy,
86 | )
87 | if args.valid_indices_path is not None:
88 | valid_indices = np.load(args.valid_indices_path)
89 | assert len(valid_indices) == np.sum(
90 | [len(el) for el in cl.clusters[1]]
91 | ), "Number of images is not equal to valid_indices size"
92 | sampled_indices = valid_indices[sampled_indices]
93 |
94 | if args.sort_indices:
95 | sampled_indices = np.sort(sampled_indices)
96 |
97 | num_images = len(sampled_indices)
98 | logger.info(f"Number of selected data points: {num_images}")
99 |
100 | save_indices_path = Path(
101 | args.clustering_path,
102 | args.save_dir_name,
103 | f'{cl.n_levels}{args.sampling_strategy}_mul{args.multiplier}_'
104 | f'{args.target_size}_balanced_selection.npy'
105 | )
106 | if len(args.name_suffix) > 0:
107 | save_indices_path = Path(
108 | str(save_indices_path).replace(".npy", f"_{args.name_suffix}.npy")
109 | )
110 | logger.info(f"Indices will be saved to {str(save_indices_path.resolve())}")
111 | if args.save:
112 | Path(args.clustering_path, args.save_dir_name).mkdir(exist_ok=True)
113 | np.save(save_indices_path, sampled_indices)
114 | logger.info("Indices are saved!")
115 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/scripts/split_clusters.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import sys
9 | from argparse import ArgumentParser
10 | import logging
11 | from pathlib import Path
12 | from tqdm import tqdm
13 |
14 | import numpy as np
15 | import torch
16 |
17 | from src.utils import setup_logging, MDSPseudoMemMap, MultiMemMap
18 |
19 | from src.dist_comm import (
20 | enable_distributed,
21 | get_global_rank,
22 | get_global_size,
23 | is_main_process,
24 | synchronize,
25 | )
26 | from src import distributed_kmeans_gpu as dkmg, kmeans_gpu as kmg
27 |
28 |
29 | logger = logging.getLogger("hkmeans")
30 |
31 |
32 | def split_clusters(
33 | data_path,
34 | subset_indices_path,
35 | clusters_path,
36 | n_splits,
37 | n_iters,
38 | dtype,
39 | high_precision,
40 | save_path,
41 | device="cuda",
42 | use_torchrun=False,
43 | checkpoint_period=10,
44 | verbose=False,
45 | ):
46 | enable_distributed(
47 | use_torchrun=use_torchrun,
48 | overwrite=True,
49 | )
50 |
51 | synchronize()
52 | logger.info("initial synchronized!")
53 |
54 |
55 | if os.path.isdir(data_path):
56 | X = MultiMemMap(data_path)
57 | else:
58 | X = np.load(data_path, mmap_mode="r")
59 |
60 | if subset_indices_path is not None:
61 | logger.info(f"Using subset with indices in {subset_indices_path}")
62 | subset_indices = np.load(subset_indices_path)
63 | X = dkmg.ExtendedNumpyMemMap(X, subset_indices)
64 | clusters = np.load(clusters_path, allow_pickle=True)
65 | n_clusters = len(clusters)
66 |
67 | part_indices = dkmg.get_part_indices(n_clusters, get_global_size())
68 | rank = get_global_rank()
69 |
70 | # load checkpoints if exist
71 | if Path(save_path, f"split_checkpoint_{rank}.npy").exists():
72 | ckpt = np.load(
73 | Path(save_path, f"split_checkpoint_{rank}.npy"), allow_pickle=True
74 | ).item()
75 | small_centroids = list(ckpt["small_centroids"])
76 | small_clusters = list(ckpt["small_clusters"])
77 | last_index = ckpt["last_index"]
78 | assert last_index - part_indices[rank] + 1 == len(small_centroids)
79 | else:
80 | small_centroids = []
81 | small_clusters = []
82 | last_index = part_indices[rank] - 1
83 |
84 | # run kmeans++ on clusters
85 | for cluster_idx in tqdm(
86 | range(last_index + 1, part_indices[rank + 1]),
87 | desc="Splitting pre-clusters",
88 | file=sys.stdout,
89 | bar_format="{l_bar}{bar}{r_bar}",
90 | ):
91 | if verbose:
92 | logger.info(f"Processing cluster {cluster_idx}")
93 | point_indices = np.sort(clusters[cluster_idx])
94 | if len(point_indices) > 0:
95 | point_feats = torch.tensor(X[point_indices], device=device, dtype=dtype)
96 | _small_centroids, _small_clusters, _, _ = kmg.kmeans(
97 | point_feats,
98 | min(n_splits, len(point_indices)),
99 | n_iters,
100 | chunk_size=-1,
101 | init_method="kmeans++",
102 | dist="l2",
103 | high_precision=high_precision,
104 | )
105 |
106 | _small_clusters = kmg.sort_cluster_by_distance(
107 | point_feats,
108 | _small_centroids,
109 | _small_clusters,
110 | device="cuda",
111 | dtype=dtype,
112 | )
113 | _small_clusters = [point_indices[el.astype(int)] for el in _small_clusters]
114 |
115 | non_empty_clusters = [len(el) > 0 for el in _small_clusters]
116 | _small_clusters = [el for el in _small_clusters if len(el) > 0]
117 | _small_centroids = _small_centroids[non_empty_clusters]
118 |
119 | small_centroids.append(_small_centroids.cpu().numpy())
120 | small_clusters += _small_clusters
121 |
122 | del point_feats
123 | if(
124 | cluster_idx % checkpoint_period == 0 or
125 | cluster_idx == part_indices[rank + 1] - 1
126 | ):
127 | np.save(
128 | Path(save_path, f"split_checkpoint_{rank}.npy"),
129 | {
130 | "small_centroids": small_centroids,
131 | "small_clusters": small_clusters,
132 | "last_index": cluster_idx,
133 | },
134 | )
135 | synchronize()
136 | logger.info("Gathering clusters")
137 | if is_main_process():
138 | centroids = []
139 | clusters = []
140 | for i in tqdm(
141 | range(get_global_size()),
142 | desc="Gathering splitted clusters",
143 | file=sys.stdout,
144 | bar_format="{l_bar}{bar}{r_bar}",
145 | ):
146 | split_data = np.load(
147 | Path(save_path, f"split_checkpoint_{i}.npy"),
148 | allow_pickle=True
149 | ).item()
150 | small_centroids = np.concatenate(split_data["small_centroids"])
151 | small_clusters = split_data["small_clusters"]
152 | assert(
153 | len(small_centroids) == len(small_clusters)
154 | ), f"Inconsistent shape in split_checkpoint_{i}.npy"
155 | assert split_data["last_index"] == part_indices[i + 1] - 1
156 | centroids.append(small_centroids)
157 | clusters += small_clusters
158 | centroids = np.concatenate(centroids)
159 | clusters = np.array(clusters, dtype=object)
160 |
161 | logger.info("Saving centroids and clusters")
162 | np.save(Path(save_path, "centroids.npy"), centroids)
163 | np.save(Path(save_path, "sorted_clusters.npy"), clusters)
164 | logger.info("Cleaning checkpoints")
165 | for i in range(get_global_size()):
166 | Path(save_path, f"split_checkpoint_{i}.npy").unlink(missing_ok=True)
167 | logger.info("Finished split_clusters!")
168 |
169 | if __name__ == "__main__":
170 | parser = ArgumentParser()
171 | parser.add_argument("--data_path", type=str, required=True)
172 | parser.add_argument("--subset_indices_path", type=str, default=None)
173 | parser.add_argument("--clusters_path", type=str, required=True)
174 | parser.add_argument("--n_splits", type=int, required=True)
175 | parser.add_argument("--n_iters", type=int, required=True)
176 | parser.add_argument("--dtype", type=str, default="float32")
177 | parser.add_argument("--high_precision", type=str, default="float32")
178 | parser.add_argument("--save_path", type=str, required=True)
179 | parser.add_argument("--use_torchrun", action="store_true")
180 |
181 | args = parser.parse_args()
182 | setup_logging()
183 |
184 | def parse_dtype(dtype):
185 | if dtype == "float32":
186 | return torch.float32
187 | elif dtype == "float64":
188 | return torch.float64
189 | elif dtype == "float16":
190 | return torch.float16
191 | else:
192 | raise ValueError(f"Value of args.dtype ({args.dtype}) not regconised")
193 |
194 | args.dtype = parse_dtype(args.dtype)
195 | args.high_precision = parse_dtype(args.high_precision)
196 |
197 | split_clusters(
198 | args.data_path,
199 | args.subset_indices_path,
200 | args.clusters_path,
201 | args.n_splits,
202 | args.n_iters,
203 | args.dtype,
204 | args.high_precision,
205 | args.save_path,
206 | "cuda",
207 | args.use_torchrun,
208 | )
209 | synchronize()
210 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from setuptools import find_packages, setup
8 |
9 |
10 | setup(
11 | name="ssl_data_curation",
12 | packages=find_packages(),
13 | )
14 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/src/clusters.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | from pathlib import Path
9 | import pickle
10 | from typing import Dict, List
11 |
12 | import numpy as np
13 |
14 |
15 | logger = logging.getLogger("hkmeans")
16 |
17 |
18 | def load_clusters_from_file(fpath):
19 | """
20 | Utility to load clusters fromj different file formats.
21 | """
22 | if Path(fpath).suffix == ".pkl":
23 | with open(fpath, "rb") as f:
24 | return np.array(pickle.load(f), dtype=object)
25 | else:
26 | return np.load(Path(fpath), allow_pickle=True)
27 |
28 | class HierarchicalCluster:
29 | """
30 | Class representing a hierarchy of clusters returned by hierarchical k-means.
31 | """
32 | def __init__(self):
33 | self.cluster_path = None
34 | self.n_levels = None
35 | self.cluster_fname = None
36 | self.is_loaded = False
37 | self.is_processed = False
38 | self.n_clusters = {}
39 | self.clusters = {}
40 | self.flat_clusters = {}
41 | self.clusters_size = {}
42 | self.flat_clusters_size = {}
43 | self.size_order = {}
44 | self.flat_size_order = {}
45 |
46 | def load_clusters_from_file(self):
47 | for level in range(1, 1 + self.n_levels):
48 | self.clusters[level] = load_clusters_from_file(
49 | Path(
50 | self.cluster_path,
51 | f"level{level}",
52 | self.cluster_fname
53 | )
54 | )
55 | self.n_clusters[level] = len(self.clusters[level])
56 | self.is_loaded = True
57 |
58 | def process_clusters(self):
59 | if not self.is_loaded:
60 | raise RuntimeError("Clusters must be loaded before being processed")
61 | logger.info("Computing flat clusters")
62 | self.flat_clusters[1] = self.clusters[1]
63 | for level in range(2, 1 + self.n_levels):
64 | current_non_flat = self.clusters[level]
65 | prev_flat = self.flat_clusters[level - 1]
66 | self.flat_clusters[level] = np.array(
67 | [
68 | np.concatenate([prev_flat[el] for el in clus])
69 | if len(clus) > 0 else np.array([])
70 | for clus in current_non_flat
71 | ],
72 | dtype=object,
73 | )
74 |
75 | logger.info("Computing cluster length")
76 | for level, clus in self.clusters.items():
77 | self.clusters_size[level] = np.array([len(el) for el in clus])
78 |
79 | for level, clus in self.flat_clusters.items():
80 | self.flat_clusters_size[level] = np.array([len(el) for el in clus])
81 |
82 | logger.info("Sorting clusters by length")
83 | for level, clsize in self.clusters_size.items():
84 | self.size_order[level] = np.argsort(clsize)[::-1]
85 |
86 | for level, flat_clsize in self.flat_clusters_size.items():
87 | self.flat_size_order[level] = np.argsort(flat_clsize)[::-1]
88 |
89 | self.is_processed = True
90 |
91 | @staticmethod
92 | def from_file(
93 | cluster_path,
94 | cluster_fname="sorted_clusters.npy",
95 | ):
96 | """
97 | Method for reading hierarchical clusters from files
98 | """
99 | logger.info("Loading hierarchical clusters from file.")
100 | cl = HierarchicalCluster()
101 | cl.cluster_path = cluster_path
102 | cl.cluster_fname = cluster_fname
103 | cl.n_levels = 0
104 | while True:
105 | if Path(cl.cluster_path, f"level{cl.n_levels + 1}").exists():
106 | cl.n_levels += 1
107 | else:
108 | break
109 | cl.load_clusters_from_file()
110 | cl.process_clusters()
111 | return cl
112 |
113 | @staticmethod
114 | def from_dict(clusters: List[Dict]):
115 | """
116 | Read hierarchical clusters from a list of dictionaries.
117 |
118 | Parameters:
119 | clusters: List[Dict]
120 | Each element is a dictionary containing a field name "clusters".
121 | An example is the output of hierarchical_kmeans_gpu.hierarchical_kmeans
122 |
123 | Return:
124 | A instance of HierarchicalCluster.
125 | """
126 | logger.info("Loading hierarchical clusters from dictionaries.")
127 | cl = HierarchicalCluster()
128 | cl.n_levels = len(clusters)
129 | for level in range(1, 1 + cl.n_levels):
130 | cl.clusters[level] = clusters[level - 1]["clusters"]
131 | cl.n_clusters[level] = len(cl.clusters[level])
132 | cl.is_loaded = True
133 | cl.process_clusters()
134 | return cl
135 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/src/hierarchical_kmeans_gpu.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import sys
8 | import logging
9 | from tqdm import tqdm
10 |
11 | import torch
12 | import numpy as np
13 |
14 | from . import kmeans_gpu as kmg
15 |
16 |
17 | logger = logging.getLogger("hkmeans")
18 | MEMORY_LIMIT = 1e8
19 |
20 |
21 | def hierarchical_kmeans(
22 | data,
23 | n_clusters,
24 | n_levels,
25 | init_method="kmeans++",
26 | num_init=1,
27 | verbose=True
28 | ):
29 | """
30 | Run hierarchical k-means on data without resampling steps.
31 |
32 | Parameters:
33 | data: 2-D numpy array
34 | Data embeddings.
35 | n_clusters: List[int]
36 | Number of clusters for each level of hierarchical k-means
37 | n_levels: int
38 | Number of levels in hierarchical k-means.
39 | init_method: str, default = "k-means++"
40 | Initialization method for k-means centroids.
41 | Options are "k-means" and "random".
42 | num_init: int, default=1
43 | Number of re-initialization for each k-means run.
44 |
45 | Returns:
46 | List[dict], clustering results for each level of hierarchical k-means,
47 | including
48 | centroids: 2-D numpy array
49 | Centroids of clusters.
50 | assigment: 1-D numpy array
51 | Mapping from data points to cluster indices.
52 | clusters: array of array
53 | pot: float
54 | K-means potential.
55 | """
56 | assert len(n_clusters) == n_levels
57 | logger.info(f"{n_levels}-level hierarchical kmeans")
58 | res = []
59 | for kmid in range(n_levels):
60 | logger.info(f"Level {kmid+1}")
61 | if kmid == 0:
62 | X = data
63 | else:
64 | X = res[kmid - 1]["centroids"]
65 | chunk_size = min(X.shape[0], int(MEMORY_LIMIT / n_clusters[kmid]))
66 | centroids, clusters, cluster_assignment, pot = kmg.kmeans(
67 | X,
68 | n_clusters=n_clusters[kmid],
69 | n_iters=50,
70 | chunk_size=chunk_size,
71 | num_init=num_init,
72 | init_method=init_method,
73 | dist="l2",
74 | high_precision=torch.float64,
75 | random_state=None,
76 | verbose=verbose
77 | )
78 | res.append(
79 | {
80 | "centroids": centroids,
81 | "assignment": cluster_assignment,
82 | "clusters": clusters,
83 | "pot": pot,
84 | }
85 | )
86 | return res
87 |
88 |
89 | def hierarchical_kmeans_with_resampling(
90 | data,
91 | n_clusters,
92 | n_levels,
93 | sample_sizes,
94 | n_resamples=10,
95 | init_method="kmeans++",
96 | num_init=1,
97 | sample_strategy="closest",
98 | verbose=True,
99 | ):
100 | """
101 | Run hierarchical k-means on data without resampling steps.
102 |
103 | Parameters:
104 | data: 2-D numpy array
105 | Data embeddings.
106 | n_clusters: List[int]
107 | Number of clusters for each level of hierarchical k-means
108 | n_levels: int
109 | Number of levels in hierarchical k-means.
110 | sample_size: List[int]
111 | Number of points to sample from each cluster in resampling steps.
112 | n_resamples: int
113 | Number of resampling steps in each level.
114 | init_method: str, default = "k-means++"
115 | Initialization method for k-means centroids.
116 | Options are "k-means" and "random".
117 | num_init: int, default=1
118 | Number of re-initialization for each k-means run.
119 | sampling_strategy: str, default = "closest"
120 | How to sample points from clusters in resampling steps.
121 | Options are "closest" and "random".
122 |
123 | Returns:
124 | List[dict], clustering results for each level of hierarchical k-means,
125 | including
126 | centroids: 2-D numpy array
127 | Centroids of clusters.
128 | assigment: 1-D numpy array
129 | Mapping from data points to cluster indices.
130 | clusters: array of array
131 | pot: float
132 | K-means potential.
133 | """
134 | assert len(n_clusters) == n_levels
135 | assert len(sample_sizes) == n_levels
136 | logger.info(f"{n_levels}-level hierarchical kmeans")
137 | res = []
138 | for kmid in range(n_levels):
139 | logger.info(f"Level {kmid+1}")
140 | logger.info("Initial kmeans")
141 | if kmid == 0:
142 | X = data
143 | else:
144 | X = res[kmid - 1]["centroids"]
145 | chunk_size = min(X.shape[0], int(MEMORY_LIMIT / n_clusters[kmid]))
146 | logger.info("Running the initial k-means")
147 | centroids, clusters, cluster_assignment, _ = kmg.kmeans(
148 | X,
149 | n_clusters=n_clusters[kmid],
150 | n_iters=50,
151 | chunk_size=chunk_size,
152 | num_init=num_init,
153 | init_method=init_method,
154 | dist="l2",
155 | high_precision=torch.float64,
156 | random_state=None,
157 | verbose=verbose,
158 | )
159 | logger.info("Resampling-kmeans")
160 | if sample_sizes[kmid] > 1:
161 | _sample_size = sample_sizes[kmid]
162 | for _ in tqdm(
163 | range(n_resamples),
164 | desc="Hierarchical k-means resampling steps",
165 | file=sys.stdout,
166 | bar_format="{l_bar}{bar}{r_bar}",
167 | ):
168 | if sample_strategy == "closest":
169 | sorted_clusters = [
170 | _cluster[
171 | torch.argsort(
172 | torch.cdist(X[_cluster], centroids[i, None])
173 | .flatten()
174 | )
175 | .cpu()
176 | .numpy()
177 | ]
178 | for i, _cluster in enumerate(clusters)
179 | ]
180 | sampled_points = torch.concat(
181 | [
182 | X[_cluster[: _sample_size]]
183 | for _cluster in sorted_clusters
184 | ]
185 | )
186 | elif sample_strategy == "random":
187 | sampled_points = torch.concat(
188 | [
189 | X[
190 | np.random.choice(
191 | _cluster,
192 | min(len(_cluster), _sample_size),
193 | replace=False
194 | )
195 | ]
196 | for _cluster in clusters
197 | ]
198 | )
199 | else:
200 | raise ValueError(
201 | f"sample_strategy={sample_strategy} not supported!"
202 | )
203 | chunk_size = min(
204 | sampled_points.shape[0],
205 | int(MEMORY_LIMIT / n_clusters[kmid])
206 | )
207 | centroids, _, _, _ = kmg.kmeans(
208 | sampled_points,
209 | n_clusters=n_clusters[kmid],
210 | n_iters=50,
211 | chunk_size=chunk_size,
212 | num_init=num_init,
213 | init_method=init_method,
214 | dist="l2",
215 | high_precision=torch.float64,
216 | random_state=None,
217 | verbose=False
218 | )
219 | cluster_assignment = kmg.assign_clusters(
220 | centroids,
221 | X,
222 | "l2",
223 | chunk_size=chunk_size,
224 | verbose=False
225 | ).cpu().numpy()
226 | clusters = kmg.create_clusters_from_cluster_assignment(
227 | cluster_assignment,
228 | n_clusters[kmid]
229 | )
230 | res.append(
231 | {
232 | "centroids": centroids,
233 | "assignment": cluster_assignment,
234 | "clusters": clusters,
235 | "pot": -1,
236 | }
237 | )
238 | return res
239 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/src/hierarchical_sampling.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import sys
8 | import logging
9 | import random
10 |
11 | import numpy as np
12 | from tqdm import tqdm
13 |
14 | from src.clusters import HierarchicalCluster
15 |
16 |
17 | logger = logging.getLogger("hkmeans")
18 |
19 | def random_selection(clusters, valid_clusters, num_per_cluster):
20 | """
21 | Parameters:
22 | clusters: (num_cluster, ) np.array
23 | clusters[i] contain indices of points in cluster i
24 | valid_clusters: list or np.array
25 | indices of clusters that are considered
26 | num_per_cluster: int
27 | number of points selected from each cluster
28 |
29 | Returns:
30 | array containing indices of selected points
31 | """
32 | num_clusters = len(clusters)
33 | selected = [[]] * num_clusters
34 | for cluster_id in tqdm(
35 | valid_clusters,
36 | desc="Random sampling from clusters",
37 | file=sys.stdout,
38 | bar_format="{l_bar}{bar}{r_bar}",
39 | ):
40 | selected[cluster_id] = random.sample(
41 | list(clusters[cluster_id]), min(num_per_cluster, len(clusters[cluster_id]))
42 | )
43 | return np.concatenate(selected).astype(np.int64)
44 |
45 |
46 | def closest_to_centroid_selection(sorted_clusters, valid_clusters, num_per_cluster):
47 | """
48 | Parameters:
49 | sorted_clusters: (num_cluster, ) np.array
50 | clusters[i] contain indices of points in cluster i
51 | indices in clusters[i] are sorted in increasing distance from the centroid i
52 | valid_clusters: list or np.array
53 | indices of clusters that are considered
54 | num_per_cluster: int, number of points selected from each cluster
55 |
56 | Returns:
57 | array containing indices of selected points
58 | """
59 | num_clusters = len(sorted_clusters)
60 | selected = [[]] * num_clusters
61 | for cluster_id in tqdm(
62 | valid_clusters,
63 | desc="Closest-to-centroid sampling from clusters",
64 | file=sys.stdout,
65 | bar_format="{l_bar}{bar}{r_bar}",
66 | ):
67 | selected[cluster_id] = sorted_clusters[cluster_id][:num_per_cluster]
68 | return np.concatenate(selected).astype(np.int64)
69 |
70 |
71 | def _find_best_cut_left(arr, target):
72 | """
73 | Find integers x such that sum(min(x, arr)) best approximates target
74 | """
75 | if target < 0:
76 | raise ValueError(f"target {target} must be non-negative!")
77 | if np.min(arr) < 0:
78 | raise ValueError("arr has negative elements!")
79 | if np.sum(arr) <= target:
80 | return np.max(arr)
81 | left = 0
82 | right = np.max(arr)
83 | while right - left > 1:
84 | mid = (left + right) // 2
85 | sum_with_mid = np.sum(np.minimum(mid, arr))
86 | if sum_with_mid > target:
87 | right = mid
88 | elif sum_with_mid < target:
89 | left = mid
90 | else:
91 | return mid
92 | if np.sum(np.minimum(right, arr)) <= target:
93 | return right
94 | return left
95 |
96 |
97 | def find_subcluster_target_size(
98 | subcluster_sizes,
99 | target_size,
100 | multiplier,
101 | ):
102 | """
103 | Given the target number of points to sample from a clusters,
104 | find number of points to sample from its subclusters.
105 | """
106 | if isinstance(subcluster_sizes, np.ndarray):
107 | arr = subcluster_sizes * multiplier
108 | else:
109 | arr = np.array(subcluster_sizes) * multiplier
110 | best_cut_left = _find_best_cut_left(arr, target_size)
111 | if best_cut_left == np.max(arr):
112 | return arr
113 | else:
114 | subcluster_target_sizes = np.minimum(best_cut_left, arr)
115 | remainder = target_size - subcluster_target_sizes.sum()
116 | candidates = np.where(arr > best_cut_left)[0]
117 | subcluster_target_sizes[np.random.choice(candidates, remainder, replace=False)] = best_cut_left + 1
118 | assert subcluster_target_sizes.sum() == target_size
119 | assert np.all(subcluster_target_sizes <= arr)
120 | return subcluster_target_sizes
121 |
122 |
123 | def recursive_hierarchical_sampling(
124 | clusters: HierarchicalCluster,
125 | level: int,
126 | target_size: int,
127 | cl_index: int,
128 | multiplier: int,
129 | sampling_strategy: str = "r",
130 | ):
131 | """
132 | Given a target number of points to sample from a cluster, return
133 | the a set of sampled points.
134 | """
135 | if level == 1:
136 | current_cluster = clusters.clusters[1][cl_index]
137 | current_cluster_size = clusters.clusters_size[1][cl_index]
138 | if current_cluster_size * multiplier <= target_size:
139 | return np.tile(current_cluster, multiplier)
140 | else:
141 | n_replicates = target_size // current_cluster_size
142 | replicates = np.tile(current_cluster, n_replicates)
143 | remaining_target = target_size - n_replicates * current_cluster_size
144 | if sampling_strategy == "r": # random
145 | remaining_samples = np.random.choice(
146 | current_cluster,
147 | remaining_target,
148 | replace=False,
149 | )
150 | elif sampling_strategy == "c": # "closest"
151 | remaining_samples = current_cluster[:remaining_target]
152 | else:
153 | raise ValueError(f"sampling_strategy={sampling_strategy} is not supported")
154 | return np.concatenate([replicates, remaining_samples])
155 | else:
156 | subcl_indices = clusters.clusters[level][cl_index]
157 | subcluster_sizes = clusters.flat_clusters_size[level - 1][subcl_indices]
158 | subcluster_target_sizes = find_subcluster_target_size(
159 | subcluster_sizes,
160 | target_size,
161 | multiplier,
162 | )
163 | samples = []
164 | for i, subcl_index in enumerate(subcl_indices):
165 | samples.append(
166 | recursive_hierarchical_sampling(
167 | clusters,
168 | level - 1,
169 | subcluster_target_sizes[i],
170 | subcl_index,
171 | multiplier,
172 | sampling_strategy,
173 | )
174 | )
175 | return np.concatenate(samples)
176 |
177 |
178 | def hierarchical_sampling(
179 | clusters: HierarchicalCluster,
180 | target_size: int,
181 | multiplier: int = 1,
182 | sampling_strategy: str = "r",
183 | ):
184 | """
185 | Method for sample hierarchically from a hierarchy of clusters.
186 | """
187 | if (not clusters.is_loaded) or (not clusters.is_processed):
188 | raise RuntimeError("HierarchicalCluster is not loaded or processed.")
189 | n_levels = clusters.n_levels
190 | cluster_target_sizes = find_subcluster_target_size(
191 | clusters.flat_clusters_size[n_levels],
192 | target_size,
193 | multiplier,
194 | )
195 | samples = []
196 | for cl_index in tqdm(
197 | range(len(clusters.clusters[n_levels])),
198 | desc="Hierarchical sampling from clusters",
199 | file=sys.stdout,
200 | bar_format="{l_bar}{bar}{r_bar}",
201 | ):
202 | samples.append(
203 | recursive_hierarchical_sampling(
204 | clusters,
205 | n_levels,
206 | cluster_target_sizes[cl_index],
207 | cl_index,
208 | multiplier,
209 | sampling_strategy,
210 | )
211 | )
212 | samples = np.concatenate(samples)
213 | return samples
214 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/src/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import sys
8 | import logging
9 | from pathlib import Path
10 |
11 | import numpy as np
12 | import torch
13 |
14 | from streaming import LocalDataset
15 |
16 |
17 | class MultiMemMap:
18 | def __init__(self, path: str, held_out_shards: int = 0):
19 | """
20 | Parameters:
21 | X: memmap to a numy array, or an array
22 | indices: array, indices representing the slice
23 | """
24 | paths = sorted(Path(path).glob("*.npy"))
25 | self.shards = [
26 | np.load(path, mmap_mode="r")
27 | for path in paths[:len(paths)-held_out_shards]
28 | ]
29 | self.lengths = [
30 | len(shard) for shard in self.shards
31 | ]
32 | self.cum_lengths = np.cumsum(self.lengths)
33 | self.dim = self.shards[0].shape[-1]
34 | self.dtype = self.shards[0].dtype
35 |
36 | def __getitem__(self, ids):
37 | if isinstance(ids, int):
38 | return self.__getitem__([ids])[0]
39 | ids = np.arange(len(self))[ids]
40 |
41 | shard_idx = np.searchsorted(self.cum_lengths, ids, side='right')
42 | results = np.zeros((len(shard_idx), self.dim), dtype=self.dtype)
43 |
44 | for shard_id in np.unique(shard_idx):
45 | ids_mask = shard_idx == shard_id
46 | results[ids_mask] = self.shards[shard_id][ids[ids_mask] - self.cum_lengths[shard_id]]
47 | return results
48 |
49 | def __len__(self):
50 | return self.cum_lengths[-1]
51 |
52 | @property
53 | def shape(self):
54 | return (self.cum_lengths[-1], self.dim)
55 |
56 | def numpy(self):
57 | return self.__getitem__(slice(0, len(self)))
58 |
59 | def to_tensor(self, dtype, device):
60 | return torch.tensor(self.numpy(), device=device, dtype=dtype)
61 |
62 |
63 | class MDSPseudoMemMap(LocalDataset):
64 | def __init__(self, path: str, field="embedding"):
65 | """
66 | Parameters:
67 | X: memmap to a numy array, or an array
68 | indices: array, indices representing the slice
69 | """
70 | super().__init__(path)
71 | self.field = field
72 |
73 | def __getitem__(self, ids):
74 | result = super().__getitem__(ids)
75 | if isinstance(result, dict):
76 | return result[self.field]
77 | elif isinstance(result[0], dict):
78 | return np.stack([r[self.field] for r in result])
79 | else:
80 | return np.stack(result)
81 |
82 | @property
83 | def shape(self):
84 | return (len(self), len(self[0]))
85 |
86 | def numpy(self):
87 | return self.__getitem__(slice(0, len(self)))
88 |
89 | def to_tensor(self, dtype, device):
90 | return torch.tensor(self.numpy(), device=device, dtype=dtype)
91 |
92 |
93 |
94 | def create_clusters_from_cluster_assignment(
95 | cluster_assignment: np.array,
96 | num_clusters: int,
97 | return_object_array: bool = True,
98 | ):
99 | """
100 | Build clusters from cluster assignment.
101 | """
102 | ID = np.argsort(cluster_assignment)
103 | sorted_cluster_assigment = cluster_assignment[ID]
104 | index_split = np.searchsorted(sorted_cluster_assigment, list(range(num_clusters)))
105 | clusters = np.split(ID, index_split[1:])
106 | if return_object_array:
107 | return np.array(clusters, dtype=object)
108 | else:
109 | return clusters
110 |
111 |
112 | def find_all_checkpoints(save_dir, pattern):
113 | """
114 | Parameters:
115 | pattern: str
116 | checkpoint name format _%d.,
117 | e.g., kmpp_checkpoint_%d.pth
118 | """
119 | save_dir = Path(save_dir)
120 | ckpt_list = [str(el.stem) for el in save_dir.glob(pattern.replace("%d", "*"))]
121 | ckpt_list = [int(el.split("_")[-1]) for el in ckpt_list]
122 | ckpt_list = sorted(ckpt_list)
123 | return [Path(save_dir, pattern % el) for el in ckpt_list]
124 |
125 |
126 | def get_last_valid_checkpoint(save_dir, pattern):
127 | """
128 | Find path to the last checkpoint.
129 | """
130 | ckpt_list = find_all_checkpoints(save_dir, pattern)
131 | for ckpt_path in ckpt_list[::-1]:
132 | try:
133 | if ".pth" in pattern:
134 | _ = torch.load(ckpt_path, map_location="cpu")
135 | elif ".npy" in pattern:
136 | _ = np.load(ckpt_path)
137 | else:
138 | raise ValueError("Pattern not recognized!")
139 | return ckpt_path
140 | except Exception:
141 | continue
142 | return None
143 |
144 |
145 | def _delete_old_checkpoint(
146 | save_dir, current_iter, checkpoint_period, max_num_checkpoints, pattern
147 | ):
148 | Path(
149 | save_dir, pattern % (current_iter - checkpoint_period * max_num_checkpoints)
150 | ).unlink(missing_ok=True)
151 |
152 |
153 | def setup_logging(
154 | *,
155 | name: str = None,
156 | level: int = logging.INFO,
157 | capture_warnings: bool = True,
158 | ) -> None:
159 | """
160 | Basic setting for logger.
161 | """
162 | logging.captureWarnings(capture_warnings)
163 |
164 | logger = logging.getLogger(name)
165 | logger.setLevel(level)
166 |
167 | if logger.hasHandlers():
168 | return
169 |
170 | fmt_prefix = (
171 | "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] "
172 | )
173 | fmt_message = "%(message)s"
174 | fmt = fmt_prefix + fmt_message
175 | datefmt = "%Y%m%d %H:%M:%S"
176 | formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
177 |
178 | handler = logging.StreamHandler(sys.stdout)
179 | handler.setLevel(level)
180 | handler.setFormatter(formatter)
181 |
182 | logger.propagate = False
183 | logger.addHandler(handler)
184 | return
185 |
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/vis/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
--------------------------------------------------------------------------------
/define_domains/k-means-clustering/vis/generalized_kmeans_1d.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import sys
8 |
9 | import numpy as np
10 | import torch
11 | from tqdm import tqdm
12 | from sklearn.utils import check_random_state
13 |
14 | from src import kmeans_gpu as kmg
15 | from src.utils import create_clusters_from_cluster_assignment
16 |
17 |
18 | def l2_squared_power(x, xi, n):
19 | """
20 | Compute L_2 ^ (2 * n) distance
21 | """
22 | return (x - xi) ** (2 * n)
23 |
24 |
25 | def l2_squared_power_der(x, xi, n):
26 | """
27 | Compute the derivative of L_2 ^ (2 * n) distance
28 | """
29 | return 2 * n * (x - xi) ** (2 * n - 1)
30 |
31 |
32 | def l2_squared_power_der2(x, xi, n):
33 | """
34 | Compute second-order derivative of L_2 ^ (2 * n) distance
35 | """
36 | return 2 * n * (2 * n - 1) * (x - xi) ** (2 * n - 2)
37 |
38 |
39 | def kmeans_plusplus(
40 | X,
41 | n_clusters,
42 | x_squared_norms,
43 | dist,
44 | power=1,
45 | random_state=None,
46 | n_local_trials=None,
47 | save_running_results=False,
48 | high_precision=torch.float32,
49 | verbose=False,
50 | ):
51 | """
52 | Computational component for initialization of n_clusters by
53 | k-means++. Prior validation of data is assumed.
54 | Parameters
55 | ----------
56 | X : torch.tensor of shape (n_samples, n_features)
57 | The data to pick seeds for.
58 | n_clusters : int
59 | The number of seeds to choose.
60 | x_squared_norms : torch.tensor (n_samples,)
61 | Squared Euclidean norm of each data point.
62 | dist: str
63 | Type of distance function. Options are "l2" or "cos".
64 | power: int
65 | Distance is L_2 ^ (2 * power).
66 | random_state : RandomState instance
67 | The generator used to initialize the centers.
68 | See :term:`Glossary `.
69 | n_local_trials : int, default=None
70 | The number of seeding trials for each center (except the first),
71 | of which the one reducing inertia the most is greedily chosen.
72 | Set to None to make the number of trials depend logarithmically
73 | on the number of seeds (2+log(k)); this is the default.
74 | save_running_results: bool, default=False
75 | Whether to save temporary results during execution.
76 | high_precision: torch.Type
77 | type for high-precision computations.
78 | verbose: bool, default=False
79 |
80 | Returns
81 | -------
82 | centers : torch.tensor of shape (n_clusters, n_features)
83 | The initial centers for k-means.
84 | indices : ndarray of shape (n_clusters,)
85 | The index location of the chosen centers in the data array X. For a
86 | given index and center, X[index] = center.
87 | """
88 | if random_state is None:
89 | random_state = check_random_state(random_state)
90 |
91 | n_samples, n_features = X.shape
92 |
93 | centers = torch.empty((n_clusters, n_features), dtype=X.dtype).to(X.device)
94 | pots = torch.empty((n_clusters,), device=X.device, dtype=high_precision)
95 |
96 | # Set the number of local seeding trials if none is given
97 | if n_local_trials is None:
98 | n_local_trials = 2 + int(np.log(n_clusters))
99 |
100 | # Pick first center randomly and track index of point
101 | center_id = random_state.randint(n_samples)
102 | indices = np.full(n_clusters, -1, dtype=int)
103 | centers[0] = X[center_id]
104 | indices[0] = center_id
105 |
106 | # Initialize list of closest distances and calculate current potential
107 | closest_dist_sq = (
108 | kmg.compute_distance(X[center_id, None], X, x_squared_norms, dist)[0].type(
109 | high_precision
110 | )
111 | ** power
112 | )
113 | current_pot = closest_dist_sq.sum()
114 | pots[0] = current_pot
115 |
116 | # Pick the remaining n_clusters-1 points
117 | if verbose:
118 | iterates = tqdm(
119 | range(1, n_clusters),
120 | desc="Genralized kmeans++ initialization",
121 | file=sys.stdout,
122 | bar_format="{l_bar}{bar}{r_bar}",
123 | )
124 | else:
125 | iterates = range(1, n_clusters)
126 | for c in iterates:
127 | # Choose center candidates by sampling with probability proportional
128 | # to the distance to the closest existing center
129 | rand_vals = (
130 | torch.tensor(random_state.uniform(size=n_local_trials)).to(
131 | current_pot.device
132 | )
133 | * current_pot
134 | )
135 | candidate_ids = torch.searchsorted(
136 | torch.cumsum(closest_dist_sq, dim=0), rand_vals
137 | )
138 | # numerical imprecision can result in a candidate_id out of range
139 | torch.clip(candidate_ids, None, closest_dist_sq.shape[0] - 1, out=candidate_ids)
140 |
141 | # Compute distances to center candidates
142 | distance_to_candidates = (
143 | kmg.compute_distance(X[candidate_ids], X, x_squared_norms, dist).type(
144 | high_precision
145 | )
146 | ** power
147 | )
148 |
149 | # update closest distances squared and potential for each candidate
150 | torch.minimum(
151 | closest_dist_sq, distance_to_candidates, out=distance_to_candidates
152 | )
153 | candidates_pot = distance_to_candidates.sum(dim=1)
154 |
155 | # Decide which candidate is the best
156 | best_candidate = torch.argmin(candidates_pot)
157 | current_pot = candidates_pot[best_candidate]
158 | closest_dist_sq = distance_to_candidates[best_candidate]
159 | best_candidate = candidate_ids[best_candidate]
160 |
161 | # Permanently add best center candidate found in local tries
162 | centers[c] = X[best_candidate]
163 | indices[c] = best_candidate
164 | pots[c] = current_pot
165 |
166 | if save_running_results and c % 1000 == 0:
167 | np.save(
168 | "kmpp_running_results.npy",
169 | {"centers": centers.cpu().numpy(), "indices": indices, "iter": c},
170 | )
171 |
172 | return centers, indices
173 |
174 |
175 | def compute_centroids(X, n, n_iters=5, method="newton", verbose=False):
176 | """
177 | Compute k-means centroids given a set of points, according to distortion
178 | function L_2 ^ (2 * n), with Newton method.
179 | """
180 | if method == "newton":
181 | # Initialize the centroid with L_2^2 means.
182 | c = X.mean()
183 | if len(X) == 1:
184 | return c
185 | for _ in range(n_iters):
186 | if verbose:
187 | f = torch.sum(l2_squared_power(c, X, n))
188 | print(f, end=", ")
189 | der_f = torch.sum(l2_squared_power_der(c, X, n))
190 | der2_f = torch.sum(l2_squared_power_der2(c, X, n))
191 | if der_f == 0:
192 | break
193 | c -= der_f / der2_f
194 | return c
195 | else:
196 | raise ValueError("Method not supported!")
197 |
198 |
199 | def assign_clusters(X, centers, chunk_size=-1):
200 | """
201 | Assign points to centroids.
202 | """
203 | cluster_assignment = (
204 | kmg.assign_clusters(centers, X, "l2", chunk_size=chunk_size, verbose=False)
205 | .cpu()
206 | .numpy()
207 | )
208 | clusters = create_clusters_from_cluster_assignment(cluster_assignment, len(centers))
209 | return clusters
210 |
211 |
212 | def update_centroids(X, clusters, n):
213 | """
214 | Update centroids based on the new clusters after reassignment.
215 | """
216 | n_clusters = len(clusters)
217 | centers = torch.zeros((n_clusters, 1), device=X.device, dtype=X.dtype)
218 | for cid in range(n_clusters):
219 | if len(clusters[cid]) > 0:
220 | centers[cid, 0] = compute_centroids(X[clusters[cid]], n).item()
221 | return centers
222 |
223 |
224 | def generalized_kmeans_1d(
225 | X, n_clusters, n, n_iters=50, init_method="k-means++", chunk_size=-1
226 | ):
227 | """
228 | Run generalized k-means with distance L_2 ^ (2 * n)
229 | """
230 | assert X.ndim == 2
231 | # initialize
232 | if init_method == "k-means++":
233 | x_squared_norms = torch.linalg.vector_norm(X, dim=1) ** 2
234 | centers, _ = kmeans_plusplus(X, n_clusters, x_squared_norms, "l2", n)
235 | else:
236 | centers = X[np.random.choice(len(X), n_clusters, replace=False), :]
237 | clusters = assign_clusters(X, centers, chunk_size=chunk_size)
238 | for _ in tqdm(
239 | range(n_iters),
240 | desc="Generalized kmeans iterations",
241 | file=sys.stdout,
242 | bar_format="{l_bar}{bar}{r_bar}",
243 | ):
244 | centers = update_centroids(X, clusters, n)
245 | clusters = assign_clusters(X, centers, chunk_size=chunk_size)
246 | return centers, clusters
247 |
--------------------------------------------------------------------------------
/define_domains/prompt_classify.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | OUTLINES_CACHE_DIR=/tmp/outlines python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000
4 | python prompt_classify