├── snip_dedup ├── __init__.py ├── cli.py ├── _cli_helper.py ├── snip_download.py ├── snip_index.py └── snip_compress.py ├── ruff.toml ├── hatch.toml ├── LICENSE ├── pyproject.toml ├── docs └── tuto_build_index.md ├── retrieve_dup_urls_demo.py ├── .github └── workflows │ └── ci.yml ├── .gitignore └── README.md /snip_dedup/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | # Config file for the ruff linting tool 2 | # https://github.com/charliermarsh/ruff 3 | 4 | # Rules to ignore 5 | ignore = [ 6 | "E501", # line length violations 7 | ] 8 | -------------------------------------------------------------------------------- /snip_dedup/cli.py: -------------------------------------------------------------------------------- 1 | """cli entry point""" 2 | 3 | from snip_dedup.snip_download import snip_download 4 | from snip_dedup.snip_compress import snip_compress 5 | from snip_dedup.snip_index import snip_index 6 | import fire 7 | 8 | 9 | def main(): 10 | """Main entry point""" 11 | fire.Fire( 12 | { 13 | "download": snip_download, 14 | "compress": snip_compress, 15 | "index": snip_index, 16 | } 17 | ) 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /hatch.toml: -------------------------------------------------------------------------------- 1 | # Config file specific to the hatch project manager. 2 | # Contains build rules and other scripts. 3 | 4 | # Build source distribution 5 | [build.targets.sdist] 6 | exclude = [ 7 | "/.github", 8 | "/docs", 9 | ] 10 | 11 | # Build wheel distribution 12 | [build.targets.wheel] 13 | packages = ["snip_dedup"] 14 | 15 | # Check linting and formatting 16 | # These scripts need "pyright", "ruff" and "black" to be installed 17 | [envs.default.scripts] 18 | check = "pyright --warnings snip_dedup/" 19 | lint = "ruff check snip_dedup/" 20 | format = "black --check snip_dedup/" 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Ryan Webster, Matthieu Pizenberg 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "snip-dedup" 7 | version = "0.0.4" 8 | description = 'SNIP: compact index for large dataset' 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | license = "MIT" 12 | keywords = ["snip", "deduplicate", "index", "laion", "machine learning", "computer vision", "dataset"] 13 | authors = [ 14 | { name = "Ryan Webster", email = "rwebstr@gmail.com" }, 15 | { name = "Matthieu Pizenberg", email = "matthieu.pizenberg@gmail.com" }, 16 | ] 17 | classifiers = [ 18 | "Environment :: GPU :: NVIDIA CUDA", 19 | "License :: OSI Approved :: MIT License", 20 | "Programming Language :: Python :: 3", 21 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 22 | ] 23 | dependencies = [ 24 | "fastparquet === 2023.2.0", 25 | "fire == 0.5.*", 26 | "numpy >= 1.24.2, < 2.0", 27 | "pandas >= 1.5.3, < 2.0", 28 | "requests >= 2.28.2, < 3.0", 29 | "torch >= 1.13.1, < 2.0", 30 | "faiss-gpu >= 1.7.2, < 2.0" 31 | ] 32 | 33 | [project.urls] 34 | Documentation = "https://github.com/ryanwebster90/snip-dedup#readme" 35 | Source = "https://github.com/ryanwebster90/snip-dedup" 36 | 37 | [project.scripts] 38 | snip = "snip_dedup.cli:main" 39 | -------------------------------------------------------------------------------- /snip_dedup/_cli_helper.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import string 3 | 4 | 5 | def validate_parts(parts): 6 | if type(parts) is int: 7 | sys.exit( 8 | "Single value is not accepted for --parts as it's ambiguous between wanting only that exact part or that number of parts starting from 0. Please use a range instead like 0:2" 9 | ) 10 | parts_bounds = parts.split(":") 11 | try: 12 | parts_bounds = [int(part) for part in parts_bounds] 13 | except Exception: 14 | sys.exit( 15 | f'The parts pattern "{parts}" is not valid. It should be a valid range such as "0:2" or "14:42"' 16 | ) 17 | if len(parts_bounds) == 0: 18 | sys.exit("The --parts argument cannot be empty") 19 | elif len(parts_bounds) == 1: 20 | sys.exit( 21 | "Single value is not accepted for --parts as it's ambiguous between wanting only that exact part or that number of parts starting from 0. Please use a range instead like 0:2" 22 | ) 23 | elif len(parts_bounds) > 2: 24 | sys.exit( 25 | "Ranges with more than 2 parts, such as 0:2:14 are not valid for --parts. Please limit yourself with simple ranges such as 0:14" 26 | ) 27 | start_part, end_part = parts_bounds 28 | if start_part < 0 or end_part < 0: 29 | sys.exit("Only positive integers are allowed for --parts, such as 0:14") 30 | if end_part <= start_part: 31 | sys.exit( 32 | 'The --parts argument must be of the shape "s:e" with s < e, such as "0:1" or "14:42". The "e" bound is excluded.' 33 | ) 34 | return start_part, end_part 35 | 36 | 37 | def validate_part_format(pattern): 38 | format_variables = [ 39 | tup[1] for tup in string.Formatter().parse(pattern) if tup[1] is not None 40 | ] 41 | if len(format_variables) != 1 or format_variables[0] != "part": 42 | sys.exit( 43 | f'Your pattern "{pattern}" is not valid as it should contain the "part" variable such as "{{part:04d}}.npy".' 44 | ) 45 | -------------------------------------------------------------------------------- /snip_dedup/snip_download.py: -------------------------------------------------------------------------------- 1 | """snip download""" 2 | 3 | import requests 4 | import os 5 | import os.path 6 | import fire 7 | import numpy as np 8 | import pandas as pd 9 | 10 | 11 | def snip_download(outfolder="data/downloaded", start=0, end=2313, dl_dedup_set=True): 12 | """Download and deduplicate a dataset. 13 | 14 | Parameters 15 | ---------- 16 | outfolder : str, optional 17 | Where to put the downloaded metadata 18 | start : int, optional 19 | Start index of the metadata 20 | end : int, optional 21 | End index of the metadata 22 | dl_dedup_set : bool, optional 23 | Indicate whether you'll download the dedup set again (2GB) 24 | """ 25 | metadata_dir = os.path.join(outfolder, "metadata") 26 | dedup_set_path = os.path.join( 27 | outfolder, "is_dup_mlp_1024_128_gelu_snn_2layer_notext.npy" 28 | ) 29 | os.makedirs(metadata_dir, exist_ok=True) 30 | 31 | if dl_dedup_set: 32 | print("downloading dedup set...") 33 | url = "https://huggingface.co/datasets/fraisdufour/snip-dedup/resolve/main/is_dup_mlp_1024_128_gelu_snn_2layer_notext.npy" 34 | response = requests.get(url) 35 | open(dedup_set_path, "wb").write(response.content) 36 | 37 | is_dup_all = np.load(dedup_set_path).ravel() 38 | abs_ind = 0 39 | for n in range(start, end): 40 | print(f"downloading metadata file {n}/{end}") 41 | url = f"https://huggingface.co/datasets/laion/laion2b-en-vit-h-14-embeddings/resolve/main/metadata/metadata_{n:04d}.parquet" 42 | response = requests.get(url) 43 | parquet_path = os.path.join(metadata_dir, f"metadata_{n:04d}.parquet") 44 | open(parquet_path, "wb").write(response.content) 45 | 46 | # perform the deduplication 47 | md = pd.read_parquet(parquet_path) 48 | non_dup_chunk = is_dup_all[abs_ind : abs_ind + len(md.index)] 49 | 50 | # take only non-dupped (uniques) 51 | non_dup_chunk = np.logical_not(non_dup_chunk) 52 | 53 | # make sure there is at least one unique 54 | non_dup_chunk[0] = True 55 | md = md[non_dup_chunk] 56 | 57 | # overwrite metadata 58 | md.to_parquet(parquet_path) 59 | abs_ind += len(md.index) 60 | 61 | 62 | if __name__ == "__main__": 63 | fire.Fire(snip_download) 64 | -------------------------------------------------------------------------------- /docs/tuto_build_index.md: -------------------------------------------------------------------------------- 1 | # Build your own SNIP index from CLIP features 2 | 3 | This tutorial aims at being a friendly tour around the `snip` commands enabling building an index. 4 | For this tutorial, we use the `laion-2b-en-vit-l-14` dataset. 5 | 6 | Let's start by creating a dedicated virtual environment for this tutorial. 7 | 8 | ```sh 9 | # Create and activate a virtual environment 10 | mkdir snip_index_tuto 11 | cd snip_index_tuto 12 | python -m venv snip_tuto_venv 13 | source snip_tuto_venv/bin/activate # adapt to your OS/shell 14 | ``` 15 | 16 | Let's continue with an installation of `snip` with the `snip-dedup` package. 17 | 18 | ```sh 19 | # Install snip 20 | pip install snip-dedup 21 | snip --help 22 | ``` 23 | 24 | Alright, now we need to download the pre-required files for this tutorial: 25 | 26 | - The `laion-2b-en-vit-l-14` CLIP embeddings 27 | - The SNIP corresponding model 28 | - The SNIP base index for that model 29 | 30 | ```sh 31 | # Create directory structure for required files to download 32 | mkdir laion-2b-en-vit-l-14 33 | cd laion-2b-en-vit-l-14 34 | mkdir snip_models 35 | mkdir clip_feats 36 | 37 | # Download the CLIP features 38 | cd clip_feats 39 | for i in $(seq -f "%04g" 0 1); do curl -fLO "https://huggingface.co/datasets/laion/laion2b-en-vit-l-14-embeddings/resolve/main/img_emb/img_emb_$i.npy"; done 40 | 41 | # Download the SNIP model and base index 42 | cd ../snip_models 43 | curl -fLO https://huggingface.co/datasets/fraisdufour/snip-dedup/resolve/main/models/snip_vitl14_128_deep.pth 44 | curl -fLO https://huggingface.co/datasets/fraisdufour/snip-dedup/resolve/main/index/snip_vitl14_deep_IVFPQ_M4_base.index 45 | ``` 46 | 47 | We are now ready to use the `snip` commands. 48 | We start by compressing the CLIP features with SNIP. 49 | 50 | ```sh 51 | # Compress CLIP features with SNIP 52 | cd .. 53 | snip compress --help # display the help for the snip compress command 54 | snip compress \ 55 | --snip_model_path snip_models/snip_vitl14_128_deep.pth \ 56 | --parts 0:2 \ 57 | --clip_feats clip_feats/img_emb_{part:04d}.npy \ 58 | --snip_feats_out snip_feats/{part:04d}.npy 59 | ``` 60 | 61 | Finally, after compressing with SNIP, we can build our index. 62 | Since the index is much smaller than the features, we can group multiple parts in each index shard. 63 | In this example, we group them by 2 with `--shard_size 2` 64 | 65 | ```sh 66 | # Build the index for the SNIP features 67 | snip index --help # display the help for the snip index command 68 | snip index \ 69 | --parts 0:2 \ 70 | --snip_feats snip_feats/{part:04d}.npy \ 71 | --snip_base_index_path snip_models/snip_vitl14_deep_IVFPQ_M4_base.index \ 72 | --index_outdir snip_index \ 73 | --shard_size 2 74 | # will build file snip_index/0000_0001.index 75 | ``` 76 | 77 | That's it! 78 | You now have a sharded compressed SNIP index for the laion2b-vit-l-14 model. 79 | -------------------------------------------------------------------------------- /snip_dedup/snip_index.py: -------------------------------------------------------------------------------- 1 | """snip index""" 2 | 3 | import sys 4 | import os 5 | import os.path 6 | import fire 7 | import numpy as np 8 | import faiss 9 | 10 | from . import _cli_helper 11 | 12 | 13 | def snip_index( 14 | parts="0:2", 15 | snip_feats="snip_feats/{part:04d}.npy", 16 | snip_base_index_path="snip_models/snip_vitl14_deep_IVFPQ_M4_base.index", 17 | index_outdir="snip_index", 18 | shard_size=1, 19 | ): 20 | """Build a sharded index from SNIP compressed features 21 | 22 | Parameters 23 | ---------- 24 | parts : str 25 | Parts to index, using a slice notation, such as 0:2 or 14:42 26 | snip_feats : str 27 | Pattern referencing the path to the SNIP features parts. 28 | You are expected to use the "part" variable with formatting options, such as "{part:03d}.npy" which will be replaced by "001.npy" when part==1. 29 | snip_base_index_path : str 30 | Path to the base index, might be something like: snip_models/snip_vitl14_deep_IVFPQ_M4_base.index 31 | index_outdir : str 32 | Directory where the computed index parts will be saved. 33 | shard_size : int 34 | Number of SNIP parts to group per index shard. 35 | Since the index is much smaller than the features, we can pack many feature parts in a single index shard. 36 | """ 37 | # Check that the path for the base index passed as argument is valid 38 | if not os.path.isfile(snip_base_index_path): 39 | sys.exit( 40 | f'The base index file "{snip_base_index_path}" does not exist or is not readable.' 41 | ) 42 | 43 | # Check that the parts argument is correct 44 | start_part, end_part = _cli_helper.validate_parts(parts) 45 | 46 | # Check that the SNIP features exist 47 | _cli_helper.validate_part_format(snip_feats) 48 | 49 | # Check that the starting SNIP feature part is a multiple of the index shard size. 50 | # Otherwise, it's probably an off-by-one mistake 51 | if start_part % shard_size != 0: 52 | sys.exit( 53 | f"WARNING: your starting SNIP part ({start_part}) is not a multiple of your packing argument ({shard_size}). You might be doing a mistake so please double check." 54 | ) 55 | 56 | # TODO: add option for cpu (will be quite slow however) 57 | res = faiss.StandardGpuResources() 58 | 59 | # Create the output directory for computed index shards 60 | os.makedirs(index_outdir, exist_ok=True) 61 | 62 | # Group parts into shards 63 | parts_range = list(range(start_part, end_part)) 64 | grouped_parts = [ 65 | parts_range[i : i + shard_size] for i in range(0, len(parts_range), shard_size) 66 | ] 67 | 68 | # Load SNIP base index 69 | base_index = faiss.read_index(snip_base_index_path) 70 | 71 | # Compute an index for each SNIP parts group 72 | for parts in grouped_parts: 73 | index = faiss.index_cpu_to_gpu(res, 0, base_index) 74 | for snip_part_id in parts: 75 | print(f"Indexing SNIP part {snip_part_id} ...") 76 | snip_part = np.load(snip_feats.format(part=snip_part_id)) 77 | index.add(snip_part) 78 | group_str = "_".join([f"{id:04d}" for id in parts]) 79 | print(f"Writing index for parts {parts} ...") 80 | faiss.write_index( 81 | faiss.index_gpu_to_cpu(index), 82 | os.path.join(index_outdir, f"{group_str}.index"), 83 | ) 84 | 85 | 86 | if __name__ == "__main__": 87 | fire.Fire(snip_index) 88 | -------------------------------------------------------------------------------- /snip_dedup/snip_compress.py: -------------------------------------------------------------------------------- 1 | """snip compress""" 2 | import os 3 | import os.path 4 | import sys 5 | import fire 6 | import numpy as np 7 | import torch 8 | 9 | from . import _cli_helper 10 | 11 | 12 | # compute features over chunks 13 | @torch.no_grad() 14 | def compute_feats_for_chunk(net, chunk, batch_size=256): 15 | feats = [] 16 | for b in range(0, chunk.shape[0], batch_size): 17 | end_ind = min(b + batch_size, chunk.shape[0]) 18 | batch = chunk[b:end_ind, :] 19 | feats += [net(torch.from_numpy(batch).float().cuda()).cpu().numpy()] 20 | feats = np.concatenate(feats, axis=0) 21 | return feats 22 | 23 | 24 | def snip_compress( 25 | snip_model_path="snip_models/snip_vitl14_128_deep.pth", 26 | parts="0:2", 27 | clip_feats="clip_feats/{part:04d}.npy", 28 | snip_feats_out="snip_feats/{part:04d}.npy", 29 | ): 30 | """Compress frozen CLIP features with SNIP 31 | 32 | Parameters 33 | ---------- 34 | snip_model_path : str 35 | Path to the SNIP model file to use, might be something like: snip_models/snip_vitl14_128_deep.pth 36 | parts : str 37 | Parts to compress, using a slice notation, such as 0:2 or 14:42 38 | clip_feats : str 39 | Pattern referencing the path to the CLIP features parts. 40 | You are expected to use the "part" variable with formatting options, such as "{part:03d}.npy" which will be replaced by "001.npy" when part==1. 41 | snip_feats_out : str 42 | Pattern referencing the path to the SNIP compressed features parts that will be computed. 43 | You are expected to use the "part" variable with formatting options, such as "{part:03d}.npy" which will be replaced by "001.npy" when part==1. 44 | """ 45 | # Check that the SNIP model path passed as argument is valid 46 | if not os.path.isfile(snip_model_path): 47 | sys.exit( 48 | f'The SNIP model file "{snip_model_path}" does not exist or is not readable.' 49 | ) 50 | 51 | # Check that the parts argument is correct 52 | start_part, end_part = _cli_helper.validate_parts(parts) 53 | 54 | # Check that the CLIP features exist 55 | _cli_helper.validate_part_format(clip_feats) 56 | for part in range(start_part, end_part): 57 | clip_part_path = clip_feats.format(part=part) 58 | if not os.path.isfile(clip_part_path): 59 | sys.exit( 60 | f'The CLIP file for part {part} does not exist: "{clip_part_path}"' 61 | ) 62 | 63 | # Create directory for the computed SNIP parts 64 | _cli_helper.validate_part_format(snip_feats_out) 65 | try: 66 | first_snip_part_path = snip_feats_out.format(part=start_part) 67 | snip_out_parent_dir = os.path.dirname(first_snip_part_path) 68 | os.makedirs(snip_out_parent_dir, exist_ok=True) 69 | except Exception: 70 | sys.exit( 71 | f'Something is wrong with the output file paths specified for --snip_feats_out "{snip_feats_out}"' 72 | ) 73 | 74 | # Load SNIP net 75 | net = torch.load(snip_model_path).eval().cuda() 76 | 77 | # Compute SNIP features for all parts 78 | print( 79 | f"Start computing SNIP features for parts {start_part} to {end_part} (excluded)" 80 | ) 81 | for part in range(start_part, end_part): 82 | print(f" Computing SNIP features for part {part} ...") 83 | # this is normally the bottleneck 84 | clip_part_path = clip_feats.format(part=part) 85 | clip_part = np.load(clip_part_path) 86 | # compute SNIP features 87 | snip_feats = compute_feats_for_chunk(net, clip_part) 88 | snip_part_path = snip_feats_out.format(part=part) 89 | np.save(snip_part_path, snip_feats) 90 | 91 | 92 | if __name__ == "__main__": 93 | fire.Fire(snip_compress) 94 | -------------------------------------------------------------------------------- /retrieve_dup_urls_demo.py: -------------------------------------------------------------------------------- 1 | import faiss 2 | import numpy as np 3 | import glob 4 | import time 5 | import torch 6 | import fire 7 | def abs_ind_to_feat_file(abs_ind, cum_sz, feat_files=None): 8 | inds = np.argwhere(abs_ind - cum_sz >= 0) 9 | last_ind = inds[-1].item() 10 | ind_offset = cum_sz[last_ind] 11 | local_ind = abs_ind - ind_offset 12 | if feat_files is not None: 13 | ff = feat_files[last_ind] 14 | else: 15 | ff=None 16 | return ff,last_ind,local_ind 17 | 18 | def get_cum_sz(feat_files): 19 | cum_sz = [0] 20 | for feat in feat_files: 21 | cum_sz += [cum_sz[-1] + np.load(feat,mmap_mode='r').shape[0]] 22 | cum_sz = np.array(cum_sz).astype('int') 23 | return cum_sz 24 | 25 | def get_emb(ff,local_ind): 26 | return np.load(ff,mmap_mode='r')[local_ind,:] 27 | 28 | def retrieve_duplicate_urls(feats_path, metadata_path,net_path,index_path='mlp_1024_128_gelu_snn_2layer_notext_l2b_vith14_merged.index',cum_sz_file='cum_sz_feats.npy',dup_file = 'is_dup_mlp_1024_128_gelu_snn_2layer_notext.npy'): 29 | 30 | index = faiss.read_index(index_path) 31 | index.nprobe = 1 32 | print('index loaded') 33 | 34 | net = torch.load(net_path).eval().cuda() 35 | 36 | import pandas as pd 37 | feat_files = sorted(glob.glob(feats_path + '*npy')) 38 | md = sorted(glob.glob(metadata_path + '*.parquet')) 39 | 40 | is_dup_all = np.load(dup_file) 41 | cum_sz = get_cum_sz(feat_files) 42 | 43 | n_eval = 1000 44 | inds = np.argwhere(is_dup_all).ravel() 45 | 46 | r_sample = np.random.randint(0,inds.shape[0], (n_eval,)) 47 | inds = inds[r_sample] 48 | md_text = open('duplicate_url_pairs.txt','a+') 49 | 50 | thresh_raw = 1e-1 51 | all_tf = np.full( (n_eval,),False,dtype=bool) 52 | for ii,k in enumerate(inds): 53 | ff,li,lci = abs_ind_to_feat_file(k,cum_sz,feat_files) 54 | if li < len(md): 55 | try: 56 | # certain metadata entries throw errors, not sure why 57 | url = list(pd.read_parquet(md[li])["url"])[lci] 58 | except Exception: 59 | url = None 60 | else: 61 | # note this won't happen if you have all the metadata 62 | continue 63 | 64 | raw_feat = get_emb(ff,lci).reshape(1,-1) 65 | 66 | with torch.no_grad(): 67 | feat_snip = net(torch.from_numpy(raw_feat).float().cuda()).cpu().numpy() 68 | 69 | d,i = index.search(feat_snip,6) 70 | nn = i[0,1] 71 | if nn == k: 72 | print('same nn retrieved, skipping...') 73 | # Note, this does not effect our de-dup precision 74 | # but just an artifact of bitwise duplicates, will be fixed later 75 | 76 | # only fetch for metadata, if you have all feature file syou can enable gt computation 77 | ff,li,lci = abs_ind_to_feat_file(nn, cum_sz, None) 78 | 79 | # if you have all the feats, go ahead and compute "ground truth" 80 | # was used for our precision calculation (see paper) 81 | # nn_feat = get_emb(ff,lci) 82 | # mse = ((raw_feat - nn_feat)**2).sum() 83 | 84 | # is_dup = 'gt dup' if mse < thresh_raw else 'nondup' 85 | if li < len(md): 86 | try: 87 | if url is not None: 88 | # md_text.write(is_dup + f'\n') 89 | url1 = list(pd.read_parquet(md[li])["url"])[lci] 90 | if url1 is not None: 91 | md_text.write(url + '\n') 92 | md_text.write(url1 + '\n') 93 | md_text.write('\n') 94 | except Exception: 95 | print('failed to parquet') 96 | 97 | 98 | if __name__ == "__main__": 99 | fire.Fire(retrieve_duplicate_urls) 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [ main ] 7 | 8 | env: 9 | PYTHON_VERSION: '3.8' 10 | TOOLING_CACHE_KEY: '006' # increment to force new cache creation 11 | PIPX_HOME: /home/runner/.local/pipx 12 | PIPX_BIN_DIR: /home/runner/.local/bin 13 | 14 | jobs: 15 | tooling-install: 16 | name: Install and cache tooling for all other jobs 17 | runs-on: ubuntu-latest 18 | steps: 19 | # install python 20 | - uses: actions/setup-python@v4 21 | with: 22 | python-version: ${{ env.PYTHON_VERSION }} 23 | 24 | # cache tooling installation 25 | - uses: actions/cache@v3 26 | with: 27 | path: | 28 | ~/.cache/pip 29 | ~/.local/pipx 30 | ~/.local/bin 31 | key: tooling-${{ env.TOOLING_CACHE_KEY }} 32 | 33 | # install tooling via pipx to have them isolated 34 | - name: Install pipx & tooling 35 | run: | 36 | python -m pip install --user pipx 37 | python -m pipx ensurepath 38 | pipx install hatch 39 | pipx install pyright 40 | pipx install ruff 41 | 42 | build: 43 | name: Build snip-dedup package 44 | needs: tooling-install 45 | runs-on: ubuntu-latest 46 | steps: 47 | # git clone the repository 48 | - uses: actions/checkout@v3 49 | 50 | # install python 51 | - uses: actions/setup-python@v4 52 | with: 53 | python-version: ${{ env.PYTHON_VERSION }} 54 | 55 | # restore tooling installation cache 56 | - uses: actions/cache/restore@v3 57 | with: 58 | path: | 59 | ~/.cache/pip 60 | ~/.local/pipx 61 | ~/.local/bin 62 | key: tooling-${{ env.TOOLING_CACHE_KEY }} 63 | restore-keys: tooling 64 | 65 | # install hatch via pipx (where we cached it) 66 | - name: Install hatch 67 | run: | 68 | python -m pip install --user pipx 69 | python -m pipx ensurepath 70 | pipx install hatch 71 | 72 | # cache dependencies 73 | - uses: actions/cache@v3 74 | with: 75 | path: ~/.local/share/hatch 76 | key: build-${{ hashFiles('pyproject.toml', 'hatch.toml') }} 77 | 78 | # build the python package 79 | - name: Build the snip-dedup package 80 | run: hatch build 81 | 82 | check-format: 83 | name: Check that code is formatted with Black 84 | runs-on: ubuntu-latest 85 | steps: 86 | - uses: actions/checkout@v3 87 | - uses: psf/black@stable 88 | with: 89 | src: "./snip_dedup/" 90 | 91 | check-pyright: 92 | name: Check that the code passes all the pyright checks 93 | needs: tooling-install 94 | runs-on: ubuntu-latest 95 | steps: 96 | # git clone the repository 97 | - uses: actions/checkout@v3 98 | 99 | # install python 100 | - uses: actions/setup-python@v4 101 | with: 102 | python-version: ${{ env.PYTHON_VERSION }} 103 | 104 | # restore tooling installation cache 105 | - uses: actions/cache/restore@v3 106 | with: 107 | path: | 108 | ~/.cache/pip 109 | ~/.local/pipx 110 | ~/.local/bin 111 | key: tooling-${{ env.TOOLING_CACHE_KEY }} 112 | restore-keys: tooling 113 | 114 | # install pyright via pipx (where we cached it) 115 | - name: Install pyright 116 | run: | 117 | python -m pip install --user pipx 118 | python -m pipx ensurepath 119 | pipx install hatch 120 | pipx install pyright 121 | 122 | # cache dependencies 123 | - uses: actions/cache@v3 124 | with: 125 | path: ~/.local/share/hatch 126 | key: pyright-${{ hashFiles('pyproject.toml', 'hatch.toml') }} 127 | 128 | # check code with pyright 129 | - name: Check the package code with pyright 130 | run: hatch run check 131 | 132 | check-lint: 133 | name: Check that the code pass the linter checks 134 | runs-on: ubuntu-latest 135 | steps: 136 | - uses: actions/checkout@v3 137 | - uses: actions/setup-python@v4 138 | with: 139 | python-version: ${{ env.PYTHON_VERSION }} 140 | - run: pip install ruff 141 | - run: ruff check --format=github snip_dedup/ 142 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Default data folder 2 | data 3 | 4 | # Created by https://www.toptal.com/developers/gitignore/api/windows,linux,osx,python 5 | # Edit at https://www.toptal.com/developers/gitignore?templates=windows,linux,osx,python 6 | 7 | ### Linux ### 8 | *~ 9 | 10 | # temporary files which can be created if a process still has a handle open of a deleted file 11 | .fuse_hidden* 12 | 13 | # KDE directory preferences 14 | .directory 15 | 16 | # Linux trash folder which might appear on any partition or disk 17 | .Trash-* 18 | 19 | # .nfs files are created when an open file is removed but is still being accessed 20 | .nfs* 21 | 22 | ### OSX ### 23 | # General 24 | .DS_Store 25 | .AppleDouble 26 | .LSOverride 27 | 28 | # Icon must end with two \r 29 | Icon 30 | 31 | 32 | # Thumbnails 33 | ._* 34 | 35 | # Files that might appear in the root of a volume 36 | .DocumentRevisions-V100 37 | .fseventsd 38 | .Spotlight-V100 39 | .TemporaryItems 40 | .Trashes 41 | .VolumeIcon.icns 42 | .com.apple.timemachine.donotpresent 43 | 44 | # Directories potentially created on remote AFP share 45 | .AppleDB 46 | .AppleDesktop 47 | Network Trash Folder 48 | Temporary Items 49 | .apdisk 50 | 51 | ### Python ### 52 | # Byte-compiled / optimized / DLL files 53 | __pycache__/ 54 | *.py[cod] 55 | *$py.class 56 | 57 | # C extensions 58 | *.so 59 | 60 | # Distribution / packaging 61 | .Python 62 | build/ 63 | develop-eggs/ 64 | dist/ 65 | downloads/ 66 | eggs/ 67 | .eggs/ 68 | lib/ 69 | lib64/ 70 | parts/ 71 | sdist/ 72 | var/ 73 | wheels/ 74 | share/python-wheels/ 75 | *.egg-info/ 76 | .installed.cfg 77 | *.egg 78 | MANIFEST 79 | 80 | # PyInstaller 81 | # Usually these files are written by a python script from a template 82 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 83 | *.manifest 84 | *.spec 85 | 86 | # Installer logs 87 | pip-log.txt 88 | pip-delete-this-directory.txt 89 | 90 | # Unit test / coverage reports 91 | htmlcov/ 92 | .tox/ 93 | .nox/ 94 | .coverage 95 | .coverage.* 96 | .cache 97 | nosetests.xml 98 | coverage.xml 99 | *.cover 100 | *.py,cover 101 | .hypothesis/ 102 | .pytest_cache/ 103 | cover/ 104 | 105 | # Translations 106 | *.mo 107 | *.pot 108 | 109 | # Django stuff: 110 | *.log 111 | local_settings.py 112 | db.sqlite3 113 | db.sqlite3-journal 114 | 115 | # Flask stuff: 116 | instance/ 117 | .webassets-cache 118 | 119 | # Scrapy stuff: 120 | .scrapy 121 | 122 | # Sphinx documentation 123 | docs/_build/ 124 | 125 | # PyBuilder 126 | .pybuilder/ 127 | target/ 128 | 129 | # Jupyter Notebook 130 | .ipynb_checkpoints 131 | 132 | # IPython 133 | profile_default/ 134 | ipython_config.py 135 | 136 | # pyenv 137 | # For a library or package, you might want to ignore these files since the code is 138 | # intended to run in multiple environments; otherwise, check them in: 139 | # .python-version 140 | 141 | # pipenv 142 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 143 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 144 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 145 | # install all needed dependencies. 146 | #Pipfile.lock 147 | 148 | # poetry 149 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 150 | # This is especially recommended for binary packages to ensure reproducibility, and is more 151 | # commonly ignored for libraries. 152 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 153 | #poetry.lock 154 | 155 | # pdm 156 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 157 | #pdm.lock 158 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 159 | # in version control. 160 | # https://pdm.fming.dev/#use-with-ide 161 | .pdm.toml 162 | 163 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 164 | __pypackages__/ 165 | 166 | # Celery stuff 167 | celerybeat-schedule 168 | celerybeat.pid 169 | 170 | # SageMath parsed files 171 | *.sage.py 172 | 173 | # Environments 174 | .env 175 | .venv 176 | env/ 177 | venv/ 178 | ENV/ 179 | env.bak/ 180 | venv.bak/ 181 | 182 | # Spyder project settings 183 | .spyderproject 184 | .spyproject 185 | 186 | # Rope project settings 187 | .ropeproject 188 | 189 | # mkdocs documentation 190 | /site 191 | 192 | # mypy 193 | .mypy_cache/ 194 | .dmypy.json 195 | dmypy.json 196 | 197 | # Pyre type checker 198 | .pyre/ 199 | 200 | # pytype static type analyzer 201 | .pytype/ 202 | 203 | # Cython debug symbols 204 | cython_debug/ 205 | 206 | # PyCharm 207 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 208 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 209 | # and can be added to the global gitignore or merged into this file. For a more nuclear 210 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 211 | #.idea/ 212 | 213 | ### Python Patch ### 214 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 215 | poetry.toml 216 | 217 | # ruff 218 | .ruff_cache/ 219 | 220 | ### Windows ### 221 | # Windows thumbnail cache files 222 | Thumbs.db 223 | Thumbs.db:encryptable 224 | ehthumbs.db 225 | ehthumbs_vista.db 226 | 227 | # Dump file 228 | *.stackdump 229 | 230 | # Folder config file 231 | [Dd]esktop.ini 232 | 233 | # Recycle Bin used on file shares 234 | $RECYCLE.BIN/ 235 | 236 | # Windows Installer files 237 | *.cab 238 | *.msi 239 | *.msix 240 | *.msm 241 | *.msp 242 | 243 | # Windows shortcuts 244 | *.lnk 245 | 246 | # End of https://www.toptal.com/developers/gitignore/api/windows,linux,osx,python 247 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # snip-dedup 2 | 3 | [![PyPI - Version](https://img.shields.io/pypi/v/snip-dedup.svg?logo=pypi&label=PyPI&logoColor=gold)](https://pypi.org/project/snip-dedup/) 4 | [![linting - Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v0.json)](https://github.com/charliermarsh/ruff) 5 | [![format - Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) 6 | [![license - MIT](https://img.shields.io/badge/license-MIT-9400d3.svg)](https://spdx.org/licenses/) 7 | [![license - MIT](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1nKccWcCz566qDg3AohTV-zjBn7u_INnG?usp=sharing) 8 | 9 | ## This repo is a WIP 10 | 11 | You no longer can filter the LAION dataset to remove duplicates, as LAION disabled the webdataset on huggingface. I'll focus on adding some functionality for deduplication for future webdatasets using clip features. 12 | 13 | - [ ] Compress features using pretrained SNIP networks (for ViT-H-14, ViT-L14, ViT-B-32) 14 | - [x] Read our research paper 15 | - [ ] Train SNIP on your CLIP features 16 | - [ ] Run a de-duplication of your dataset using our de-dup code 17 | 18 | SNIP is a technique to compress CLIP features. It is competitive with previous works for large scale retrieval of deep features, and has some nice properties for multi-modal features. Read more about it [here](https://arxiv.org/abs/2303.12733). 19 | 20 | We used SNIP together with the faiss library to deduplicate a billions scale dataset, and found a high level of duplication (roughly 700M / 2 billion). This webdataset is no longer being distributed by laion. 21 | 22 | ## Install 23 | 24 | ```sh 25 | pip install --upgrade snip-dedup 26 | ``` 27 | 28 | ## Usage 29 | 30 | ```sh 31 | # List available commands 32 | snip --help 33 | snip download --help 34 | 35 | # Download and deduplicate the 10 first shards of the dataset 36 | snip download --start 0 --end 10 37 | ``` 38 | 39 | Then, you may download (deduplicated) laion2b images with the awesome [img2dataset](https://github.com/rom1504/img2dataset). 40 | 41 | See the colab [![license - MIT](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1nKccWcCz566qDg3AohTV-zjBn7u_INnG?usp=sharing) for a demo on search. 42 | 43 | ## What is a Duplicate? 44 | 45 | In our first iteration, we merely marked duplicates pairwise, and remove one sample from a duplicate pair (the above code downloads a binary array, for samples to remove). In our latest run, we recorded the entire adjacency matrix of duplication. For instance, suppose SNIP has labeled feature $k$ as a duplicate with feature $j$. Then $A[k,j] = A[j,k] = 1$ in the adjacency matrix. We're currently having trouble computing the full connected components of this matrix, see [this issue](https://github.com/ryanwebster90/snip-dedup/issues/7#issue-1639736690). 46 | 47 | If you allow connected components with only one node, Then to compute the number of "unique" samples, you simply take one from each duplicate set, say $|\mathcal{C}|$ sets, with $N$ nodes is $D := N - |\mathcal{C}|$ duplicates. 48 | 49 | ### Approximate CCs of Duplicates 50 | 51 | Currently, we have an approximation of the CC of the duplicates. During the de-duplication, we label nodes as follows. Suppose we are at node $n$, the pseudo code for one step of labeling is calculated as 52 | ```python 53 | labels = np.arange(0,N) 54 | ... 55 | d,i = index.search(feats[n,:],k) 56 | dups = get_dups(d,i) #Use adaptive threshhold on ADC (see paper) 57 | label[dups] = resolve_labels_one_step(dups) 58 | ``` 59 | Where `N` is number of nodes (2B for L2B). Here `resolve_labels_one_step` will simply re-write any node that is unlabeled to be the current node $n$. This can be thought of as a tree. We then connect nodes with common ancestors with a fixed point 60 | ```python 61 | while True: 62 | label = label[label] 63 | ``` 64 | 65 | The labels of the above loop can be found on huggingface [vitl14_labels](https://huggingface.co/datasets/fraisdufour/snip-dedup/resolve/main/representatives/representatives_vitl14_fixed_pt.npy). 66 | 67 | Other: 68 | 69 | [cumulative sizes of features (for indexing sharded files)](https://drive.google.com/file/d/1OdVt5rjYw55XfMhsQSdqcVOP7lG2qj4W/view?usp=sharing) 70 | 71 | ## Finding images overfit by Stable Diffusion 72 | 73 | By analyzing the most duplicated images, we have found several more images verbatim copied by Stable Diffusion, posing a copyright problem: 74 | 75 | ![sylvester_overfit](https://user-images.githubusercontent.com/2905865/225423740-e0befaba-cb74-44bf-9a64-f5dd9cbd4c33.jpeg) 76 | ![hopped up logo](https://user-images.githubusercontent.com/2905865/225423836-7c64428b-6782-4452-8d29-1628dc192c6c.jpeg) 77 | 78 | 79 | ## Note on False positives 80 | We noticed many images labled as dup by SNIP but not by raw feats are in fact newar duplicates, for example: 81 | 82 | ![Chess1](https://en.chessok.net/uploads/posts/2017-09/1506718434_knight-on-the-left-1.nc3.jpg) 83 | ![Chess2](https://m.media-amazon.com/images/I/51jNRpWUCjL.jpg) 84 | 85 | you may check a list of (randomly sampled) detected duplicate pairs [here](https://docs.google.com/spreadsheets/d/1Eq46U3MbTXzNoLCvnHLcw64X3bWE3ZE8zMJVQU9_gCg/edit?usp=sharing) 86 | 87 | 88 | ## Semantic Search 89 | 90 | You may use the compressed features to do semantic search with faiss (see for instance, the clip-retrieval repository). 91 | 92 | ## Contribute 93 | 94 | Contributions are welcome. 95 | Usually, the best way is first to open an issue to discuss things. 96 | 97 | This python project uses the [`hatch`][hatch] project manager. 98 | Dependencies are specified inside the `pyproject.toml` file, and build configs inside the `hatch.toml` file. 99 | As such you can enter the isolated development environment with `hatch shell` from inside the repository. 100 | 101 | The code should be documented following the [Numpy docstring standard][docstring]. 102 | 103 | To avoid silly mistakes, the code is checked with [pyright][pyright]. 104 | To ensure a consistent styling, all python code is formatted with [black][black] and we use the [ruff][ruff] linter. 105 | Remark that these can usually get installed in your editor, such as VS Code, to view the checks directly in the code. 106 | Once you have installed them (suggested via [pipx][pipx]), you can check that the code is consistent with: 107 | 108 | ```sh 109 | hatch run check # check for mistakes via static analysis with pyright 110 | black --check snip_dedup/ # check formatting of all python files 111 | ruff check snip_dedup/ # check linting rules 112 | ``` 113 | 114 | STILL TODO: 115 | 116 | - [ ] add docs / tutorial 117 | - [ ] add tests 118 | - [ ] check max file size on CI to prevent pushing data 119 | - [ ] auto publish github action. example at https://github.com/ofek/hatch-showcase/blob/master/.github/workflows/build.yml 120 | 121 | [hatch]: https://github.com/pypa/hatch 122 | [pyright]: https://github.com/microsoft/pyright 123 | [black]: https://github.com/psf/black 124 | [ruff]: https://github.com/charliermarsh/ruff 125 | [pipx]: https://github.com/pypa/pipx 126 | [docstring]: https://numpydoc.readthedocs.io/en/latest/format.html 127 | 128 | ## Citation 129 | ``` 130 | @misc{webster2023deduplication, 131 | title={On the De-duplication of LAION-2B}, 132 | author={Ryan Webster and Julien Rabin and Loic Simon and Frederic Jurie}, 133 | year={2023}, 134 | eprint={2303.12733}, 135 | archivePrefix={arXiv}, 136 | primaryClass={cs.CV} 137 | } 138 | ``` 139 | 140 | --------------------------------------------------------------------------------