├── .gitignore ├── LICENSE ├── README.md ├── assets └── model.png ├── inference.py ├── requirements.txt ├── scripts ├── inference_m2kr_large.sh └── inference_m2kr_small.sh ├── src ├── models │ ├── __init__.py │ ├── contrastive_loss │ │ ├── __init__.py │ │ ├── disco_clip.py │ │ ├── gather_utils.py │ │ └── modeling_contrastive_loss.py │ ├── ret │ │ ├── __init__.py │ │ ├── configuration_ret.py │ │ └── modeling_ret.py │ └── retriever │ │ ├── __init__.py │ │ ├── configuration_retriever.py │ │ └── modeling_retriever.py └── utils.py └── third_party ├── LICENSE ├── LoTTE.md ├── MANIFEST.in ├── README.md ├── colbert ├── __init__.py ├── data │ ├── __init__.py │ ├── collection.py │ ├── dataset.py │ ├── examples.py │ ├── queries.py │ └── ranking.py ├── distillation │ ├── ranking_scorer.py │ └── scorer.py ├── evaluation │ ├── __init__.py │ ├── load_model.py │ ├── loaders.py │ └── metrics.py ├── index.py ├── index_updater.py ├── indexer.py ├── indexing │ ├── __init__.py │ ├── codecs │ │ ├── __init__.py │ │ ├── decompress_residuals.cpp │ │ ├── decompress_residuals.cu │ │ ├── packbits.cpp │ │ ├── packbits.cu │ │ ├── residual.py │ │ ├── residual_embeddings.py │ │ └── residual_embeddings_strided.py │ ├── collection_encoder.py │ ├── collection_indexer.py │ ├── index_manager.py │ ├── index_saver.py │ ├── loaders.py │ └── utils.py ├── infra │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ ├── base_config.py │ │ ├── config.py │ │ ├── core_config.py │ │ └── settings.py │ ├── launcher.py │ ├── provenance.py │ ├── run.py │ └── utilities │ │ ├── annotate_em.py │ │ ├── create_triples.py │ │ └── minicorpus.py ├── modeling │ ├── __init__.py │ ├── base_colbert.py │ ├── checkpoint.py │ ├── colbert.py │ ├── hf_colbert.py │ ├── reranker │ │ ├── __init__.py │ │ ├── electra.py │ │ └── tokenizer.py │ ├── segmented_maxsim.cpp │ └── tokenization │ │ ├── __init__.py │ │ ├── doc_tokenization.py │ │ ├── query_tokenization.py │ │ └── utils.py ├── parameters.py ├── ranking │ └── __init__.py ├── search │ ├── __init__.py │ ├── candidate_generation.py │ ├── decompress_residuals.cpp │ ├── filter_pids.cpp │ ├── index_loader.py │ ├── index_storage.py │ ├── segmented_lookup.cpp │ ├── strided_tensor.py │ └── strided_tensor_core.py ├── searcher.py ├── tests │ ├── e2e_test.py │ ├── index_coalesce_test.py │ ├── index_updater_test.py │ └── tokenizers_test.py ├── trainer.py ├── training │ ├── __init__.py │ ├── eager_batcher.py │ ├── lazy_batcher.py │ ├── rerank_batcher.py │ ├── training.py │ └── utils.py ├── utilities │ ├── annotate_em.py │ ├── create_triples.py │ └── minicorpus.py └── utils │ ├── __init__.py │ ├── amp.py │ ├── coalesce.py │ ├── distributed.py │ ├── logging.py │ ├── parser.py │ ├── runs.py │ └── utils.py ├── colbert_ai.egg-info ├── PKG-INFO ├── SOURCES.txt ├── dependency_links.txt ├── requires.txt └── top_level.txt ├── conda_env.yml ├── conda_env_cpu.yml ├── server.py ├── setup.py └── utility ├── __init__.py ├── evaluate ├── __init__.py ├── annotate_EM.py ├── annotate_EM_helpers.py ├── evaluate_lotte_rankings.py └── msmarco_passages.py ├── preprocess ├── __init__.py ├── docs2passages.py └── queries_split.py ├── rankings ├── __init__.py ├── dev_subsample.py ├── merge.py ├── split_by_offset.py ├── split_by_queries.py └── tune.py ├── supervision ├── __init__.py ├── self_training.py └── triples.py └── utils ├── __init__.py ├── dpr.py ├── qa_loaders.py └── save_metadata.py /.gitignore: -------------------------------------------------------------------------------- 1 | .scratch/ 2 | .vscode/ 3 | __pycache__/ -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/assets/model.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.5.2 2 | certifi==2025.1.31 3 | charset-normalizer==3.4.1 4 | filelock==3.13.1 5 | fsspec==2024.6.1 6 | huggingface-hub==0.29.3 7 | idna==3.10 8 | Jinja2==3.1.4 9 | MarkupSafe==2.1.5 10 | mpmath==1.3.0 11 | networkx==3.3 12 | numpy==1.26.4 13 | packaging==24.2 14 | pandas==2.1.4 15 | pillow==11.0.0 16 | python-dateutil==2.9.0.post0 17 | pytz==2025.1 18 | PyYAML==6.0.2 19 | regex==2024.11.6 20 | requests==2.32.3 21 | safetensors==0.5.3 22 | six==1.17.0 23 | sympy==1.13.1 24 | tokenizers==0.20.3 25 | tqdm==4.67.1 26 | transformers==4.45.0 27 | typing_extensions==4.12.2 28 | tzdata==2025.1 29 | ujson==5.10.0 30 | urllib3==2.3.0 -------------------------------------------------------------------------------- /scripts/inference_m2kr_large.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=inference_m2kr_large 3 | #SBATCH --output= 4 | #SBATCH --error= 5 | #SBATCH --open-mode=truncate 6 | #SBATCH --partition= 7 | #SBATCH --account= 8 | #SBATCH --nodes=1 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --gpus-per-node=4 11 | #SBATCH --mem=128G 12 | #SBATCH --cpus-per-task=8 13 | #SBATCH --array=0-3 14 | #SBATCH --time=00:30:00 15 | 16 | # tested on 4 NVIDIA A100-SXM-64GB 17 | 18 | conda activate ret 19 | cd ~/ReT 20 | export PYTHONPATH=. 21 | 22 | export TRANSFORMERS_VERBOSITY=info 23 | export TOKENIZERS_PARALLELISM=false 24 | export COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True 25 | 26 | DATASET_NAMES=( 27 | "okvqa" 28 | "infoseek" 29 | "evqa" 30 | "wit" 31 | "llava" 32 | "kvqa" 33 | "oven" 34 | "iglue" 35 | ) 36 | 37 | JSONL_ROOT_PATH= 38 | DATASET_PATHS=( 39 | "${JSONL_ROOT_PATH}/okvqa_test.jsonl" 40 | "${JSONL_ROOT_PATH}/infoseek_test.jsonl" 41 | "${JSONL_ROOT_PATH}/evqa_test_m2kr.jsonl" 42 | "${JSONL_ROOT_PATH}/wit_test.jsonl" 43 | "${JSONL_ROOT_PATH}/llava_test.jsonl" 44 | "${JSONL_ROOT_PATH}/kvqa_test.jsonl" 45 | "${JSONL_ROOT_PATH}/oven_test.jsonl" 46 | "${JSONL_ROOT_PATH}/iglue_test.jsonl" 47 | ) 48 | 49 | DATASET_PASSAGES_PATHS=( 50 | "${JSONL_ROOT_PATH}/okvqa_passages_test.jsonl" 51 | "${JSONL_ROOT_PATH}/infoseek_passages_test.jsonl" 52 | "${JSONL_ROOT_PATH}/evqa_passages_test.jsonl" 53 | "${JSONL_ROOT_PATH}/wit_passages_test.jsonl" 54 | "${JSONL_ROOT_PATH}/llava_passages_test.jsonl" 55 | "${JSONL_ROOT_PATH}/kvqa_passages_test.jsonl" 56 | "${JSONL_ROOT_PATH}/oven_passages_test.jsonl" 57 | "${JSONL_ROOT_PATH}/iglue_passages_test.jsonl" 58 | ) 59 | 60 | IMAGE_ROOT_PATH= 61 | 62 | model_name="ReT-CLIP-ViT-L-14" 63 | checkpoint_path="aimagelab/${model_name}" 64 | root_path= 65 | dataset_path="${DATASET_PATHS[$SLURM_ARRAY_TASK_ID]}" 66 | dataset_passages_path="${DATASET_PASSAGES_PATHS[$SLURM_ARRAY_TASK_ID]}" 67 | experiment_name="${model_name}" 68 | index_name="${DATASET_NAMES[$SLURM_ARRAY_TASK_ID]}" 69 | 70 | echo "DATASET PATH: ${dataset_path}" 71 | echo "DATASET PASSAGES PATH: ${dataset_passages_path}" 72 | 73 | srun -c $SLURM_CPUS_PER_TASK --mem $SLURM_MEM_PER_NODE \ 74 | python inference.py \ 75 | --action index \ 76 | --dataset_path $dataset_passages_path \ 77 | --image_root_path $IMAGE_ROOT_PATH \ 78 | --checkpoint_path $checkpoint_path \ 79 | --root_path $root_path \ 80 | --experiment_name $experiment_name \ 81 | --index_name $index_name \ 82 | --index_bsize 128 83 | 84 | srun -c $SLURM_CPUS_PER_TASK --mem $SLURM_MEM_PER_NODE \ 85 | python inference.py \ 86 | --action search \ 87 | --dataset_path $dataset_path \ 88 | --dataset_passages_path $dataset_passages_path \ 89 | --image_root_path $IMAGE_ROOT_PATH \ 90 | --checkpoint_path $checkpoint_path \ 91 | --root_path $root_path \ 92 | --experiment_name $experiment_name \ 93 | --index_name $index_name \ 94 | --index_bsize 128 \ 95 | --num_docs_to_retrieve 500 96 | -------------------------------------------------------------------------------- /scripts/inference_m2kr_small.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=inference_m2kr_small 3 | #SBATCH --output= 4 | #SBATCH --error= 5 | #SBATCH --open-mode=truncate 6 | #SBATCH --partition= 7 | #SBATCH --account= 8 | #SBATCH --nodes=1 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --gpus-per-node=1 11 | #SBATCH --mem=128G 12 | #SBATCH --cpus-per-task=8 13 | #SBATCH --array=4-7 14 | #SBATCH --time=00:30:00 15 | 16 | # tested on 1 NVIDIA A100-SXM-64GB 17 | 18 | conda activate ret 19 | cd ~/ReT 20 | export PYTHONPATH=. 21 | 22 | export TRANSFORMERS_VERBOSITY=info 23 | export TOKENIZERS_PARALLELISM=false 24 | export COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True 25 | 26 | DATASET_NAMES=( 27 | "okvqa" 28 | "infoseek" 29 | "evqa" 30 | "wit" 31 | "llava" 32 | "kvqa" 33 | "oven" 34 | "iglue" 35 | ) 36 | 37 | JSONL_ROOT_PATH= 38 | DATASET_PATHS=( 39 | "${JSONL_ROOT_PATH}/okvqa_test.jsonl" 40 | "${JSONL_ROOT_PATH}/infoseek_test.jsonl" 41 | "${JSONL_ROOT_PATH}/evqa_test_m2kr.jsonl" 42 | "${JSONL_ROOT_PATH}/wit_test.jsonl" 43 | "${JSONL_ROOT_PATH}/llava_test.jsonl" 44 | "${JSONL_ROOT_PATH}/kvqa_test.jsonl" 45 | "${JSONL_ROOT_PATH}/oven_test.jsonl" 46 | "${JSONL_ROOT_PATH}/iglue_test.jsonl" 47 | ) 48 | 49 | DATASET_PASSAGES_PATHS=( 50 | "${JSONL_ROOT_PATH}/okvqa_passages_test.jsonl" 51 | "${JSONL_ROOT_PATH}/infoseek_passages_test.jsonl" 52 | "${JSONL_ROOT_PATH}/evqa_passages_test.jsonl" 53 | "${JSONL_ROOT_PATH}/wit_passages_test.jsonl" 54 | "${JSONL_ROOT_PATH}/llava_passages_test.jsonl" 55 | "${JSONL_ROOT_PATH}/kvqa_passages_test.jsonl" 56 | "${JSONL_ROOT_PATH}/oven_passages_test.jsonl" 57 | "${JSONL_ROOT_PATH}/iglue_passages_test.jsonl" 58 | ) 59 | 60 | IMAGE_ROOT_PATH= 61 | 62 | model_name="ReT-CLIP-ViT-L-14" 63 | checkpoint_path="aimagelab/${model_name}" 64 | root_path= 65 | dataset_path="${DATASET_PATHS[$SLURM_ARRAY_TASK_ID]}" 66 | dataset_passages_path="${DATASET_PASSAGES_PATHS[$SLURM_ARRAY_TASK_ID]}" 67 | experiment_name="${model_name}" 68 | index_name="${DATASET_NAMES[$SLURM_ARRAY_TASK_ID]}" 69 | 70 | echo "DATASET PATH: ${dataset_path}" 71 | echo "DATASET PASSAGES PATH: ${dataset_passages_path}" 72 | 73 | srun -c $SLURM_CPUS_PER_TASK --mem $SLURM_MEM_PER_NODE \ 74 | python inference.py \ 75 | --action index \ 76 | --dataset_path $dataset_passages_path \ 77 | --image_root_path $IMAGE_ROOT_PATH \ 78 | --checkpoint_path $checkpoint_path \ 79 | --root_path $root_path \ 80 | --experiment_name $experiment_name \ 81 | --index_name $index_name \ 82 | --index_bsize 128 83 | 84 | srun -c $SLURM_CPUS_PER_TASK --mem $SLURM_MEM_PER_NODE \ 85 | python inference.py \ 86 | --action search \ 87 | --dataset_path $dataset_path \ 88 | --dataset_passages_path $dataset_passages_path \ 89 | --image_root_path $IMAGE_ROOT_PATH \ 90 | --checkpoint_path $checkpoint_path \ 91 | --root_path $root_path \ 92 | --experiment_name $experiment_name \ 93 | --index_name $index_name \ 94 | --index_bsize 128 \ 95 | --num_docs_to_retrieve 500 96 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ret import * 2 | from .retriever import * -------------------------------------------------------------------------------- /src/models/contrastive_loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .modeling_contrastive_loss import contrastive_loss, ContrastiveLossOutput -------------------------------------------------------------------------------- /src/models/contrastive_loss/disco_clip.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/IDEA-Research/DisCo-CLIP/blob/main/disco/gather.py 2 | 3 | import os 4 | import torch 5 | import torch.distributed as dist 6 | 7 | class DisCoGather(torch.autograd.Function): 8 | """An autograd function that performs allgather on a tensor.""" 9 | 10 | @staticmethod 11 | def forward(ctx, tensor): 12 | if not torch.distributed.is_initialized(): 13 | raise "torch.distributed is not initialized" 14 | 15 | world_size = torch.distributed.get_world_size() 16 | ctx.bs = tensor.shape[0] 17 | ctx.rank = torch.distributed.get_rank() 18 | 19 | gathered_tensors = [ 20 | torch.zeros_like(tensor) for _ in range(world_size) 21 | ] 22 | torch.distributed.all_gather(gathered_tensors, tensor) 23 | 24 | gathered_tensors = torch.cat(gathered_tensors, dim=0) 25 | gathered_tensors.requires_grad_(True) 26 | 27 | return gathered_tensors 28 | 29 | @staticmethod 30 | def backward(ctx, grad_output): 31 | # do not remove this contiguous 32 | torch.distributed.all_reduce(grad_output.contiguous(), op=torch.distributed.ReduceOp.SUM) 33 | return grad_output[ctx.bs*ctx.rank:ctx.bs*(ctx.rank+1)] 34 | 35 | 36 | def Gather(tensor): 37 | return DisCoGather.apply(tensor) -------------------------------------------------------------------------------- /src/models/contrastive_loss/gather_utils.py: -------------------------------------------------------------------------------- 1 | # from LAVIS codebase: https://github.com/salesforce/LAVIS/blob/main/lavis/models/base_model.py#L202 2 | import torch 3 | from torch.distributed.nn.functional import all_gather as all_gather_with_grad_torch_functional 4 | 5 | 6 | class GatherLayer(torch.autograd.Function): 7 | """ 8 | Gather tensors from all workers with support for backward propagation: 9 | This implementation does not cut the gradients as torch.distributed.all_gather does. 10 | """ 11 | 12 | @staticmethod 13 | def forward(ctx, x): 14 | output = [ 15 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) 16 | ] 17 | torch.distributed.all_gather(output, x) 18 | return tuple(output) 19 | 20 | @staticmethod 21 | def backward(ctx, *grads): 22 | all_gradients = torch.stack(grads) 23 | torch.distributed.all_reduce(all_gradients) 24 | return all_gradients[torch.distributed.get_rank()] 25 | 26 | 27 | def all_gather_with_grad(tensors): 28 | """ 29 | Performs all_gather operation on the provided tensors. 30 | Graph remains connected for backward grad computation. 31 | """ 32 | # Queue the gathered tensors 33 | world_size = torch.distributed.get_world_size() 34 | # There is no need for reduction in the single-proc case 35 | if world_size == 1: 36 | return tensors 37 | 38 | # tensor_all = GatherLayer.apply(tensors) 39 | tensor_all = GatherLayer.apply(tensors) 40 | 41 | return torch.cat(tensor_all, dim=0) 42 | 43 | 44 | def all_gather_with_grad_torch(tensors): 45 | """ 46 | Official version of all_gather with grads in torch.distributed.nn.functional 47 | """ 48 | tensor_all = all_gather_with_grad_torch_functional(tensors) 49 | return torch.cat(tensor_all, dim=0) 50 | 51 | 52 | def all_gather(x): 53 | output = [ 54 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size()) 55 | ] 56 | torch.distributed.all_gather(output, x) 57 | return torch.cat(output, dim=0) -------------------------------------------------------------------------------- /src/models/ret/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_ret import RetConfig, RetLayerStrategy 2 | from .modeling_ret import RetModel, RetModelOutput 3 | from transformers import AutoConfig, AutoModel 4 | 5 | __all__ = [ 6 | 'RetLayerStrategy', 7 | 'RetConfig', 8 | 'RetModel', 9 | 'RetModelOutput' 10 | ] 11 | 12 | AutoConfig.register(RetConfig.model_type, RetConfig) 13 | AutoModel.register(RetConfig, RetModel) -------------------------------------------------------------------------------- /src/models/ret/configuration_ret.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Union 2 | from transformers import PretrainedConfig, AutoConfig 3 | from enum import Enum 4 | 5 | class RetLayerStrategy(str, Enum): 6 | CLIP_VIT_B = 'clip_vit_b' 7 | CLIP_VIT_L = 'clip_vit_l' 8 | OPENCLIP_VIT_H = 'openclip_vit_h' 9 | OPENCLIP_VIT_G = 'openclip_vit_g' 10 | 11 | 12 | _RET_LAYER_STRATEGY_MAPPING = { 13 | RetLayerStrategy.CLIP_VIT_B: { 14 | 'text': tuple(range(12)), 15 | 'vision': tuple(range(12)) 16 | }, 17 | RetLayerStrategy.CLIP_VIT_L: { 18 | 'text': list(range(12)), 19 | 'vision': list(range(24))[::2] 20 | }, 21 | RetLayerStrategy.OPENCLIP_VIT_H: { 22 | 'text': (1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23), 23 | 'vision': (0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 31) 24 | }, 25 | RetLayerStrategy.OPENCLIP_VIT_G: { 26 | 'text': (0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 31), 27 | 'vision': (2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 44, 47) 28 | } 29 | } 30 | 31 | 32 | class RetConfig(PretrainedConfig): 33 | model_type = 'ret' 34 | is_composition = False 35 | 36 | def __init__( 37 | self, 38 | text_config: Optional[Union[PretrainedConfig, Dict]] = None, 39 | vision_config: Optional[Union[PretrainedConfig, Dict]] = None, 40 | is_text_frozen: bool = True, 41 | is_vision_frozen: bool = True, 42 | late_proj_output_size: int = 128, 43 | layer_strategy: str = RetLayerStrategy.CLIP_VIT_B, 44 | use_pooler_features: bool = False, 45 | num_queries: int = 32, 46 | hidden_size: int = 1024, 47 | dropout_p: float = 0.05, 48 | attention_dropout: float = 0.05, 49 | activation_fn: str = 'gelu', 50 | input_gate_bias_prior: float = 0.0, 51 | forget_gate_bias_prior: float = 0.0, 52 | use_tanh: bool = False, 53 | **kwargs 54 | ): 55 | super().__init__(**kwargs) 56 | 57 | if isinstance(text_config, dict): 58 | text_config = AutoConfig.for_model(text_config.pop('model_type'), **text_config) 59 | self.text_config = text_config 60 | 61 | if isinstance(vision_config, dict): 62 | vision_config = AutoConfig.for_model(vision_config.pop('model_type'), **vision_config) 63 | self.vision_config = vision_config 64 | 65 | self.is_text_frozen = is_text_frozen 66 | self.is_vision_frozen = is_vision_frozen 67 | self.late_proj_output_size = late_proj_output_size 68 | self.layer_strategy = layer_strategy 69 | self.use_pooler_features = use_pooler_features 70 | 71 | # recurrent cell 72 | self.num_queries = num_queries 73 | self.hidden_size = hidden_size 74 | self.intermediate_size = hidden_size * 4 75 | self.num_attention_heads = hidden_size // 64 76 | self.dropout_p = dropout_p 77 | self.attention_dropout = attention_dropout 78 | self.activation_fn = activation_fn 79 | self.input_gate_bias_prior = input_gate_bias_prior 80 | self.forget_gate_bias_prior = forget_gate_bias_prior 81 | self.use_tanh = use_tanh 82 | 83 | @property 84 | def n_rec_layers(self): 85 | return len(_RET_LAYER_STRATEGY_MAPPING[self.layer_strategy]['text']) 86 | 87 | @property 88 | def text_layer_idxs(self): 89 | return _RET_LAYER_STRATEGY_MAPPING[self.layer_strategy]['text'] 90 | 91 | @property 92 | def vision_layer_idxs(self): 93 | return _RET_LAYER_STRATEGY_MAPPING[self.layer_strategy]['vision'] 94 | 95 | @property 96 | def text_hidden_size(self): 97 | if self.text_config is None: 98 | return 0 99 | else: 100 | return self.text_config.hidden_size 101 | 102 | @property 103 | def vision_hidden_size(self): 104 | if self.vision_config is None: 105 | return 0 106 | else: 107 | return self.vision_config.hidden_size -------------------------------------------------------------------------------- /src/models/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_retriever import RetrieverConfig 2 | from .modeling_retriever import RetrieverModel 3 | from transformers import AutoConfig, AutoModel 4 | 5 | __all__ = [ 6 | 'RetrieverConfig', 7 | 'RetrieverModel' 8 | ] 9 | 10 | AutoConfig.register(RetrieverConfig.model_type, RetrieverConfig) 11 | AutoModel.register(RetrieverConfig, RetrieverModel) -------------------------------------------------------------------------------- /src/models/retriever/configuration_retriever.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Union 2 | from transformers import PretrainedConfig 3 | from src.models.ret import RetConfig 4 | 5 | 6 | class RetrieverConfig(PretrainedConfig): 7 | model_type = 'retriever' 8 | is_composition = False 9 | 10 | def __init__( 11 | self, 12 | query_config: Optional[Union[RetConfig, Dict]] = None, 13 | passage_config: Optional[Union[RetConfig, Dict]] = None, 14 | fg_loss: bool = True, 15 | simmetric_loss: bool = True, 16 | share_query_passage_models: bool = False, 17 | share_text_models: bool = True, 18 | share_vision_models: bool = True, 19 | **kwargs 20 | ): 21 | super().__init__(**kwargs) 22 | 23 | if isinstance(query_config, RetConfig): 24 | query_config = query_config 25 | elif isinstance(query_config, dict): 26 | query_config = RetConfig.from_dict(query_config) 27 | self.query_config = query_config 28 | 29 | if isinstance(passage_config, RetConfig): 30 | passage_config = passage_config 31 | elif isinstance(passage_config, dict): 32 | passage_config = RetConfig.from_dict(passage_config) 33 | 34 | if share_query_passage_models: 35 | self.passage_config = self.query_config 36 | else: 37 | self.passage_config = passage_config 38 | 39 | self.fg_loss = fg_loss 40 | self.simmetric_loss = simmetric_loss 41 | self.share_query_passage_models = share_query_passage_models 42 | self.share_text_models = share_text_models 43 | self.share_vision_models = share_vision_models -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import logging 3 | import os 4 | 5 | logging.set_verbosity_info() 6 | logging.enable_explicit_format() 7 | 8 | 9 | def get_logger(): 10 | return logging.get_logger("transformers") 11 | 12 | 13 | def get_additive_attn_mask(binary_attn_mask, dtype): 14 | ret = torch.where(binary_attn_mask.bool(), 0, torch.finfo(dtype).min).to(dtype) 15 | return ret 16 | 17 | 18 | def is_debug(): 19 | return int(os.getenv("DEBUG", 0)) == 1 -------------------------------------------------------------------------------- /third_party/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019, 2020 Stanford Future Data Systems 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /third_party/LoTTE.md: -------------------------------------------------------------------------------- 1 | ## LoTTE dataset 2 | 3 | The Long-Tail Topic-stratified Evaluation (LoTTE) benchmark includes 12 domain-specific datasets derived from StackExchange questions and answers. Datasets span topics including writing, recreation, science, technology, and lifestyle. LoTTE includes two sets of queries: the first set is comprised of search-based queries from the GooAQ dataset, while the second set is comprised of forum-based queries taken directly from StackExchange. 4 | 5 | The dataset can be downloaded from this link: [https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/lotte.tar.gz](https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/lotte.tar.gz) 6 | 7 | The dataset is organized as follows: 8 | ``` 9 | |-- lotte 10 | |-- writing 11 | |-- dev 12 | |-- collection.tsv 13 | |-- metadata.jsonl 14 | |-- questions.search.tsv 15 | |-- qas.search.jsonl 16 | |-- questions.forum.tsv 17 | |-- qas.forum.jsonl 18 | |-- test 19 | |-- collection.tsv 20 | |-- ... 21 | |-- recreation 22 | |-- ... 23 | |-- ... 24 | ``` 25 | Here is a description of each file's contents: 26 | - `collection.tsv`: A list of passages where each line is of the form by `[pid]\t[text]` 27 | - `metadata.jsonl`: A list of JSON dictionaries for each question where each line is of the form: 28 | ``` 29 | { 30 | "dataset": dataset, 31 | "question_id": question_id, 32 | "post_ids": [post_id_1, post_id_2, ..., post_id_n], 33 | "scores": [score_1, score_2, ..., score_n], 34 | "post_urls": [url_1, url_2, ..., url_n], 35 | "post_authors": [author_1, author_2, ..., author_n], 36 | "post_author_urls": [url_1, url_2, ..., url_n], 37 | "question_author": question_author, 38 | "question_author_url", question_author_url 39 | } 40 | ``` 41 | - `questions.search.tsv`: A list of search-based questions of the form `[qid]\t[text]` 42 | - `qas.search.jsonl`: A list of JSON dictionaries for each search-based question's answer data of the form: 43 | 44 | ``` 45 | { 46 | "qid": qid, 47 | "query": query, 48 | "answer_pids": answer_pids 49 | } 50 | ``` 51 | - `questions.forum.tsv`: A list of forum-based questions 52 | - `qas.forum.tsv`: A list of JSON dictionaries for each forum-based question's answer data 53 | 54 | We also include a script to evaluate LoTTE rankings: `evaluate_lotte_rankings.py`. Each rankings file must be in a tsv format with each line of the form `[qid]\t[pid]\t[rank]\t[score]`. Note that `qid`s must be in sequential order starting from 0, and `rank`s must be in sequential order starting from 1. The rankings directory must have the following structure: 55 | ``` 56 | |--rankings 57 | |-- dev 58 | |-- writing.search.ranking.tsv 59 | |-- writing.forum.ranking.tsv 60 | |-- recreation.search.ranking.tsv 61 | |-- recreation.forum.ranking.tsv 62 | |-- science.search.ranking.tsv 63 | |-- science.forum.ranking.tsv 64 | |-- technology.search.ranking.tsv 65 | |-- technology.forum.ranking.tsv 66 | |-- lifestyle.search.ranking.tsv 67 | |-- lifestyle.forum.ranking.tsv 68 | |-- pooled.search.ranking.tsv 69 | |-- pooled.forum.ranking.tsv 70 | |-- test 71 | |-- writing.search.ranking.tsv 72 | |-- ... 73 | ``` 74 | Note that the file names must match exactly, though if some files are missing the script will print partial results. An example usage of the script is as follows: 75 | ``` 76 | python evaluate_lotte_rankings.py --k 5 --split test --data_path /path/to/lotte --rankings_path /path/to/rankings 77 | ``` 78 | This will produce the following output (numbers taken from the ColBERTv2 evaluation): 79 | ``` 80 | [query_type=search, dataset=writing] Success@5: 80.1 81 | [query_type=search, dataset=recreation] Success@5: 72.3 82 | [query_type=search, dataset=science] Success@5: 56.7 83 | [query_type=search, dataset=technology] Success@5: 66.1 84 | [query_type=search, dataset=lifestyle] Success@5: 84.7 85 | [query_type=search, dataset=pooled] Success@5: 71.6 86 | 87 | [query_type=forum, dataset=writing] Success@5: 76.3 88 | [query_type=forum, dataset=recreation] Success@5: 70.8 89 | [query_type=forum, dataset=science] Success@5: 46.1 90 | [query_type=forum, dataset=technology] Success@5: 53.6 91 | [query_type=forum, dataset=lifestyle] Success@5: 76.9 92 | [query_type=forum, dataset=pooled] Success@5: 63.4 93 | ``` 94 | 95 | -------------------------------------------------------------------------------- /third_party/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include colbert/indexing/codecs/*.cpp 2 | include colbert/indexing/codecs/*.cu -------------------------------------------------------------------------------- /third_party/colbert/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | from .indexer import Indexer 3 | from .searcher import Searcher 4 | from .index_updater import IndexUpdater 5 | 6 | from .modeling.checkpoint import Checkpoint 7 | -------------------------------------------------------------------------------- /third_party/colbert/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .collection import * 2 | from .queries import * 3 | 4 | from .ranking import * 5 | from .examples import * 6 | -------------------------------------------------------------------------------- /third_party/colbert/data/collection.py: -------------------------------------------------------------------------------- 1 | 2 | # Could be .tsv or .json. The latter always allows more customization via optional parameters. 3 | # I think it could be worth doing some kind of parallel reads too, if the file exceeds 1 GiBs. 4 | # Just need to use a datastructure that shares things across processes without too much pickling. 5 | # I think multiprocessing.Manager can do that! 6 | 7 | import os 8 | import itertools 9 | 10 | from third_party.colbert.evaluation.loaders import load_collection 11 | from third_party.colbert.infra.run import Run 12 | 13 | 14 | class Collection: 15 | def __init__(self, path=None, data=None): 16 | self.path = path 17 | self.data = data or self._load_file(path) 18 | 19 | def __iter__(self): 20 | # TODO: If __data isn't there, stream from disk! 21 | return self.data.__iter__() 22 | 23 | def __getitem__(self, item): 24 | # TODO: Load from disk the first time this is called. Unless self.data is already not None. 25 | return self.data[item] 26 | 27 | def __len__(self): 28 | # TODO: Load here too. Basically, let's make data a property function and, on first call, either load or get __data. 29 | return len(self.data) 30 | 31 | def _load_file(self, path): 32 | self.path = path 33 | return self._load_tsv(path) if path.endswith('.tsv') else self._load_jsonl(path) 34 | 35 | def _load_tsv(self, path): 36 | return load_collection(path) 37 | 38 | def _load_jsonl(self, path): 39 | raise NotImplementedError() 40 | 41 | def provenance(self): 42 | return self.path 43 | 44 | def toDict(self): 45 | return {'provenance': self.provenance()} 46 | 47 | def save(self, new_path): 48 | assert new_path.endswith('.tsv'), "TODO: Support .json[l] too." 49 | assert not os.path.exists(new_path), new_path 50 | 51 | with Run().open(new_path, 'w') as f: 52 | # TODO: expects content to always be a string here; no separate title! 53 | for pid, content in enumerate(self.data): 54 | content = f'{pid}\t{content}\n' 55 | f.write(content) 56 | 57 | return f.name 58 | 59 | def enumerate(self, rank): 60 | for _, offset, passages in self.enumerate_batches(rank=rank): 61 | for idx, passage in enumerate(passages): 62 | yield (offset + idx, passage) 63 | 64 | def enumerate_batches(self, rank, chunksize=None): 65 | assert rank is not None, "TODO: Add support for the rank=None case." 66 | 67 | chunksize = chunksize or self.get_chunksize() 68 | 69 | offset = 0 70 | iterator = iter(self) 71 | 72 | for chunk_idx, owner in enumerate(itertools.cycle(range(Run().nranks))): 73 | L = [line for _, line in zip(range(chunksize), iterator)] 74 | 75 | if len(L) > 0 and owner == rank: 76 | yield (chunk_idx, offset, L) 77 | 78 | offset += len(L) 79 | 80 | if len(L) < chunksize: 81 | return 82 | 83 | def get_chunksize(self): 84 | return min(25_000, 1 + len(self) // Run().nranks) # 25k is great, 10k allows things to reside on GPU?? 85 | 86 | @classmethod 87 | def cast(cls, obj): 88 | if type(obj) is str: 89 | return cls(path=obj) 90 | 91 | if type(obj) is list: 92 | return cls(data=obj) 93 | 94 | if type(obj) is cls: 95 | return obj 96 | 97 | assert False, f"obj has type {type(obj)} which is not compatible with cast()" 98 | 99 | 100 | # TODO: Look up path in some global [per-thread or thread-safe] list. 101 | -------------------------------------------------------------------------------- /third_party/colbert/data/dataset.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Not just the corpus, but also an arbitrary number of query sets, indexed by name in a dictionary/dotdict. 4 | # And also query sets with top-k PIDs. 5 | # QAs too? TripleSets too? 6 | 7 | 8 | class Dataset: 9 | def __init__(self): 10 | pass 11 | 12 | def select(self, key): 13 | # Select the {corpus, queryset, tripleset, rankingset} determined by uniqueness or by key and return a "unique" dataset (e.g., for key=train) 14 | pass 15 | -------------------------------------------------------------------------------- /third_party/colbert/data/examples.py: -------------------------------------------------------------------------------- 1 | from third_party.colbert.infra.run import Run 2 | import os 3 | import ujson 4 | 5 | from third_party.colbert.utils.utils import print_message 6 | from third_party.colbert.infra.provenance import Provenance 7 | from third_party.utility.utils.save_metadata import get_metadata_only 8 | 9 | 10 | class Examples: 11 | def __init__(self, path=None, data=None, nway=None, provenance=None): 12 | self.__provenance = provenance or path or Provenance() 13 | self.nway = nway 14 | self.path = path 15 | self.data = data or self._load_file(path) 16 | 17 | def provenance(self): 18 | return self.__provenance 19 | 20 | def toDict(self): 21 | return self.provenance() 22 | 23 | def _load_file(self, path): 24 | nway = self.nway + 1 if self.nway else self.nway 25 | examples = [] 26 | 27 | with open(path) as f: 28 | for line in f: 29 | example = ujson.loads(line)[:nway] 30 | examples.append(example) 31 | 32 | return examples 33 | 34 | def tolist(self, rank=None, nranks=None): 35 | """ 36 | NOTE: For distributed sampling, this isn't equivalent to perfectly uniform sampling. 37 | In particular, each subset is perfectly represented in every batch! However, since we never 38 | repeat passes over the data, we never repeat any particular triple, and the split across 39 | nodes is random (since the underlying file is pre-shuffled), there's no concern here. 40 | """ 41 | 42 | if rank or nranks: 43 | assert rank in range(nranks), (rank, nranks) 44 | return [self.data[idx] for idx in range(0, len(self.data), nranks)] # if line_idx % nranks == rank 45 | 46 | return list(self.data) 47 | 48 | def save(self, new_path): 49 | assert 'json' in new_path.strip('/').split('/')[-1].split('.'), "TODO: Support .json[l] too." 50 | 51 | print_message(f"#> Writing {len(self.data) / 1000_000.0}M examples to {new_path}") 52 | 53 | with Run().open(new_path, 'w') as f: 54 | for example in self.data: 55 | ujson.dump(example, f) 56 | f.write('\n') 57 | 58 | output_path = f.name 59 | print_message(f"#> Saved examples with {len(self.data)} lines to {f.name}") 60 | 61 | with Run().open(f'{new_path}.meta', 'w') as f: 62 | d = {} 63 | d['metadata'] = get_metadata_only() 64 | d['provenance'] = self.provenance() 65 | line = ujson.dumps(d, indent=4) 66 | f.write(line) 67 | 68 | return output_path 69 | 70 | @classmethod 71 | def cast(cls, obj, nway=None): 72 | if type(obj) is str: 73 | return cls(path=obj, nway=nway) 74 | 75 | if isinstance(obj, list): 76 | return cls(data=obj, nway=nway) 77 | 78 | if type(obj) is cls: 79 | assert nway is None, nway 80 | return obj 81 | 82 | assert False, f"obj has type {type(obj)} which is not compatible with cast()" 83 | -------------------------------------------------------------------------------- /third_party/colbert/data/queries.py: -------------------------------------------------------------------------------- 1 | from third_party.colbert.infra.run import Run 2 | import os 3 | import ujson 4 | 5 | from third_party.colbert.evaluation.loaders import load_queries 6 | 7 | # TODO: Look up path in some global [per-thread or thread-safe] list. 8 | # TODO: path could be a list of paths...? But then how can we tell it's not a list of queries.. 9 | 10 | 11 | class Queries: 12 | def __init__(self, path=None, data=None): 13 | self.path = path 14 | 15 | if data: 16 | assert isinstance(data, dict), type(data) 17 | self._load_data(data) or self._load_file(path) 18 | 19 | def __len__(self): 20 | return len(self.data) 21 | 22 | def __iter__(self): 23 | return iter(self.data.items()) 24 | 25 | def provenance(self): 26 | return self.path 27 | 28 | def toDict(self): 29 | return {'provenance': self.provenance()} 30 | 31 | def _load_data(self, data): 32 | if data is None: 33 | return None 34 | 35 | self.data = {} 36 | self._qas = {} 37 | 38 | for qid, content in data.items(): 39 | if isinstance(content, dict): 40 | self.data[qid] = content['question'] 41 | self._qas[qid] = content 42 | else: 43 | self.data[qid] = content 44 | 45 | if len(self._qas) == 0: 46 | del self._qas 47 | 48 | return True 49 | 50 | def _load_file(self, path): 51 | if not path.endswith('.json'): 52 | self.data = load_queries(path) 53 | return True 54 | 55 | # Load QAs 56 | self.data = {} 57 | self._qas = {} 58 | 59 | with open(path) as f: 60 | for line in f: 61 | qa = ujson.loads(line) 62 | 63 | assert qa['qid'] not in self.data 64 | self.data[qa['qid']] = qa['question'] 65 | self._qas[qa['qid']] = qa 66 | 67 | return self.data 68 | 69 | def qas(self): 70 | return dict(self._qas) 71 | 72 | def __getitem__(self, key): 73 | return self.data[key] 74 | 75 | def keys(self): 76 | return self.data.keys() 77 | 78 | def values(self): 79 | return self.data.values() 80 | 81 | def items(self): 82 | return self.data.items() 83 | 84 | def save(self, new_path): 85 | assert new_path.endswith('.tsv') 86 | assert not os.path.exists(new_path), new_path 87 | 88 | with Run().open(new_path, 'w') as f: 89 | for qid, content in self.data.items(): 90 | content = f'{qid}\t{content}\n' 91 | f.write(content) 92 | 93 | return f.name 94 | 95 | def save_qas(self, new_path): 96 | assert new_path.endswith('.json') 97 | assert not os.path.exists(new_path), new_path 98 | 99 | with open(new_path, 'w') as f: 100 | for qid, qa in self._qas.items(): 101 | qa['qid'] = qid 102 | f.write(ujson.dumps(qa) + '\n') 103 | 104 | def _load_tsv(self, path): 105 | raise NotImplementedError 106 | 107 | def _load_jsonl(self, path): 108 | raise NotImplementedError 109 | 110 | @classmethod 111 | def cast(cls, obj): 112 | if type(obj) is str: 113 | return cls(path=obj) 114 | 115 | if isinstance(obj, dict) or isinstance(obj, list): 116 | return cls(data=obj) 117 | 118 | if type(obj) is cls: 119 | return obj 120 | 121 | assert False, f"obj has type {type(obj)} which is not compatible with cast()" 122 | 123 | 124 | # class QuerySet: 125 | # def __init__(self, *paths, renumber=False): 126 | # self.paths = paths 127 | # self.original_queries = [load_queries(path) for path in paths] 128 | 129 | # if renumber: 130 | # self.queries = flatten([q.values() for q in self.original_queries]) 131 | # self.queries = {idx: text for idx, text in enumerate(self.queries)} 132 | 133 | # else: 134 | # self.queries = {} 135 | 136 | # for queries in self.original_queries: 137 | # assert len(set.intersection(set(queries.keys()), set(self.queries.keys()))) == 0, \ 138 | # "renumber=False requires non-overlapping query IDs" 139 | 140 | # self.queries.update(queries) 141 | 142 | # assert len(self.queries) == sum(map(len, self.original_queries)) 143 | 144 | # def todict(self): 145 | # return dict(self.queries) 146 | 147 | # def tolist(self): 148 | # return list(self.queries.values()) 149 | 150 | # def query_sets(self): 151 | # return self.original_queries 152 | 153 | # def split_rankings(self, rankings): 154 | # assert type(rankings) is list 155 | # assert len(rankings) == len(self.queries) 156 | 157 | # sub_rankings = [] 158 | # offset = 0 159 | # for source in self.original_queries: 160 | # sub_rankings.append(rankings[offset:offset+len(source)]) 161 | # offset += len(source) 162 | 163 | # return sub_rankings 164 | -------------------------------------------------------------------------------- /third_party/colbert/data/ranking.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import ujson 4 | from third_party.colbert.infra.provenance import Provenance 5 | 6 | from third_party.colbert.infra.run import Run 7 | from third_party.colbert.utils.utils import print_message, groupby_first_item 8 | from third_party.utility.utils.save_metadata import get_metadata_only 9 | 10 | 11 | def numericize(v): 12 | if '.' in v: 13 | return float(v) 14 | 15 | return int(v) 16 | 17 | 18 | def load_ranking(path): # works with annotated and un-annotated ranked lists 19 | print_message("#> Loading the ranked lists from", path) 20 | 21 | with open(path) as f: 22 | return [list(map(numericize, line.strip().split('\t'))) for line in f] 23 | 24 | 25 | class Ranking: 26 | def __init__(self, path=None, data=None, metrics=None, provenance=None): 27 | self.__provenance = provenance or path or Provenance() 28 | self.data = self._prepare_data(data or self._load_file(path)) 29 | 30 | def provenance(self): 31 | return self.__provenance 32 | 33 | def toDict(self): 34 | return {'provenance': self.provenance()} 35 | 36 | def _prepare_data(self, data): 37 | # TODO: Handle list of lists??? 38 | if isinstance(data, dict): 39 | self.flat_ranking = [(qid, *rest) for qid, subranking in data.items() for rest in subranking] 40 | return data 41 | 42 | self.flat_ranking = data 43 | return groupby_first_item(tqdm.tqdm(self.flat_ranking)) 44 | 45 | def _load_file(self, path): 46 | return load_ranking(path) 47 | 48 | def todict(self): 49 | return dict(self.data) 50 | 51 | def tolist(self): 52 | return list(self.flat_ranking) 53 | 54 | def items(self): 55 | return self.data.items() 56 | 57 | def _load_tsv(self, path): 58 | raise NotImplementedError 59 | 60 | def _load_jsonl(self, path): 61 | raise NotImplementedError 62 | 63 | def save(self, new_path): 64 | assert 'tsv' in new_path.strip('/').split('/')[-1].split('.'), "TODO: Support .json[l] too." 65 | 66 | with Run().open(new_path, 'w') as f: 67 | for items in self.flat_ranking: 68 | line = '\t'.join(map(lambda x: str(int(x) if type(x) is bool else x), items)) + '\n' 69 | f.write(line) 70 | 71 | output_path = f.name 72 | print_message(f"#> Saved ranking of {len(self.data)} queries and {len(self.flat_ranking)} lines to {f.name}") 73 | 74 | with Run().open(f'{new_path}.meta', 'w') as f: 75 | d = {} 76 | d['metadata'] = get_metadata_only() 77 | d['provenance'] = self.provenance() 78 | line = ujson.dumps(d, indent=4) 79 | f.write(line) 80 | 81 | return output_path 82 | 83 | @classmethod 84 | def cast(cls, obj): 85 | if type(obj) is str: 86 | return cls(path=obj) 87 | 88 | if isinstance(obj, dict) or isinstance(obj, list): 89 | return cls(data=obj) 90 | 91 | if type(obj) is cls: 92 | return obj 93 | 94 | assert False, f"obj has type {type(obj)} which is not compatible with cast()" 95 | -------------------------------------------------------------------------------- /third_party/colbert/distillation/ranking_scorer.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | import ujson 3 | 4 | from collections import defaultdict 5 | 6 | from third_party.colbert.utils.utils import print_message, zipstar 7 | from third_party.utility.utils.save_metadata import get_metadata_only 8 | 9 | from third_party.colbert.infra import Run 10 | from third_party.colbert.data import Ranking 11 | from third_party.colbert.infra.provenance import Provenance 12 | from third_party.colbert.distillation.scorer import Scorer 13 | 14 | 15 | class RankingScorer: 16 | def __init__(self, scorer: Scorer, ranking: Ranking): 17 | self.scorer = scorer 18 | self.ranking = ranking.tolist() 19 | self.__provenance = Provenance() 20 | 21 | print_message(f"#> Loaded ranking with {len(self.ranking)} qid--pid pairs!") 22 | 23 | def provenance(self): 24 | return self.__provenance 25 | 26 | def run(self): 27 | print_message(f"#> Starting..") 28 | 29 | qids, pids, *_ = zipstar(self.ranking) 30 | distillation_scores = self.scorer.launch(qids, pids) 31 | 32 | scores_by_qid = defaultdict(list) 33 | 34 | for qid, pid, score in tqdm.tqdm(zip(qids, pids, distillation_scores)): 35 | scores_by_qid[qid].append((score, pid)) 36 | 37 | with Run().open('distillation_scores.json', 'w') as f: 38 | for qid in tqdm.tqdm(scores_by_qid): 39 | obj = (qid, scores_by_qid[qid]) 40 | f.write(ujson.dumps(obj) + '\n') 41 | 42 | output_path = f.name 43 | print_message(f'#> Saved the distillation_scores to {output_path}') 44 | 45 | with Run().open(f'{output_path}.meta', 'w') as f: 46 | d = {} 47 | d['metadata'] = get_metadata_only() 48 | d['provenance'] = self.provenance() 49 | line = ujson.dumps(d, indent=4) 50 | f.write(line) 51 | 52 | return output_path 53 | -------------------------------------------------------------------------------- /third_party/colbert/distillation/scorer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tqdm 3 | 4 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 5 | 6 | from third_party.colbert.infra.launcher import Launcher 7 | from third_party.colbert.infra import Run, RunConfig 8 | from third_party.colbert.modeling.reranker.electra import ElectraReranker 9 | from third_party.colbert.utils.utils import flatten 10 | 11 | 12 | DEFAULT_MODEL = 'cross-encoder/ms-marco-MiniLM-L-6-v2' 13 | 14 | 15 | class Scorer: 16 | def __init__(self, queries, collection, model=DEFAULT_MODEL, maxlen=180, bsize=256): 17 | self.queries = queries 18 | self.collection = collection 19 | self.model = model 20 | 21 | self.maxlen = maxlen 22 | self.bsize = bsize 23 | 24 | def launch(self, qids, pids): 25 | launcher = Launcher(self._score_pairs_process, return_all=True) 26 | outputs = launcher.launch(Run().config, qids, pids) 27 | 28 | return flatten(outputs) 29 | 30 | def _score_pairs_process(self, config, qids, pids): 31 | assert len(qids) == len(pids), (len(qids), len(pids)) 32 | share = 1 + len(qids) // config.nranks 33 | offset = config.rank * share 34 | endpos = (1 + config.rank) * share 35 | 36 | return self._score_pairs(qids[offset:endpos], pids[offset:endpos], show_progress=(config.rank < 1)) 37 | 38 | def _score_pairs(self, qids, pids, show_progress=False): 39 | tokenizer = AutoTokenizer.from_pretrained(self.model) 40 | model = AutoModelForSequenceClassification.from_pretrained(self.model).cuda() 41 | 42 | assert len(qids) == len(pids), (len(qids), len(pids)) 43 | 44 | scores = [] 45 | 46 | model.eval() 47 | with torch.inference_mode(): 48 | with torch.cuda.amp.autocast(): 49 | for offset in tqdm.tqdm(range(0, len(qids), self.bsize), disable=(not show_progress)): 50 | endpos = offset + self.bsize 51 | 52 | queries_ = [self.queries[qid] for qid in qids[offset:endpos]] 53 | passages_ = [self.collection[pid] for pid in pids[offset:endpos]] 54 | 55 | features = tokenizer(queries_, passages_, padding='longest', truncation=True, 56 | return_tensors='pt', max_length=self.maxlen).to(model.device) 57 | 58 | scores.append(model(**features).logits.flatten()) 59 | 60 | scores = torch.cat(scores) 61 | scores = scores.tolist() 62 | 63 | Run().print(f'Returning with {len(scores)} scores') 64 | 65 | return scores 66 | 67 | 68 | # LONG-TERM TODO: This can be sped up by sorting by length in advance. 69 | -------------------------------------------------------------------------------- /third_party/colbert/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/evaluation/__init__.py -------------------------------------------------------------------------------- /third_party/colbert/evaluation/load_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ujson 3 | import torch 4 | import random 5 | 6 | from collections import defaultdict, OrderedDict 7 | 8 | from third_party.colbert.parameters import DEVICE 9 | from third_party.colbert.modeling.colbert import ColBERT 10 | from third_party.colbert.utils.utils import print_message, load_checkpoint 11 | 12 | 13 | def load_model(args, do_print=True): 14 | colbert = ColBERT.from_pretrained('bert-base-uncased', 15 | query_maxlen=args.query_maxlen, 16 | doc_maxlen=args.doc_maxlen, 17 | dim=args.dim, 18 | similarity_metric=args.similarity, 19 | mask_punctuation=args.mask_punctuation) 20 | colbert = colbert.to(DEVICE) 21 | 22 | print_message("#> Loading model checkpoint.", condition=do_print) 23 | 24 | checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print) 25 | 26 | colbert.eval() 27 | 28 | return colbert, checkpoint 29 | -------------------------------------------------------------------------------- /third_party/colbert/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | import ujson 2 | 3 | from collections import defaultdict 4 | from third_party.colbert.utils.runs import Run 5 | 6 | 7 | class Metrics: 8 | def __init__(self, mrr_depths: set, recall_depths: set, success_depths: set, total_queries=None): 9 | self.results = {} 10 | self.mrr_sums = {depth: 0.0 for depth in mrr_depths} 11 | self.recall_sums = {depth: 0.0 for depth in recall_depths} 12 | self.success_sums = {depth: 0.0 for depth in success_depths} 13 | self.total_queries = total_queries 14 | 15 | self.max_query_idx = -1 16 | self.num_queries_added = 0 17 | 18 | def add(self, query_idx, query_key, ranking, gold_positives): 19 | self.num_queries_added += 1 20 | 21 | assert query_key not in self.results 22 | assert len(self.results) <= query_idx 23 | assert len(set(gold_positives)) == len(gold_positives) 24 | assert len(set([pid for _, pid, _ in ranking])) == len(ranking) 25 | 26 | self.results[query_key] = ranking 27 | 28 | positives = [i for i, (_, pid, _) in enumerate(ranking) if pid in gold_positives] 29 | 30 | if len(positives) == 0: 31 | return 32 | 33 | for depth in self.mrr_sums: 34 | first_positive = positives[0] 35 | self.mrr_sums[depth] += (1.0 / (first_positive+1.0)) if first_positive < depth else 0.0 36 | 37 | for depth in self.success_sums: 38 | first_positive = positives[0] 39 | self.success_sums[depth] += 1.0 if first_positive < depth else 0.0 40 | 41 | for depth in self.recall_sums: 42 | num_positives_up_to_depth = len([pos for pos in positives if pos < depth]) 43 | self.recall_sums[depth] += num_positives_up_to_depth / len(gold_positives) 44 | 45 | def print_metrics(self, query_idx): 46 | for depth in sorted(self.mrr_sums): 47 | print("MRR@" + str(depth), "=", self.mrr_sums[depth] / (query_idx+1.0)) 48 | 49 | for depth in sorted(self.success_sums): 50 | print("Success@" + str(depth), "=", self.success_sums[depth] / (query_idx+1.0)) 51 | 52 | for depth in sorted(self.recall_sums): 53 | print("Recall@" + str(depth), "=", self.recall_sums[depth] / (query_idx+1.0)) 54 | 55 | def log(self, query_idx): 56 | assert query_idx >= self.max_query_idx 57 | self.max_query_idx = query_idx 58 | 59 | Run.log_metric("ranking/max_query_idx", query_idx, query_idx) 60 | Run.log_metric("ranking/num_queries_added", self.num_queries_added, query_idx) 61 | 62 | for depth in sorted(self.mrr_sums): 63 | score = self.mrr_sums[depth] / (query_idx+1.0) 64 | Run.log_metric("ranking/MRR." + str(depth), score, query_idx) 65 | 66 | for depth in sorted(self.success_sums): 67 | score = self.success_sums[depth] / (query_idx+1.0) 68 | Run.log_metric("ranking/Success." + str(depth), score, query_idx) 69 | 70 | for depth in sorted(self.recall_sums): 71 | score = self.recall_sums[depth] / (query_idx+1.0) 72 | Run.log_metric("ranking/Recall." + str(depth), score, query_idx) 73 | 74 | def output_final_metrics(self, path, query_idx, num_queries): 75 | assert query_idx + 1 == num_queries 76 | assert num_queries == self.total_queries 77 | 78 | if self.max_query_idx < query_idx: 79 | self.log(query_idx) 80 | 81 | self.print_metrics(query_idx) 82 | 83 | output = defaultdict(dict) 84 | 85 | for depth in sorted(self.mrr_sums): 86 | score = self.mrr_sums[depth] / (query_idx+1.0) 87 | output['mrr'][depth] = score 88 | 89 | for depth in sorted(self.success_sums): 90 | score = self.success_sums[depth] / (query_idx+1.0) 91 | output['success'][depth] = score 92 | 93 | for depth in sorted(self.recall_sums): 94 | score = self.recall_sums[depth] / (query_idx+1.0) 95 | output['recall'][depth] = score 96 | 97 | with open(path, 'w') as f: 98 | ujson.dump(output, f, indent=4) 99 | f.write('\n') 100 | 101 | 102 | def evaluate_recall(qrels, queries, topK_pids): 103 | if qrels is None: 104 | return 105 | 106 | assert set(qrels.keys()) == set(queries.keys()) 107 | recall_at_k = [len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) / max(1.0, len(qrels[qid])) 108 | for qid in qrels] 109 | recall_at_k = sum(recall_at_k) / len(qrels) 110 | recall_at_k = round(recall_at_k, 3) 111 | print("Recall @ maximum depth =", recall_at_k) 112 | 113 | 114 | # TODO: If implicit qrels are used (for re-ranking), warn if a recall metric is requested + add an asterisk to output. 115 | -------------------------------------------------------------------------------- /third_party/colbert/index.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # TODO: This is the loaded index, underneath a searcher. 4 | 5 | 6 | """ 7 | ## Operations: 8 | 9 | index = Index(index='/path/to/index') 10 | index.load_to_memory() 11 | 12 | batch_of_pids = [2324,32432,98743,23432] 13 | index.lookup(batch_of_pids, device='cuda:0') -> (N, doc_maxlen, dim) 14 | 15 | index.iterate_over_parts() 16 | 17 | """ 18 | -------------------------------------------------------------------------------- /third_party/colbert/indexer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import torch.multiprocessing as mp 5 | 6 | from third_party.colbert.infra.run import Run 7 | from third_party.colbert.infra.config import ColBERTConfig, RunConfig 8 | from third_party.colbert.infra.launcher import Launcher 9 | 10 | from third_party.colbert.utils.utils import create_directory, print_message 11 | 12 | from third_party.colbert.indexing.collection_indexer import encode 13 | from transformers import AutoConfig 14 | from src.utils import is_debug 15 | 16 | class Indexer: 17 | def __init__(self, checkpoint, config=None, verbose: int = 3): 18 | """ 19 | Use Run().context() to choose the run's configuration. They are NOT extracted from `config`. 20 | """ 21 | 22 | self.index_path = None 23 | self.verbose = verbose 24 | self.checkpoint = checkpoint 25 | self.checkpoint_config = ColBERTConfig.load_from_checkpoint(checkpoint) 26 | if config is None: 27 | self.checkpoint_config = AutoConfig.from_pretrained(checkpoint) 28 | else: 29 | self.checkpoint_config = config 30 | 31 | self.config = ColBERTConfig.from_existing(self.checkpoint_config, config, Run().config) 32 | self.configure(checkpoint=checkpoint) 33 | 34 | def configure(self, **kw_args): 35 | self.config.configure(**kw_args) 36 | 37 | def get_index(self): 38 | return self.index_path 39 | 40 | def erase(self, force_silent: bool = False): 41 | assert self.index_path is not None 42 | directory = self.index_path 43 | deleted = [] 44 | 45 | for filename in sorted(os.listdir(directory)): 46 | filename = os.path.join(directory, filename) 47 | 48 | delete = filename.endswith(".json") 49 | delete = delete and ('metadata' in filename or 'doclen' in filename or 'plan' in filename) 50 | delete = delete or filename.endswith(".pt") 51 | 52 | if delete: 53 | deleted.append(filename) 54 | 55 | if len(deleted): 56 | if not force_silent and not is_debug(): 57 | print_message(f"#> Will delete {len(deleted)} files already at {directory} in 20 seconds...") 58 | time.sleep(20) 59 | 60 | for filename in deleted: 61 | os.remove(filename) 62 | 63 | return deleted 64 | 65 | def index(self, name, collection, overwrite=False): 66 | assert overwrite in [True, False, 'reuse', 'resume', "force_silent_overwrite"] 67 | 68 | self.configure(collection=collection, index_name=name, resume=overwrite=='resume') 69 | # Note: The bsize value set here is ignored internally. Users are encouraged 70 | # to supply their own batch size for indexing by using the index_bsize parameter in the ColBERTConfig. 71 | self.configure(bsize=4, partitions=None) 72 | 73 | self.index_path = self.config.index_path_ 74 | index_does_not_exist = (not os.path.exists(self.config.index_path_)) 75 | 76 | assert (overwrite in [True, 'reuse', 'resume', "force_silent_overwrite"]) or index_does_not_exist, self.config.index_path_ 77 | create_directory(self.config.index_path_) 78 | 79 | if overwrite == 'force_silent_overwrite': 80 | self.erase(force_silent=True) 81 | elif overwrite is True: 82 | self.erase() 83 | 84 | if index_does_not_exist or overwrite != 'reuse': 85 | self.__launch(collection) 86 | 87 | return self.index_path 88 | 89 | def __launch(self, collection): 90 | launcher = Launcher(encode) 91 | if self.config.nranks == 1 and self.config.avoid_fork_if_possible: 92 | shared_queues = [] 93 | shared_lists = [] 94 | launcher.launch_without_fork(self.config, collection, shared_lists, shared_queues, self.verbose) 95 | 96 | return 97 | 98 | manager = mp.Manager() 99 | shared_lists = [manager.list() for _ in range(self.config.nranks)] 100 | shared_queues = [manager.Queue(maxsize=1) for _ in range(self.config.nranks)] 101 | 102 | # Encodes collection into index using the CollectionIndexer class 103 | self.verbose = True 104 | launcher.launch(self.config, collection, shared_lists, shared_queues, self.verbose) 105 | -------------------------------------------------------------------------------- /third_party/colbert/indexing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/indexing/__init__.py -------------------------------------------------------------------------------- /third_party/colbert/indexing/codecs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/indexing/codecs/__init__.py -------------------------------------------------------------------------------- /third_party/colbert/indexing/codecs/decompress_residuals.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | torch::Tensor decompress_residuals_cuda( 4 | const torch::Tensor binary_residuals, const torch::Tensor bucket_weights, 5 | const torch::Tensor reversed_bit_map, 6 | const torch::Tensor bucket_weight_combinations, const torch::Tensor codes, 7 | const torch::Tensor centroids, const int dim, const int nbits); 8 | 9 | torch::Tensor decompress_residuals( 10 | const torch::Tensor binary_residuals, const torch::Tensor bucket_weights, 11 | const torch::Tensor reversed_bit_map, 12 | const torch::Tensor bucket_weight_combinations, const torch::Tensor codes, 13 | const torch::Tensor centroids, const int dim, const int nbits) { 14 | // Add input verification 15 | return decompress_residuals_cuda( 16 | binary_residuals, bucket_weights, reversed_bit_map, 17 | bucket_weight_combinations, codes, centroids, dim, nbits); 18 | } 19 | 20 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 21 | m.def("decompress_residuals_cpp", &decompress_residuals, 22 | "Decompress residuals"); 23 | } 24 | -------------------------------------------------------------------------------- /third_party/colbert/indexing/codecs/decompress_residuals.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | __global__ void decompress_residuals_kernel( 9 | const uint8_t* binary_residuals, 10 | const torch::PackedTensorAccessor32 11 | bucket_weights, 12 | const torch::PackedTensorAccessor32 13 | reversed_bit_map, 14 | const torch::PackedTensorAccessor32 15 | bucket_weight_combinations, 16 | const torch::PackedTensorAccessor32 codes, 17 | const torch::PackedTensorAccessor32 18 | centroids, 19 | const int n, const int dim, const int nbits, const int packed_size, 20 | at::Half* output) { 21 | const int packed_dim = (int)(dim * nbits / packed_size); 22 | const int i = blockIdx.x; 23 | const int j = threadIdx.x; 24 | 25 | if (i >= n) return; 26 | if (j >= dim * nbits / packed_size) return; 27 | 28 | const int code = codes[i]; 29 | 30 | uint8_t x = binary_residuals[i * packed_dim + j]; 31 | x = reversed_bit_map[x]; 32 | int output_idx = (int)(j * packed_size / nbits); 33 | for (int k = 0; k < packed_size / nbits; k++) { 34 | assert(output_idx < dim); 35 | const int bucket_weight_idx = bucket_weight_combinations[x][k]; 36 | output[i * dim + output_idx] = bucket_weights[bucket_weight_idx]; 37 | output[i * dim + output_idx] += centroids[code][output_idx]; 38 | output_idx++; 39 | } 40 | } 41 | 42 | torch::Tensor decompress_residuals_cuda( 43 | const torch::Tensor binary_residuals, const torch::Tensor bucket_weights, 44 | const torch::Tensor reversed_bit_map, 45 | const torch::Tensor bucket_weight_combinations, const torch::Tensor codes, 46 | const torch::Tensor centroids, const int dim, const int nbits) { 47 | auto options = torch::TensorOptions() 48 | .dtype(torch::kFloat16) 49 | .device(torch::kCUDA, 0) 50 | .requires_grad(false); 51 | torch::Tensor output = 52 | torch::zeros({(int)binary_residuals.size(0), (int)dim}, options); 53 | 54 | // TODO: Set this automatically? 55 | const int packed_size = 8; 56 | 57 | const int threads = dim / (packed_size / nbits); 58 | const int blocks = 59 | (binary_residuals.size(0) * binary_residuals.size(1)) / threads; 60 | 61 | decompress_residuals_kernel<<>>( 62 | binary_residuals.data(), 63 | bucket_weights 64 | .packed_accessor32(), 65 | reversed_bit_map 66 | .packed_accessor32(), 67 | bucket_weight_combinations 68 | .packed_accessor32(), 69 | codes.packed_accessor32(), 70 | centroids.packed_accessor32(), 71 | binary_residuals.size(0), dim, nbits, packed_size, 72 | output.data()); 73 | 74 | return output; 75 | } 76 | -------------------------------------------------------------------------------- /third_party/colbert/indexing/codecs/packbits.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | torch::Tensor packbits_cuda(const torch::Tensor residuals); 4 | 5 | torch::Tensor packbits(const torch::Tensor residuals) { 6 | return packbits_cuda(residuals); 7 | } 8 | 9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 10 | m.def("packbits_cpp", &packbits, "Pack bits"); 11 | } 12 | 13 | -------------------------------------------------------------------------------- /third_party/colbert/indexing/codecs/packbits.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #define FULL_MASK 0xffffffff 9 | 10 | __global__ void packbits_kernel( 11 | const uint8_t* residuals, 12 | uint8_t* packed_residuals, 13 | const int residuals_size) { 14 | const int i = blockIdx.x; 15 | const int j = threadIdx.x; 16 | 17 | assert(blockDim.x == 32); 18 | 19 | const int residuals_idx = i * blockDim.x + j; 20 | if (residuals_idx >= residuals_size) { 21 | return; 22 | } 23 | 24 | const int packed_residuals_idx = residuals_idx / 8; 25 | 26 | 27 | uint32_t mask = __ballot_sync(FULL_MASK, residuals[residuals_idx]); 28 | 29 | mask = __brev(mask); 30 | 31 | if (residuals_idx % 32 == 0) { 32 | for (int k = 0; k < 4; k++) { 33 | packed_residuals[packed_residuals_idx + k] = 34 | (mask >> (8 * (4 - k - 1))) & 0xff; 35 | } 36 | } 37 | } 38 | 39 | torch::Tensor packbits_cuda(const torch::Tensor residuals) { 40 | auto options = torch::TensorOptions() 41 | .dtype(torch::kUInt8) 42 | .device(torch::kCUDA, residuals.device().index()) 43 | .requires_grad(false); 44 | assert(residuals.size(0) % 32 == 0); 45 | torch::Tensor packed_residuals = torch::zeros({int(residuals.size(0) / 8)}, options); 46 | 47 | const int threads = 32; 48 | const int blocks = std::ceil(residuals.size(0) / (float) threads); 49 | 50 | packbits_kernel<<>>( 51 | residuals.data(), 52 | packed_residuals.data(), 53 | residuals.size(0) 54 | ); 55 | 56 | return packed_residuals; 57 | } 58 | -------------------------------------------------------------------------------- /third_party/colbert/indexing/codecs/residual_embeddings_strided.py: -------------------------------------------------------------------------------- 1 | # from third_party.colbert.indexing.codecs.residual import ResidualCodec 2 | import third_party.colbert.indexing.codecs.residual_embeddings as residual_embeddings 3 | 4 | from third_party.colbert.search.strided_tensor import StridedTensor 5 | 6 | class ResidualEmbeddingsStrided: 7 | def __init__(self, codec, embeddings, doclens): 8 | self.codec = codec 9 | self.codes = embeddings.codes 10 | self.residuals = embeddings.residuals 11 | self.use_gpu = self.codec.use_gpu 12 | 13 | self.codes_strided = StridedTensor(self.codes, doclens, use_gpu=self.use_gpu) 14 | self.residuals_strided = StridedTensor(self.residuals, doclens, use_gpu=self.use_gpu) 15 | 16 | def lookup_pids(self, passage_ids, out_device='cuda'): 17 | codes_packed, codes_lengths = self.codes_strided.lookup(passage_ids)#.as_packed_tensor() 18 | residuals_packed, _ = self.residuals_strided.lookup(passage_ids)#.as_packed_tensor() 19 | 20 | embeddings_packed = self.codec.decompress(residual_embeddings.ResidualEmbeddings(codes_packed, residuals_packed)) 21 | 22 | return embeddings_packed, codes_lengths 23 | 24 | def lookup_codes(self, passage_ids): 25 | return self.codes_strided.lookup(passage_ids)#.as_packed_tensor() 26 | -------------------------------------------------------------------------------- /third_party/colbert/indexing/collection_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from third_party.colbert.infra.run import Run 4 | from third_party.colbert.utils.utils import print_message, batch 5 | 6 | 7 | class CollectionEncoder: 8 | def __init__(self, config, checkpoint): 9 | self.config = config 10 | self.checkpoint = checkpoint 11 | self.use_gpu = self.config.total_visible_gpus > 0 12 | 13 | def encode_passages(self, passages): 14 | Run().print(f"#> Encoding {len(passages)} passages..") 15 | 16 | if len(passages) == 0: 17 | return None, None 18 | 19 | with torch.inference_mode(): 20 | embs, doclens = [], [] 21 | 22 | # Batch here to avoid OOM from storing intermediate embeddings on GPU. 23 | # Storing on the GPU helps with speed of masking, etc. 24 | # But ideally this batching happens internally inside docFromText. 25 | for passage_batch in batch(passages, self.config.index_bsize): 26 | embs_, doclens_ = self.checkpoint.docFromText(passage_batch) 27 | embs.append(embs_) 28 | doclens.extend(doclens_) 29 | 30 | embs = torch.cat(embs) 31 | 32 | return embs, doclens 33 | -------------------------------------------------------------------------------- /third_party/colbert/indexing/index_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from bitarray import bitarray 5 | 6 | 7 | class IndexManager(): 8 | def __init__(self, dim): 9 | self.dim = dim 10 | 11 | def save(self, tensor, path_prefix): 12 | torch.save(tensor, path_prefix) 13 | 14 | def save_bitarray(self, bitarray, path_prefix): 15 | with open(path_prefix, "wb") as f: 16 | bitarray.tofile(f) 17 | 18 | 19 | def load_index_part(filename, verbose=True): 20 | part = torch.load(filename) 21 | 22 | if type(part) == list: # for backward compatibility 23 | part = torch.cat(part) 24 | 25 | return part 26 | 27 | 28 | def load_compressed_index_part(filename, dim, bits): 29 | a = bitarray() 30 | 31 | with open(filename, "rb") as f: 32 | a.fromfile(f) 33 | 34 | n = len(a) // dim // bits 35 | part = torch.tensor(np.frombuffer(a.tobytes(), dtype=np.uint8)) # TODO: isn't from_numpy(.) faster? 36 | part = part.reshape((n, int(np.ceil(dim * bits / 8)))) 37 | 38 | return part 39 | -------------------------------------------------------------------------------- /third_party/colbert/indexing/index_saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import queue 3 | import ujson 4 | import threading 5 | 6 | from contextlib import contextmanager 7 | 8 | from third_party.colbert.indexing.codecs.residual import ResidualCodec 9 | 10 | from third_party.colbert.utils.utils import print_message 11 | 12 | 13 | class IndexSaver(): 14 | def __init__(self, config): 15 | self.config = config 16 | 17 | def save_codec(self, codec): 18 | codec.save(index_path=self.config.index_path_) 19 | 20 | def load_codec(self): 21 | return ResidualCodec.load(index_path=self.config.index_path_) 22 | 23 | def try_load_codec(self): 24 | try: 25 | ResidualCodec.load(index_path=self.config.index_path_) 26 | return True 27 | except Exception as e: 28 | return False 29 | 30 | def check_chunk_exists(self, chunk_idx): 31 | # TODO: Verify that the chunk has the right amount of data? 32 | 33 | doclens_path = os.path.join(self.config.index_path_, f'doclens.{chunk_idx}.json') 34 | if not os.path.exists(doclens_path): 35 | return False 36 | 37 | metadata_path = os.path.join(self.config.index_path_, f'{chunk_idx}.metadata.json') 38 | if not os.path.exists(metadata_path): 39 | return False 40 | 41 | path_prefix = os.path.join(self.config.index_path_, str(chunk_idx)) 42 | codes_path = f'{path_prefix}.codes.pt' 43 | if not os.path.exists(codes_path): 44 | return False 45 | 46 | residuals_path = f'{path_prefix}.residuals.pt' # f'{path_prefix}.residuals.bn' 47 | if not os.path.exists(residuals_path): 48 | return False 49 | 50 | return True 51 | 52 | @contextmanager 53 | def thread(self): 54 | self.codec = self.load_codec() 55 | 56 | self.saver_queue = queue.Queue(maxsize=3) 57 | thread = threading.Thread(target=self._saver_thread) 58 | thread.start() 59 | 60 | try: 61 | yield 62 | 63 | finally: 64 | self.saver_queue.put(None) 65 | thread.join() 66 | 67 | del self.saver_queue 68 | del self.codec 69 | 70 | def save_chunk(self, chunk_idx, offset, embs, doclens): 71 | compressed_embs = self.codec.compress(embs) 72 | 73 | self.saver_queue.put((chunk_idx, offset, compressed_embs, doclens)) 74 | 75 | def _saver_thread(self): 76 | for args in iter(self.saver_queue.get, None): 77 | self._write_chunk_to_disk(*args) 78 | 79 | def _write_chunk_to_disk(self, chunk_idx, offset, compressed_embs, doclens): 80 | path_prefix = os.path.join(self.config.index_path_, str(chunk_idx)) 81 | compressed_embs.save(path_prefix) 82 | 83 | doclens_path = os.path.join(self.config.index_path_, f'doclens.{chunk_idx}.json') 84 | with open(doclens_path, 'w') as output_doclens: 85 | ujson.dump(doclens, output_doclens) 86 | 87 | metadata_path = os.path.join(self.config.index_path_, f'{chunk_idx}.metadata.json') 88 | with open(metadata_path, 'w') as output_metadata: 89 | metadata = {'passage_offset': offset, 'num_passages': len(doclens), 'num_embeddings': len(compressed_embs)} 90 | ujson.dump(metadata, output_metadata) 91 | -------------------------------------------------------------------------------- /third_party/colbert/indexing/loaders.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import ujson 4 | 5 | 6 | def get_parts(directory): 7 | extension = '.pt' 8 | 9 | parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory) 10 | if filename.endswith(extension)]) 11 | 12 | assert list(range(len(parts))) == parts, parts 13 | 14 | # Integer-sortedness matters. 15 | parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts] 16 | samples_paths = [os.path.join(directory, '{}.sample'.format(filename)) for filename in parts] 17 | 18 | return parts, parts_paths, samples_paths 19 | 20 | 21 | def load_doclens(directory, flatten=True): 22 | doclens_filenames = {} 23 | 24 | for filename in os.listdir(directory): 25 | match = re.match("doclens.(\d+).json", filename) 26 | 27 | if match is not None: 28 | doclens_filenames[int(match.group(1))] = filename 29 | 30 | doclens_filenames = [os.path.join(directory, doclens_filenames[i]) for i in sorted(doclens_filenames.keys())] 31 | 32 | all_doclens = [ujson.load(open(filename)) for filename in doclens_filenames] 33 | 34 | if flatten: 35 | all_doclens = [x for sub_doclens in all_doclens for x in sub_doclens] 36 | 37 | if len(all_doclens) == 0: 38 | raise ValueError("Could not load doclens") 39 | 40 | return all_doclens 41 | 42 | 43 | def get_deltas(directory): 44 | extension = '.residuals.pt' 45 | 46 | parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory) 47 | if filename.endswith(extension)]) 48 | 49 | assert list(range(len(parts))) == parts, parts 50 | 51 | # Integer-sortedness matters. 52 | parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts] 53 | 54 | return parts, parts_paths 55 | 56 | 57 | # def load_compression_data(level, path): 58 | # with open(path, "r") as f: 59 | # for line in f: 60 | # line = line.split(',') 61 | # bits = int(line[0]) 62 | 63 | # if bits == level: 64 | # return [float(v) for v in line[1:]] 65 | 66 | # raise ValueError(f"No data found for {level}-bit compression") 67 | -------------------------------------------------------------------------------- /third_party/colbert/indexing/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import tqdm 4 | 5 | from third_party.colbert.indexing.loaders import load_doclens 6 | from third_party.colbert.utils.utils import print_message, flatten 7 | 8 | def optimize_ivf(orig_ivf, orig_ivf_lengths, index_path, verbose:int=3): 9 | if verbose > 1: 10 | print_message("#> Optimizing IVF to store map from centroids to list of pids..") 11 | 12 | print_message("#> Building the emb2pid mapping..") 13 | all_doclens = load_doclens(index_path, flatten=False) 14 | 15 | # assert self.num_embeddings == sum(flatten(all_doclens)) 16 | 17 | all_doclens = flatten(all_doclens) 18 | total_num_embeddings = sum(all_doclens) 19 | 20 | emb2pid = torch.zeros(total_num_embeddings, dtype=torch.int) 21 | 22 | """ 23 | EVENTUALLY: Use two tensors. emb2pid_offsets will have every 256th element. 24 | emb2pid_delta will have the delta from the corresponding offset, 25 | """ 26 | 27 | offset_doclens = 0 28 | for pid, dlength in enumerate(all_doclens): 29 | emb2pid[offset_doclens: offset_doclens + dlength] = pid 30 | offset_doclens += dlength 31 | 32 | if verbose > 1: 33 | print_message("len(emb2pid) =", len(emb2pid)) 34 | 35 | ivf = emb2pid[orig_ivf] 36 | unique_pids_per_centroid = [] 37 | ivf_lengths = [] 38 | 39 | offset = 0 40 | for length in tqdm.tqdm(orig_ivf_lengths.tolist()): 41 | pids = torch.unique(ivf[offset:offset+length]) 42 | unique_pids_per_centroid.append(pids) 43 | ivf_lengths.append(pids.shape[0]) 44 | offset += length 45 | ivf = torch.cat(unique_pids_per_centroid) 46 | ivf_lengths = torch.tensor(ivf_lengths) 47 | 48 | max_stride = ivf_lengths.max().item() 49 | zero = torch.zeros(1, dtype=torch.long, device=ivf_lengths.device) 50 | offsets = torch.cat((zero, torch.cumsum(ivf_lengths, dim=0))) 51 | inner_dims = ivf.size()[1:] 52 | 53 | if offsets[-2] + max_stride > ivf.size(0): 54 | padding = torch.zeros(max_stride, *inner_dims, dtype=ivf.dtype, device=ivf.device) 55 | ivf = torch.cat((ivf, padding)) 56 | 57 | original_ivf_path = os.path.join(index_path, 'ivf.pt') 58 | optimized_ivf_path = os.path.join(index_path, 'ivf.pid.pt') 59 | torch.save((ivf, ivf_lengths), optimized_ivf_path) 60 | if verbose > 1: 61 | print_message(f"#> Saved optimized IVF to {optimized_ivf_path}") 62 | if os.path.exists(original_ivf_path): 63 | print_message(f"#> Original IVF at path \"{original_ivf_path}\" can now be removed") 64 | 65 | return ivf, ivf_lengths 66 | 67 | -------------------------------------------------------------------------------- /third_party/colbert/infra/__init__.py: -------------------------------------------------------------------------------- 1 | from .run import * 2 | from .config import * -------------------------------------------------------------------------------- /third_party/colbert/infra/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | from .settings import * -------------------------------------------------------------------------------- /third_party/colbert/infra/config/base_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import ujson 4 | from huggingface_hub import hf_hub_download 5 | from huggingface_hub.utils import RepositoryNotFoundError 6 | import dataclasses 7 | 8 | from typing import Any 9 | from collections import defaultdict 10 | from dataclasses import dataclass, fields 11 | from third_party.colbert.utils.utils import timestamp, torch_load_dnn 12 | 13 | from third_party.utility.utils.save_metadata import get_metadata_only 14 | from .core_config import * 15 | 16 | 17 | @dataclass 18 | class BaseConfig(CoreConfig): 19 | @classmethod 20 | def from_existing(cls, *sources): 21 | kw_args = {} 22 | 23 | for source in sources: 24 | if source is None: 25 | continue 26 | 27 | local_kw_args = dataclasses.asdict(source) 28 | local_kw_args = {k: local_kw_args[k] for k in source.assigned} 29 | kw_args = {**kw_args, **local_kw_args} 30 | 31 | obj = cls(**kw_args) 32 | 33 | return obj 34 | 35 | @classmethod 36 | def from_deprecated_args(cls, args): 37 | obj = cls() 38 | ignored = obj.configure(ignore_unrecognized=True, **args) 39 | 40 | return obj, ignored 41 | 42 | @classmethod 43 | def from_path(cls, name): 44 | with open(name) as f: 45 | args = ujson.load(f) 46 | 47 | if "config" in args: 48 | args = args["config"] 49 | 50 | return cls.from_deprecated_args( 51 | args 52 | ) # the new, non-deprecated version functions the same at this level. 53 | 54 | @classmethod 55 | def load_from_checkpoint(cls, checkpoint_path): 56 | if checkpoint_path.endswith(".dnn"): 57 | dnn = torch_load_dnn(checkpoint_path) 58 | config, _ = cls.from_deprecated_args(dnn.get("arguments", {})) 59 | 60 | # TODO: FIXME: Decide if the line below will have any unintended consequences. We don't want to overwrite those! 61 | config.set("checkpoint", checkpoint_path) 62 | 63 | return config 64 | 65 | try: 66 | checkpoint_path = hf_hub_download( 67 | repo_id=checkpoint_path, filename="artifact.metadata" 68 | ).split("artifact")[0] 69 | except Exception: 70 | pass 71 | loaded_config_path = os.path.join(checkpoint_path, "artifact.metadata") 72 | if os.path.exists(loaded_config_path): 73 | loaded_config, _ = cls.from_path(loaded_config_path) 74 | loaded_config.set("checkpoint", checkpoint_path) 75 | 76 | return loaded_config 77 | 78 | return ( 79 | None # can happen if checkpoint_path is something like 'bert-base-uncased' 80 | ) 81 | 82 | @classmethod 83 | def load_from_index(cls, index_path): 84 | # FIXME: We should start here with initial_config = ColBERTConfig(config, Run().config). 85 | # This should allow us to say initial_config.index_root. Then, below, set config = Config(..., initial_c) 86 | 87 | # default_index_root = os.path.join(Run().root, Run().experiment, 'indexes/') 88 | # index_path = os.path.join(default_index_root, index_path) 89 | 90 | # CONSIDER: No more plan/metadata.json. Only metadata.json to avoid weird issues when loading an index. 91 | 92 | try: 93 | metadata_path = os.path.join(index_path, "metadata.json") 94 | loaded_config, _ = cls.from_path(metadata_path) 95 | except: 96 | metadata_path = os.path.join(index_path, "plan.json") 97 | loaded_config, _ = cls.from_path(metadata_path) 98 | 99 | return loaded_config 100 | 101 | def save(self, path, overwrite=False): 102 | assert overwrite or not os.path.exists(path), path 103 | 104 | with open(path, "w") as f: 105 | args = self.export() # dict(self.__config) 106 | args["meta"] = get_metadata_only() 107 | args["meta"]["version"] = "colbert-v0.4" 108 | # TODO: Add git_status details.. It can't be too large! It should be a path that Runs() saves on exit, maybe! 109 | 110 | f.write(ujson.dumps(args, indent=4) + "\n") 111 | 112 | def save_for_checkpoint(self, checkpoint_path): 113 | assert not checkpoint_path.endswith( 114 | ".dnn" 115 | ), f"{checkpoint_path}: We reserve *.dnn names for the deprecated checkpoint format." 116 | 117 | output_config_path = os.path.join(checkpoint_path, "artifact.metadata") 118 | self.save(output_config_path, overwrite=True) 119 | -------------------------------------------------------------------------------- /third_party/colbert/infra/config/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from .base_config import BaseConfig 3 | from .settings import * 4 | from .core_config import DefaultVal 5 | 6 | 7 | @dataclass 8 | class RunConfig(BaseConfig, RunSettings): 9 | pass 10 | 11 | 12 | @dataclass 13 | class ColBERTConfig(RunSettings, ResourceSettings, DocSettings, QuerySettings, TrainingSettings, 14 | IndexingSettings, SearchSettings, BaseConfig, TokenizerSettings): 15 | checkpoint_path: str = DefaultVal(None) 16 | -------------------------------------------------------------------------------- /third_party/colbert/infra/config/core_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import ujson 4 | import dataclasses 5 | 6 | from typing import Any 7 | from collections import defaultdict 8 | from dataclasses import dataclass, fields 9 | from third_party.colbert.utils.utils import timestamp, torch_load_dnn 10 | 11 | from third_party.utility.utils.save_metadata import get_metadata_only 12 | 13 | 14 | @dataclass 15 | class DefaultVal: 16 | val: Any 17 | 18 | def __hash__(self): 19 | return hash(repr(self.val)) 20 | 21 | def __eq__(self, other): 22 | self.val == other.val 23 | 24 | @dataclass 25 | class CoreConfig: 26 | def __post_init__(self): 27 | """ 28 | Source: https://stackoverflow.com/a/58081120/1493011 29 | """ 30 | 31 | self.assigned = {} 32 | 33 | for field in fields(self): 34 | field_val = getattr(self, field.name) 35 | 36 | if isinstance(field_val, DefaultVal) or field_val is None: 37 | setattr(self, field.name, field.default.val) 38 | 39 | if not isinstance(field_val, DefaultVal): 40 | self.assigned[field.name] = True 41 | 42 | def assign_defaults(self): 43 | for field in fields(self): 44 | setattr(self, field.name, field.default.val) 45 | self.assigned[field.name] = True 46 | 47 | def configure(self, ignore_unrecognized=True, **kw_args): 48 | ignored = set() 49 | 50 | for key, value in kw_args.items(): 51 | self.set(key, value, ignore_unrecognized) or ignored.update({key}) 52 | 53 | return ignored 54 | 55 | """ 56 | # TODO: Take a config object, not kw_args. 57 | 58 | for key in config.assigned: 59 | value = getattr(config, key) 60 | """ 61 | 62 | def set(self, key, value, ignore_unrecognized=False): 63 | if hasattr(self, key): 64 | setattr(self, key, value) 65 | self.assigned[key] = True 66 | return True 67 | 68 | if not ignore_unrecognized: 69 | raise Exception(f"Unrecognized key `{key}` for {type(self)}") 70 | 71 | def help(self): 72 | print(ujson.dumps(self.export(), indent=4)) 73 | 74 | def __export_value(self, v): 75 | v = v.provenance() if hasattr(v, 'provenance') else v 76 | 77 | if isinstance(v, list) and len(v) > 100: 78 | v = (f"list with {len(v)} elements starting with...", v[:3]) 79 | 80 | if isinstance(v, dict) and len(v) > 100: 81 | v = (f"dict with {len(v)} keys starting with...", list(v.keys())[:3]) 82 | 83 | return v 84 | 85 | def export(self): 86 | d = dataclasses.asdict(self) 87 | 88 | for k, v in d.items(): 89 | d[k] = self.__export_value(v) 90 | 91 | return d 92 | -------------------------------------------------------------------------------- /third_party/colbert/infra/provenance.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import traceback 3 | import inspect 4 | 5 | 6 | class Provenance: 7 | def __init__(self) -> None: 8 | self.initial_stacktrace = self.stacktrace() 9 | 10 | def stacktrace(self): 11 | trace = inspect.stack() 12 | output = [] 13 | 14 | for frame in trace[2:-1]: 15 | try: 16 | frame = f'{frame.filename}:{frame.lineno}:{frame.function}: {frame.code_context[0].strip()}' 17 | output.append(frame) 18 | except: 19 | output.append(None) 20 | 21 | return output 22 | 23 | def toDict(self): # for ujson 24 | self.serialization_stacktrace = self.stacktrace() 25 | return dict(self.__dict__) 26 | 27 | 28 | if __name__ == '__main__': 29 | p = Provenance() 30 | print(p.toDict().keys()) 31 | 32 | import ujson 33 | print(ujson.dumps(p, indent=4)) 34 | 35 | 36 | class X: 37 | def __init__(self) -> None: 38 | pass 39 | 40 | def toDict(self): 41 | return {'key': 1} 42 | 43 | print(ujson.dumps(X())) -------------------------------------------------------------------------------- /third_party/colbert/infra/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import atexit 3 | 4 | from third_party.colbert.utils.utils import create_directory, print_message, timestamp 5 | from contextlib import contextmanager 6 | 7 | from third_party.colbert.infra.config import RunConfig 8 | 9 | 10 | class Run(object): 11 | _instance = None 12 | 13 | os.environ["TOKENIZERS_PARALLELISM"] = "true" # NOTE: If a deadlock arises, switch to false!! 14 | 15 | def __new__(cls): 16 | """ 17 | Singleton Pattern. See https://python-patterns.guide/gang-of-four/singleton/ 18 | """ 19 | if cls._instance is None: 20 | cls._instance = super().__new__(cls) 21 | cls._instance.stack = [] 22 | 23 | # TODO: Save a timestamp here! And re-use it! But allow the user to override it on calling Run().context a second time. 24 | run_config = RunConfig() 25 | run_config.assign_defaults() 26 | 27 | cls._instance.__append(run_config) 28 | 29 | # TODO: atexit.register(all_done) 30 | 31 | return cls._instance 32 | 33 | @property 34 | def config(self): 35 | return self.stack[-1] 36 | 37 | def __getattr__(self, name): 38 | if hasattr(self.config, name): 39 | return getattr(self.config, name) 40 | 41 | super().__getattr__(name) 42 | 43 | def __append(self, runconfig: RunConfig): 44 | # runconfig.disallow_writes(readonly=True) 45 | self.stack.append(runconfig) 46 | 47 | def __pop(self): 48 | self.stack.pop() 49 | 50 | @contextmanager 51 | def context(self, runconfig: RunConfig, inherit_config=True): 52 | if inherit_config: 53 | runconfig = RunConfig.from_existing(self.config, runconfig) 54 | 55 | self.__append(runconfig) 56 | 57 | try: 58 | yield 59 | finally: 60 | self.__pop() 61 | 62 | def open(self, path, mode='r'): 63 | path = os.path.join(self.path_, path) 64 | 65 | if not os.path.exists(self.path_): 66 | create_directory(self.path_) 67 | 68 | if ('w' in mode or 'a' in mode) and not self.overwrite: 69 | assert not os.path.exists(path), (self.overwrite, path) 70 | 71 | # create directory if it doesn't exist 72 | os.makedirs(os.path.dirname(path), exist_ok=True) 73 | 74 | return open(path, mode=mode) 75 | 76 | def print(self, *args): 77 | print_message("[" + str(self.rank) + "]", "\t\t", *args) 78 | 79 | def print_main(self, *args): 80 | if self.rank == 0: 81 | self.print(*args) 82 | 83 | 84 | if __name__ == '__main__': 85 | print(Run().root, '!') 86 | 87 | with Run().context(RunConfig(rank=0, nranks=1)): 88 | with Run().context(RunConfig(experiment='newproject')): 89 | print(Run().nranks, '!') 90 | 91 | print(Run().config, '!') 92 | print(Run().rank) 93 | 94 | 95 | # TODO: Handle logging all prints to a file. There should be a way to determine the level of logs that go to stdout. -------------------------------------------------------------------------------- /third_party/colbert/infra/utilities/annotate_em.py: -------------------------------------------------------------------------------- 1 | 2 | from third_party.colbert.infra.run import Run 3 | from third_party.colbert.data.collection import Collection 4 | import os 5 | import sys 6 | import git 7 | import tqdm 8 | import ujson 9 | import random 10 | 11 | from argparse import ArgumentParser 12 | from multiprocessing import Pool 13 | 14 | from third_party.colbert.utils.utils import groupby_first_item, print_message 15 | from third_party.utility.utils.qa_loaders import load_qas_, load_collection_ 16 | from third_party.utility.utils.save_metadata import format_metadata, get_metadata 17 | from third_party.utility.evaluate.annotate_EM_helpers import * 18 | 19 | from third_party.colbert.data.ranking import Ranking 20 | 21 | 22 | class AnnotateEM: 23 | def __init__(self, collection, qas): 24 | # TODO: These should just be Queries! But Queries needs to support looking up answers as qid2answers below. 25 | qas = load_qas_(qas) 26 | collection = Collection.cast(collection) # .tolist() #load_collection_(collection, retain_titles=True) 27 | 28 | self.parallel_pool = Pool(30) 29 | 30 | print_message('#> Tokenize the answers in the Q&As in parallel...') 31 | qas = list(self.parallel_pool.map(tokenize_all_answers, qas)) 32 | 33 | qid2answers = {qid: tok_answers for qid, _, tok_answers in qas} 34 | assert len(qas) == len(qid2answers), (len(qas), len(qid2answers)) 35 | 36 | self.qas, self.collection = qas, collection 37 | self.qid2answers = qid2answers 38 | 39 | def annotate(self, ranking): 40 | rankings = Ranking.cast(ranking) 41 | 42 | # print(len(rankings), rankings[0]) 43 | 44 | print_message('#> Lookup passages from PIDs...') 45 | expanded_rankings = [(qid, pid, rank, self.collection[pid], self.qid2answers[qid]) 46 | for qid, pid, rank, *_ in rankings.tolist()] 47 | 48 | print_message('#> Assign labels in parallel...') 49 | labeled_rankings = list(self.parallel_pool.map(assign_label_to_passage, enumerate(expanded_rankings))) 50 | 51 | # Dump output. 52 | self.qid2rankings = groupby_first_item(labeled_rankings) 53 | 54 | self.num_judged_queries, self.num_ranked_queries = check_sizes(self.qid2answers, self.qid2rankings) 55 | 56 | # Evaluation metrics and depths. 57 | self.success, self.counts = self._compute_labels(self.qid2answers, self.qid2rankings) 58 | 59 | print(rankings.provenance(), self.success) 60 | 61 | return Ranking(data=self.qid2rankings, provenance=("AnnotateEM", rankings.provenance())) 62 | 63 | def _compute_labels(self, qid2answers, qid2rankings): 64 | cutoffs = [1, 5, 10, 20, 30, 50, 100, 1000, 'all'] 65 | success = {cutoff: 0.0 for cutoff in cutoffs} 66 | counts = {cutoff: 0.0 for cutoff in cutoffs} 67 | 68 | for qid in qid2answers: 69 | if qid not in qid2rankings: 70 | continue 71 | 72 | prev_rank = 0 # ranks should start at one (i.e., and not zero) 73 | labels = [] 74 | 75 | for pid, rank, label in qid2rankings[qid]: 76 | assert rank == prev_rank+1, (qid, pid, (prev_rank, rank)) 77 | prev_rank = rank 78 | 79 | labels.append(label) 80 | 81 | for cutoff in cutoffs: 82 | if cutoff != 'all': 83 | success[cutoff] += sum(labels[:cutoff]) > 0 84 | counts[cutoff] += sum(labels[:cutoff]) 85 | else: 86 | success[cutoff] += sum(labels) > 0 87 | counts[cutoff] += sum(labels) 88 | 89 | return success, counts 90 | 91 | def save(self, new_path): 92 | print_message("#> Dumping output to", new_path, "...") 93 | 94 | Ranking(data=self.qid2rankings).save(new_path) 95 | 96 | # Dump metrics. 97 | with Run().open(f'{new_path}.metrics', 'w') as f: 98 | d = {'num_ranked_queries': self.num_ranked_queries, 'num_judged_queries': self.num_judged_queries} 99 | 100 | extra = '__WARNING' if self.num_judged_queries != self.num_ranked_queries else '' 101 | d[f'success{extra}'] = {k: v / self.num_judged_queries for k, v in self.success.items()} 102 | d[f'counts{extra}'] = {k: v / self.num_judged_queries for k, v in self.counts.items()} 103 | # d['arguments'] = get_metadata(args) # TODO: Need arguments... 104 | 105 | f.write(format_metadata(d) + '\n') 106 | 107 | 108 | if __name__ == '__main__': 109 | r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.50.02/ranking.tsv' 110 | r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.59.37/ranking.tsv' 111 | r = sys.argv[1] 112 | 113 | a = AnnotateEM(collection='/future/u/okhattab/root/unit/data/NQ-mini/collection.tsv', 114 | qas='/future/u/okhattab/root/unit/data/NQ-mini/dev/qas.json') 115 | a.annotate(ranking=r) 116 | -------------------------------------------------------------------------------- /third_party/colbert/infra/utilities/create_triples.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from third_party.colbert.utils.utils import print_message 4 | from third_party.utility.utils.save_metadata import save_metadata 5 | from third_party.utility.supervision.triples import sample_for_query 6 | 7 | from third_party.colbert.data.ranking import Ranking 8 | from third_party.colbert.data.examples import Examples 9 | 10 | MAX_NUM_TRIPLES = 40_000_000 11 | 12 | 13 | class Triples: 14 | def __init__(self, ranking, seed=12345): 15 | random.seed(seed) # TODO: Use internal RNG instead.. 16 | self.qid2rankings = Ranking.cast(ranking).todict() 17 | 18 | def create(self, positives, depth): 19 | assert all(len(x) == 2 for x in positives) 20 | assert all(maxBest <= maxDepth for maxBest, maxDepth in positives), positives 21 | 22 | Triples = [] 23 | NonEmptyQIDs = 0 24 | 25 | for processing_idx, qid in enumerate(self.qid2rankings): 26 | l = sample_for_query(qid, self.qid2rankings[qid], positives, depth, False, None) 27 | NonEmptyQIDs += (len(l) > 0) 28 | Triples.extend(l) 29 | 30 | if processing_idx % (10_000) == 0: 31 | print_message(f"#> Done with {processing_idx+1} questions!\t\t " 32 | f"{str(len(Triples) / 1000)}k triples for {NonEmptyQIDs} unqiue QIDs.") 33 | 34 | print_message(f"#> Sub-sample the triples (if > {MAX_NUM_TRIPLES})..") 35 | print_message(f"#> len(Triples) = {len(Triples)}") 36 | 37 | if len(Triples) > MAX_NUM_TRIPLES: 38 | Triples = random.sample(Triples, MAX_NUM_TRIPLES) 39 | 40 | ### Prepare the triples ### 41 | print_message("#> Shuffling the triples...") 42 | random.shuffle(Triples) 43 | 44 | self.Triples = Examples(data=Triples) 45 | 46 | return Triples 47 | 48 | def save(self, new_path): 49 | Examples(data=self.Triples).save(new_path) 50 | 51 | # save_metadata(f'{output}.meta', args) # TODO: What args to save?? {seed, positives, depth, rankings if path or else whatever provenance the rankings object shares} 52 | 53 | -------------------------------------------------------------------------------- /third_party/colbert/infra/utilities/minicorpus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from third_party.colbert.utils.utils import create_directory 5 | 6 | from third_party.colbert.data import Collection, Queries, Ranking 7 | 8 | 9 | def sample_minicorpus(name, factor, topk=30, maxdev=3000): 10 | """ 11 | Factor: 12 | * nano=1 13 | * micro=10 14 | * mini=100 15 | * small=100 with topk=100 16 | * medium=150 with topk=300 17 | """ 18 | 19 | random.seed(12345) 20 | 21 | # Load collection 22 | collection = Collection(path='/dfs/scratch0/okhattab/OpenQA/collection.tsv') 23 | 24 | # Load train and dev queries 25 | qas_train = Queries(path='/dfs/scratch0/okhattab/OpenQA/NQ/train/qas.json').qas() 26 | qas_dev = Queries(path='/dfs/scratch0/okhattab/OpenQA/NQ/dev/qas.json').qas() 27 | 28 | # Load train and dev C3 rankings 29 | ranking_train = Ranking(path='/dfs/scratch0/okhattab/OpenQA/NQ/train/rankings/C3.tsv.annotated').todict() 30 | ranking_dev = Ranking(path='/dfs/scratch0/okhattab/OpenQA/NQ/dev/rankings/C3.tsv.annotated').todict() 31 | 32 | # Sample NT and ND queries from each, keep only the top-k passages for those 33 | sample_train = random.sample(list(qas_train.keys()), min(len(qas_train.keys()), 300*factor)) 34 | sample_dev = random.sample(list(qas_dev.keys()), min(len(qas_dev.keys()), maxdev, 30*factor)) 35 | 36 | train_pids = [pid for qid in sample_train for qpids in ranking_train[qid][:topk] for pid in qpids] 37 | dev_pids = [pid for qid in sample_dev for qpids in ranking_dev[qid][:topk] for pid in qpids] 38 | 39 | sample_pids = sorted(list(set(train_pids + dev_pids))) 40 | print(f'len(sample_pids) = {len(sample_pids)}') 41 | 42 | # Save the new query sets: train and dev 43 | ROOT = f'/future/u/okhattab/root/unit/data/NQ-{name}' 44 | 45 | create_directory(os.path.join(ROOT, 'train')) 46 | create_directory(os.path.join(ROOT, 'dev')) 47 | 48 | new_train = Queries(data={qid: qas_train[qid] for qid in sample_train}) 49 | new_train.save(os.path.join(ROOT, 'train/questions.tsv')) 50 | new_train.save_qas(os.path.join(ROOT, 'train/qas.json')) 51 | 52 | new_dev = Queries(data={qid: qas_dev[qid] for qid in sample_dev}) 53 | new_dev.save(os.path.join(ROOT, 'dev/questions.tsv')) 54 | new_dev.save_qas(os.path.join(ROOT, 'dev/qas.json')) 55 | 56 | # Save the new collection 57 | print(f"Saving to {os.path.join(ROOT, 'collection.tsv')}") 58 | Collection(data=[collection[pid] for pid in sample_pids]).save(os.path.join(ROOT, 'collection.tsv')) 59 | 60 | print('#> Done!') 61 | 62 | 63 | if __name__ == '__main__': 64 | sample_minicorpus('medium', 150, topk=300) 65 | -------------------------------------------------------------------------------- /third_party/colbert/modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/modeling/__init__.py -------------------------------------------------------------------------------- /third_party/colbert/modeling/base_colbert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import sys 4 | 5 | from third_party.colbert.utils.utils import torch_load_dnn 6 | 7 | from transformers import AutoTokenizer 8 | from third_party.colbert.modeling.hf_colbert import class_factory 9 | from third_party.colbert.infra.config import ColBERTConfig 10 | from third_party.colbert.parameters import DEVICE 11 | 12 | 13 | class BaseColBERT(torch.nn.Module): 14 | """ 15 | Shallow module that wraps the ColBERT parameters, custom configuration, and underlying tokenizer. 16 | This class provides direct instantiation and saving of the model/colbert_config/tokenizer package. 17 | 18 | Like HF, evaluation mode is the default. 19 | """ 20 | 21 | def __init__(self, name_or_path, colbert_config=None): 22 | super().__init__() 23 | 24 | self.colbert_config = ColBERTConfig.from_existing(ColBERTConfig.load_from_checkpoint(name_or_path), colbert_config) 25 | self.name = self.colbert_config.model_name or name_or_path 26 | 27 | try: 28 | HF_ColBERT = class_factory(self.name) 29 | except: 30 | self.name = 'bert-base-uncased' # TODO: Double check that this is appropriate here in all cases 31 | HF_ColBERT = class_factory(self.name) 32 | 33 | # assert self.name is not None 34 | # HF_ColBERT = class_factory(self.name) 35 | 36 | self.model = HF_ColBERT.from_pretrained(name_or_path, colbert_config=self.colbert_config) 37 | self.model.to(DEVICE) 38 | self.raw_tokenizer = AutoTokenizer.from_pretrained(name_or_path) 39 | 40 | self.eval() 41 | 42 | @property 43 | def device(self): 44 | return self.model.device 45 | 46 | @property 47 | def bert(self): 48 | return self.model.LM 49 | 50 | @property 51 | def linear(self): 52 | return self.model.linear 53 | 54 | @property 55 | def score_scaler(self): 56 | return self.model.score_scaler 57 | 58 | def save(self, path): 59 | assert not path.endswith('.dnn'), f"{path}: We reserve *.dnn names for the deprecated checkpoint format." 60 | 61 | self.model.save_pretrained(path) 62 | self.raw_tokenizer.save_pretrained(path) 63 | 64 | self.colbert_config.save_for_checkpoint(path) 65 | 66 | 67 | if __name__ == '__main__': 68 | import random 69 | import numpy as np 70 | 71 | from third_party.colbert.infra.run import Run 72 | from third_party.colbert.infra.config import RunConfig 73 | 74 | random.seed(12345) 75 | np.random.seed(12345) 76 | torch.manual_seed(12345) 77 | 78 | with Run().context(RunConfig(gpus=2)): 79 | m = BaseColBERT('bert-base-uncased', colbert_config=ColBERTConfig(Run().config, doc_maxlen=300, similarity='l2')) 80 | m.colbert_config.help() 81 | print(m.linear.weight) 82 | m.save('/future/u/okhattab/tmp/2021/08/model.deleteme2/') 83 | 84 | m2 = BaseColBERT('/future/u/okhattab/tmp/2021/08/model.deleteme2/') 85 | m2.colbert_config.help() 86 | print(m2.linear.weight) 87 | 88 | exit() 89 | 90 | m = BaseColBERT('/future/u/okhattab/tmp/2021/08/model.deleteme/') 91 | print('BaseColBERT', m.linear.weight) 92 | print('BaseColBERT', m.colbert_config) 93 | 94 | exit() 95 | 96 | # m = HF_ColBERT.from_pretrained('nreimers/MiniLMv2-L6-H768-distilled-from-BERT-Large') 97 | m = HF_ColBERT.from_pretrained('/future/u/okhattab/tmp/2021/08/model.deleteme/') 98 | print('HF_ColBERT', m.linear.weight) 99 | 100 | m.save_pretrained('/future/u/okhattab/tmp/2021/08/model.deleteme/') 101 | 102 | # old = OldColBERT.from_pretrained('bert-base-uncased') 103 | # print(old.bert.encoder.layer[10].attention.self.value.weight) 104 | 105 | # random.seed(12345) 106 | # np.random.seed(12345) 107 | # torch.manual_seed(12345) 108 | 109 | dnn = torch_load_dnn( 110 | "/future/u/okhattab/root/TACL21/experiments/Feb26.NQ/train.py/ColBERT.C3/checkpoints/colbert-60000.dnn") 111 | # base = dnn.get('arguments', {}).get('model', 'bert-base-uncased') 112 | 113 | # new = BaseColBERT.from_pretrained('bert-base-uncased', state_dict=dnn['model_state_dict']) 114 | 115 | # print(new.bert.encoder.layer[10].attention.self.value.weight) 116 | 117 | print(dnn['model_state_dict']['linear.weight']) 118 | # print(dnn['model_state_dict']['bert.encoder.layer.10.attention.self.value.weight']) 119 | 120 | # # base_model_prefix 121 | -------------------------------------------------------------------------------- /third_party/colbert/modeling/reranker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/modeling/reranker/__init__.py -------------------------------------------------------------------------------- /third_party/colbert/modeling/reranker/electra.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from transformers import ElectraPreTrainedModel, ElectraModel, AutoTokenizer 4 | 5 | class ElectraReranker(ElectraPreTrainedModel): 6 | """ 7 | Shallow wrapper around HuggingFace transformers. All new parameters should be defined at this level. 8 | 9 | This makes sure `{from,save}_pretrained` and `init_weights` are applied to new parameters correctly. 10 | """ 11 | _keys_to_ignore_on_load_unexpected = [r"cls"] 12 | 13 | def __init__(self, config): 14 | super().__init__(config) 15 | 16 | self.electra = ElectraModel(config) 17 | self.linear = nn.Linear(config.hidden_size, 1) 18 | self.raw_tokenizer = AutoTokenizer.from_pretrained('google/electra-large-discriminator') 19 | 20 | self.init_weights() 21 | 22 | def forward(self, encoding): 23 | outputs = self.electra(encoding.input_ids, 24 | attention_mask=encoding.attention_mask, 25 | token_type_ids=encoding.token_type_ids)[0] 26 | 27 | scores = self.linear(outputs[:, 0]).squeeze(-1) 28 | 29 | return scores 30 | 31 | def save(self, path): 32 | assert not path.endswith('.dnn'), f"{path}: We reserve *.dnn names for the deprecated checkpoint format." 33 | 34 | self.save_pretrained(path) 35 | self.raw_tokenizer.save_pretrained(path) -------------------------------------------------------------------------------- /third_party/colbert/modeling/reranker/tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer 2 | 3 | class RerankerTokenizer(): 4 | def __init__(self, total_maxlen, base): 5 | self.total_maxlen = total_maxlen 6 | self.tok = AutoTokenizer.from_pretrained(base) 7 | 8 | def tensorize(self, questions, passages): 9 | assert type(questions) in [list, tuple], type(questions) 10 | assert type(passages) in [list, tuple], type(passages) 11 | 12 | encoding = self.tok(questions, passages, padding='longest', truncation='longest_first', 13 | return_tensors='pt', max_length=self.total_maxlen, add_special_tokens=True) 14 | 15 | return encoding 16 | -------------------------------------------------------------------------------- /third_party/colbert/modeling/segmented_maxsim.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | typedef struct { 8 | int tid; 9 | int nthreads; 10 | 11 | int ndocs; 12 | int ndoc_vectors; 13 | int nquery_vectors; 14 | 15 | int64_t* lengths; 16 | float* scores; 17 | int64_t* offsets; 18 | 19 | float* max_scores; 20 | } max_args_t; 21 | 22 | void* max(void* args) { 23 | max_args_t* max_args = (max_args_t*)args; 24 | 25 | int ndocs_per_thread = 26 | std::ceil(((float)max_args->ndocs) / max_args->nthreads); 27 | int start = max_args->tid * ndocs_per_thread; 28 | int end = std::min((max_args->tid + 1) * ndocs_per_thread, max_args->ndocs); 29 | 30 | auto max_scores_offset = 31 | max_args->max_scores + (start * max_args->nquery_vectors); 32 | auto scores_offset = 33 | max_args->scores + (max_args->offsets[start] * max_args->nquery_vectors); 34 | 35 | for (int i = start; i < end; i++) { 36 | for (int j = 0; j < max_args->lengths[i]; j++) { 37 | std::transform(max_scores_offset, 38 | max_scores_offset + max_args->nquery_vectors, 39 | scores_offset, max_scores_offset, 40 | [](float a, float b) { return std::max(a, b); }); 41 | scores_offset += max_args->nquery_vectors; 42 | } 43 | max_scores_offset += max_args->nquery_vectors; 44 | } 45 | 46 | return NULL; 47 | } 48 | 49 | torch::Tensor segmented_maxsim(const torch::Tensor scores, 50 | const torch::Tensor lengths) { 51 | auto lengths_a = lengths.data_ptr(); 52 | auto scores_a = scores.data_ptr(); 53 | auto ndocs = lengths.size(0); 54 | auto ndoc_vectors = scores.size(0); 55 | auto nquery_vectors = scores.size(1); 56 | auto nthreads = at::get_num_threads(); 57 | 58 | torch::Tensor max_scores = 59 | torch::zeros({ndocs, nquery_vectors}, scores.options()); 60 | 61 | int64_t offsets[ndocs + 1]; 62 | offsets[0] = 0; 63 | std::partial_sum(lengths_a, lengths_a + ndocs, offsets + 1); 64 | 65 | pthread_t threads[nthreads]; 66 | max_args_t args[nthreads]; 67 | 68 | for (int i = 0; i < nthreads; i++) { 69 | args[i].tid = i; 70 | args[i].nthreads = nthreads; 71 | 72 | args[i].ndocs = ndocs; 73 | args[i].ndoc_vectors = ndoc_vectors; 74 | args[i].nquery_vectors = nquery_vectors; 75 | 76 | args[i].lengths = lengths_a; 77 | args[i].scores = scores_a; 78 | args[i].offsets = offsets; 79 | 80 | args[i].max_scores = max_scores.data_ptr(); 81 | 82 | int rc = pthread_create(&threads[i], NULL, max, (void*)&args[i]); 83 | if (rc) { 84 | fprintf(stderr, "Unable to create thread %d: %d\n", i, rc); 85 | } 86 | } 87 | 88 | for (int i = 0; i < nthreads; i++) { 89 | pthread_join(threads[i], NULL); 90 | } 91 | 92 | return max_scores.sum(1); 93 | } 94 | 95 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 96 | m.def("segmented_maxsim_cpp", &segmented_maxsim, "Segmented MaxSim"); 97 | } 98 | -------------------------------------------------------------------------------- /third_party/colbert/modeling/tokenization/__init__.py: -------------------------------------------------------------------------------- 1 | from third_party.colbert.modeling.tokenization.query_tokenization import * 2 | from third_party.colbert.modeling.tokenization.doc_tokenization import * 3 | from third_party.colbert.modeling.tokenization.utils import tensorize_triples 4 | -------------------------------------------------------------------------------- /third_party/colbert/modeling/tokenization/doc_tokenization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # from transformers import BertTokenizerFast 4 | 5 | from third_party.colbert.modeling.hf_colbert import class_factory 6 | from third_party.colbert.infra import ColBERTConfig 7 | from third_party.colbert.modeling.tokenization.utils import _split_into_batches, _sort_by_length, _insert_prefix_token 8 | from third_party.colbert.parameters import DEVICE 9 | 10 | class DocTokenizer(): 11 | def __init__(self, config: ColBERTConfig): 12 | HF_ColBERT = class_factory(config.checkpoint) 13 | self.tok = HF_ColBERT.raw_tokenizer_from_pretrained(config.checkpoint) 14 | 15 | self.config = config 16 | self.doc_maxlen = config.doc_maxlen 17 | 18 | self.D_marker_token, self.D_marker_token_id = self.config.doc_token, self.tok.convert_tokens_to_ids(self.config.doc_token_id) 19 | self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id 20 | self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id 21 | 22 | def tokenize(self, batch_text, add_special_tokens=False): 23 | assert type(batch_text) in [list, tuple], (type(batch_text)) 24 | 25 | tokens = [self.tok.tokenize(x, add_special_tokens=False).to(DEVICE) for x in batch_text] 26 | 27 | if not add_special_tokens: 28 | return tokens 29 | 30 | prefix, suffix = [self.cls_token, self.D_marker_token], [self.sep_token] 31 | tokens = [prefix + lst + suffix for lst in tokens] 32 | 33 | return tokens 34 | 35 | def encode(self, batch_text, add_special_tokens=False): 36 | assert type(batch_text) in [list, tuple], (type(batch_text)) 37 | 38 | ids = self.tok(batch_text, add_special_tokens=False).to(DEVICE)['input_ids'] 39 | 40 | if not add_special_tokens: 41 | return ids 42 | 43 | prefix, suffix = [self.cls_token_id, self.D_marker_token_id], [self.sep_token_id] 44 | ids = [prefix + lst + suffix for lst in ids] 45 | 46 | return ids 47 | 48 | def tensorize(self, batch_text, bsize=None): 49 | assert type(batch_text) in [list, tuple], (type(batch_text)) 50 | 51 | obj = self.tok(batch_text, padding='longest', truncation='longest_first', 52 | return_tensors='pt', max_length=(self.doc_maxlen - 1)).to(DEVICE) 53 | 54 | ids = _insert_prefix_token(obj['input_ids'], self.D_marker_token_id) 55 | mask = _insert_prefix_token(obj['attention_mask'], 1) 56 | 57 | if bsize: 58 | ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize) 59 | batches = _split_into_batches(ids, mask, bsize) 60 | return batches, reverse_indices 61 | 62 | return ids, mask 63 | -------------------------------------------------------------------------------- /third_party/colbert/modeling/tokenization/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def tensorize_triples(query_tokenizer, doc_tokenizer, queries, passages, scores, bsize, nway): 5 | # assert len(passages) == len(scores) == bsize * nway 6 | # assert bsize is None or len(queries) % bsize == 0 7 | 8 | # N = len(queries) 9 | Q_ids, Q_mask = query_tokenizer.tensorize(queries) 10 | D_ids, D_mask = doc_tokenizer.tensorize(passages) 11 | # D_ids, D_mask = D_ids.view(2, N, -1), D_mask.view(2, N, -1) 12 | 13 | # # Compute max among {length of i^th positive, length of i^th negative} for i \in N 14 | # maxlens = D_mask.sum(-1).max(0).values 15 | 16 | # # Sort by maxlens 17 | # indices = maxlens.sort().indices 18 | # Q_ids, Q_mask = Q_ids[indices], Q_mask[indices] 19 | # D_ids, D_mask = D_ids[:, indices], D_mask[:, indices] 20 | 21 | # (positive_ids, negative_ids), (positive_mask, negative_mask) = D_ids, D_mask 22 | 23 | query_batches = _split_into_batches(Q_ids, Q_mask, bsize) 24 | doc_batches = _split_into_batches(D_ids, D_mask, bsize * nway) 25 | # positive_batches = _split_into_batches(positive_ids, positive_mask, bsize) 26 | # negative_batches = _split_into_batches(negative_ids, negative_mask, bsize) 27 | 28 | if len(scores): 29 | score_batches = _split_into_batches2(scores, bsize * nway) 30 | else: 31 | score_batches = [[] for _ in doc_batches] 32 | 33 | batches = [] 34 | for Q, D, S in zip(query_batches, doc_batches, score_batches): 35 | batches.append((Q, D, S)) 36 | 37 | return batches 38 | 39 | 40 | def _sort_by_length(ids, mask, bsize): 41 | if ids.size(0) <= bsize: 42 | return ids, mask, torch.arange(ids.size(0)) 43 | 44 | indices = mask.sum(-1).sort().indices 45 | reverse_indices = indices.sort().indices 46 | 47 | return ids[indices], mask[indices], reverse_indices 48 | 49 | 50 | def _split_into_batches(ids, mask, bsize): 51 | batches = [] 52 | for offset in range(0, ids.size(0), bsize): 53 | batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize])) 54 | 55 | return batches 56 | 57 | 58 | def _split_into_batches2(scores, bsize): 59 | batches = [] 60 | for offset in range(0, len(scores), bsize): 61 | batches.append(scores[offset:offset+bsize]) 62 | 63 | return batches 64 | 65 | def _insert_prefix_token(tensor: torch.Tensor, prefix_id: int): 66 | prefix_tensor = torch.full( 67 | (tensor.size(0), 1), 68 | prefix_id, 69 | dtype=tensor.dtype, 70 | device=tensor.device, 71 | ) 72 | return torch.cat([tensor[:, :1], prefix_tensor, tensor[:, 1:]], dim=1) 73 | -------------------------------------------------------------------------------- /third_party/colbert/parameters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 4 | 5 | SAVED_CHECKPOINTS = [32*1000, 100*1000, 150*1000, 200*1000, 250*1000, 300*1000, 400*1000] 6 | SAVED_CHECKPOINTS += [10*1000, 20*1000, 30*1000, 40*1000, 50*1000, 60*1000, 70*1000, 80*1000, 90*1000] 7 | SAVED_CHECKPOINTS += [25*1000, 50*1000, 75*1000] 8 | 9 | SAVED_CHECKPOINTS = set(SAVED_CHECKPOINTS) 10 | 11 | 12 | # TODO: final_ckpt 2k, 5k, 10k 20k, 50k, 100k 150k 200k, 500k, 1M 2M, 5M, 10M -------------------------------------------------------------------------------- /third_party/colbert/ranking/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/ranking/__init__.py -------------------------------------------------------------------------------- /third_party/colbert/search/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/search/__init__.py -------------------------------------------------------------------------------- /third_party/colbert/search/candidate_generation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from third_party.colbert.search.strided_tensor import StridedTensor 4 | from .strided_tensor_core import _create_mask, _create_view 5 | 6 | 7 | class CandidateGeneration: 8 | 9 | def __init__(self, use_gpu=True): 10 | self.use_gpu = use_gpu 11 | 12 | def get_cells(self, Q, ncells): 13 | scores = (self.codec.centroids @ Q.T) 14 | if ncells == 1: 15 | cells = scores.argmax(dim=0, keepdim=True).permute(1, 0) 16 | else: 17 | cells = scores.topk(ncells, dim=0, sorted=False).indices.permute(1, 0) # (32, ncells) 18 | cells = cells.flatten().contiguous() # (32 * ncells,) 19 | cells = cells.unique(sorted=False) 20 | return cells, scores 21 | 22 | def generate_candidate_eids(self, Q, ncells): 23 | cells, scores = self.get_cells(Q, ncells) 24 | 25 | eids, cell_lengths = self.ivf.lookup(cells) # eids = (packedlen,) lengths = (32 * ncells,) 26 | eids = eids.long() 27 | if self.use_gpu: 28 | eids = eids.cuda() 29 | return eids, scores 30 | 31 | def generate_candidate_pids(self, Q, ncells): 32 | cells, scores = self.get_cells(Q, ncells) 33 | 34 | pids, cell_lengths = self.ivf.lookup(cells) 35 | if self.use_gpu: 36 | pids = pids.cuda() 37 | return pids, scores 38 | 39 | def generate_candidate_scores(self, Q, eids): 40 | E = self.lookup_eids(eids) 41 | if self.use_gpu: 42 | E = E.cuda() 43 | return (Q.unsqueeze(0) @ E.unsqueeze(2)).squeeze(-1).T 44 | 45 | def generate_candidates(self, config, Q): 46 | ncells = config.ncells 47 | 48 | assert isinstance(self.ivf, StridedTensor) 49 | 50 | Q = Q.squeeze(0) 51 | if self.use_gpu: 52 | Q = Q.cuda().half() 53 | assert Q.dim() == 2 54 | 55 | pids, centroid_scores = self.generate_candidate_pids(Q, ncells) 56 | 57 | sorter = pids.sort() 58 | pids = sorter.values 59 | 60 | pids, pids_counts = torch.unique_consecutive(pids, return_counts=True) 61 | if self.use_gpu: 62 | pids, pids_counts = pids.cuda(), pids_counts.cuda() 63 | 64 | return pids, centroid_scores 65 | -------------------------------------------------------------------------------- /third_party/colbert/search/index_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ujson 3 | import torch 4 | import numpy as np 5 | import tqdm 6 | 7 | from third_party.colbert.utils.utils import lengths2offsets, print_message, dotdict, flatten 8 | from third_party.colbert.indexing.codecs.residual import ResidualCodec 9 | from third_party.colbert.indexing.utils import optimize_ivf 10 | from third_party.colbert.search.strided_tensor import StridedTensor 11 | 12 | 13 | class IndexLoader: 14 | def __init__(self, index_path, use_gpu=True, load_index_with_mmap=False): 15 | self.index_path = index_path 16 | self.use_gpu = use_gpu 17 | self.load_index_with_mmap = load_index_with_mmap 18 | 19 | self._load_codec() 20 | self._load_ivf() 21 | 22 | self._load_doclens() 23 | self._load_embeddings() 24 | 25 | def _load_codec(self): 26 | print_message(f"#> Loading codec...") 27 | self.codec = ResidualCodec.load(self.index_path) 28 | 29 | def _load_ivf(self): 30 | print_message(f"#> Loading IVF...") 31 | 32 | if os.path.exists(os.path.join(self.index_path, "ivf.pid.pt")): 33 | ivf, ivf_lengths = torch.load(os.path.join(self.index_path, "ivf.pid.pt"), map_location='cpu') 34 | else: 35 | assert os.path.exists(os.path.join(self.index_path, "ivf.pt")) 36 | ivf, ivf_lengths = torch.load(os.path.join(self.index_path, "ivf.pt"), map_location='cpu') 37 | ivf, ivf_lengths = optimize_ivf(ivf, ivf_lengths, self.index_path) 38 | 39 | if False: 40 | ivf = ivf.tolist() 41 | ivf = [ivf[offset:endpos] for offset, endpos in lengths2offsets(ivf_lengths)] 42 | else: 43 | # ivf, ivf_lengths = ivf.cuda(), torch.LongTensor(ivf_lengths).cuda() # FIXME: REMOVE THIS LINE! 44 | ivf = StridedTensor(ivf, ivf_lengths, use_gpu=self.use_gpu) 45 | 46 | self.ivf = ivf 47 | 48 | def _load_doclens(self): 49 | doclens = [] 50 | 51 | print_message("#> Loading doclens...") 52 | 53 | for chunk_idx in tqdm.tqdm(range(self.num_chunks)): 54 | with open(os.path.join(self.index_path, f'doclens.{chunk_idx}.json')) as f: 55 | chunk_doclens = ujson.load(f) 56 | doclens.extend(chunk_doclens) 57 | 58 | self.doclens = torch.tensor(doclens) 59 | 60 | def _load_embeddings(self): 61 | self.embeddings = ResidualCodec.Embeddings.load_chunks( 62 | self.index_path, 63 | range(self.num_chunks), 64 | self.num_embeddings, 65 | self.load_index_with_mmap, 66 | ) 67 | 68 | @property 69 | def metadata(self): 70 | try: 71 | self._metadata 72 | except: 73 | with open(os.path.join(self.index_path, 'metadata.json')) as f: 74 | self._metadata = ujson.load(f) 75 | 76 | return self._metadata 77 | 78 | @property 79 | def config(self): 80 | raise NotImplementedError() # load from dict at metadata['config'] 81 | 82 | @property 83 | def num_chunks(self): 84 | # EVENTUALLY: If num_chunks doesn't exist (i.e., old index), fall back to counting doclens.*.json files. 85 | return self.metadata['num_chunks'] 86 | 87 | @property 88 | def num_embeddings(self): 89 | # EVENTUALLY: If num_embeddings doesn't exist (i.e., old index), sum the values in doclens.*.json files. 90 | return self.metadata['num_embeddings'] 91 | 92 | -------------------------------------------------------------------------------- /third_party/colbert/search/segmented_lookup.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | typedef struct { 8 | int tid; 9 | pthread_mutex_t* mutex; 10 | std::queue* queue; 11 | 12 | int64_t ndocs; 13 | int64_t noutputs; 14 | int64_t dim; 15 | 16 | void* input; 17 | int64_t* lengths; 18 | int64_t* offsets; 19 | int64_t* cumulative_lengths; 20 | 21 | void* output; 22 | } lookup_args_t; 23 | 24 | template 25 | void* lookup(void* args) { 26 | lookup_args_t* lookup_args = (lookup_args_t*)args; 27 | 28 | int64_t* lengths = lookup_args->lengths; 29 | int64_t* cumulative_lengths = lookup_args->cumulative_lengths; 30 | int64_t* offsets = lookup_args->offsets; 31 | int64_t dim = lookup_args->dim; 32 | 33 | T* input = static_cast(lookup_args->input); 34 | T* output = static_cast(lookup_args->output); 35 | 36 | while (1) { 37 | pthread_mutex_lock(lookup_args->mutex); 38 | if (lookup_args->queue->empty()) { 39 | pthread_mutex_unlock(lookup_args->mutex); 40 | return NULL; 41 | } 42 | int i = lookup_args->queue->front(); 43 | lookup_args->queue->pop(); 44 | pthread_mutex_unlock(lookup_args->mutex); 45 | 46 | std::memcpy(output + (cumulative_lengths[i] * dim), 47 | input + (offsets[i] * dim), lengths[i] * dim * sizeof(T)); 48 | } 49 | } 50 | 51 | template 52 | torch::Tensor segmented_lookup_impl(const torch::Tensor input, 53 | const torch::Tensor pids, 54 | const torch::Tensor lengths, 55 | const torch::Tensor offsets) { 56 | auto lengths_a = lengths.data_ptr(); 57 | auto offsets_a = offsets.data_ptr(); 58 | 59 | int64_t ndocs = pids.size(0); 60 | int64_t noutputs = std::accumulate(lengths_a, lengths_a + ndocs, 0); 61 | 62 | int nthreads = at::get_num_threads(); 63 | 64 | int64_t dim; 65 | torch::Tensor output; 66 | 67 | if (input.dim() == 1) { 68 | dim = 1; 69 | output = torch::zeros({noutputs}, input.options()); 70 | } else { 71 | assert(input.dim() == 2); 72 | dim = input.size(1); 73 | output = torch::zeros({noutputs, dim}, input.options()); 74 | } 75 | 76 | int64_t cumulative_lengths[ndocs + 1]; 77 | cumulative_lengths[0] = 0; 78 | std::partial_sum(lengths_a, lengths_a + ndocs, cumulative_lengths + 1); 79 | 80 | pthread_mutex_t mutex; 81 | int rc = pthread_mutex_init(&mutex, NULL); 82 | if (rc) { 83 | fprintf(stderr, "Unable to init mutex: %d\n", rc); 84 | } 85 | 86 | std::queue queue; 87 | for (int i = 0; i < ndocs; i++) { 88 | queue.push(i); 89 | } 90 | 91 | pthread_t threads[nthreads]; 92 | lookup_args_t args[nthreads]; 93 | for (int i = 0; i < nthreads; i++) { 94 | args[i].tid = i; 95 | args[i].mutex = &mutex; 96 | args[i].queue = &queue; 97 | 98 | args[i].ndocs = ndocs; 99 | args[i].noutputs = noutputs; 100 | args[i].dim = dim; 101 | 102 | args[i].input = (void*)input.data_ptr(); 103 | args[i].lengths = lengths_a; 104 | args[i].offsets = offsets_a; 105 | args[i].cumulative_lengths = cumulative_lengths; 106 | 107 | args[i].output = (void*)output.data_ptr(); 108 | 109 | rc = pthread_create(&threads[i], NULL, lookup, (void*)&args[i]); 110 | if (rc) { 111 | fprintf(stderr, "Unable to create thread %d: %d\n", i, rc); 112 | } 113 | } 114 | 115 | for (int i = 0; i < nthreads; i++) { 116 | pthread_join(threads[i], NULL); 117 | } 118 | 119 | rc = pthread_mutex_destroy(&mutex); 120 | if (rc) { 121 | fprintf(stderr, "Unable to destroy mutex: %d\n", rc); 122 | } 123 | 124 | return output; 125 | } 126 | 127 | torch::Tensor segmented_lookup(const torch::Tensor input, 128 | const torch::Tensor pids, 129 | const torch::Tensor lengths, 130 | const torch::Tensor offsets) { 131 | if (input.dtype() == torch::kUInt8) { 132 | return segmented_lookup_impl(input, pids, lengths, offsets); 133 | } else if (input.dtype() == torch::kInt32) { 134 | return segmented_lookup_impl(input, pids, lengths, offsets); 135 | } else if (input.dtype() == torch::kInt64) { 136 | return segmented_lookup_impl(input, pids, lengths, offsets); 137 | } else if (input.dtype() == torch::kFloat32) { 138 | return segmented_lookup_impl(input, pids, lengths, offsets); 139 | } else if (input.dtype() == torch::kFloat16) { 140 | return segmented_lookup_impl(input, pids, lengths, offsets); 141 | } else { 142 | assert(false); 143 | } 144 | } 145 | 146 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 147 | m.def("segmented_lookup_cpp", &segmented_lookup, "Segmented lookup"); 148 | } 149 | -------------------------------------------------------------------------------- /third_party/colbert/search/strided_tensor_core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | import numpy as np 5 | 6 | from third_party.colbert.utils.utils import flatten 7 | 8 | 9 | """ 10 | import line_profiler 11 | import atexit 12 | profile = line_profiler.LineProfiler() 13 | atexit.register(profile.print_stats) 14 | """ 15 | 16 | 17 | class StridedTensorCore: 18 | # # @profile 19 | def __init__(self, packed_tensor, lengths, dim=None, use_gpu=True): 20 | self.dim = dim 21 | self.tensor = packed_tensor 22 | self.inner_dims = self.tensor.size()[1:] 23 | self.use_gpu = use_gpu 24 | 25 | self.lengths = lengths.long() if torch.is_tensor(lengths) else torch.LongTensor(lengths) 26 | 27 | self.strides = _select_strides(self.lengths, [.5, .75, .9, .95]) + [self.lengths.max().item()] 28 | self.max_stride = self.strides[-1] 29 | 30 | zero = torch.zeros(1, dtype=torch.long, device=self.lengths.device) 31 | self.offsets = torch.cat((zero, torch.cumsum(self.lengths, dim=0))) 32 | 33 | if self.offsets[-2] + self.max_stride > self.tensor.size(0): 34 | # if self.tensor.size(0) > 10_000_000: 35 | # print("#> WARNING: StridedTensor has to add padding, internally, to a large tensor.") 36 | # print("#> WARNING: Consider doing this padding in advance to save memory!") 37 | 38 | padding = torch.zeros(self.max_stride, *self.inner_dims, dtype=self.tensor.dtype, device=self.tensor.device) 39 | self.tensor = torch.cat((self.tensor, padding)) 40 | 41 | self.views = {stride: _create_view(self.tensor, stride, self.inner_dims) for stride in self.strides} 42 | 43 | @classmethod 44 | def from_packed_tensor(cls, tensor, lengths): 45 | return cls(tensor, lengths) 46 | 47 | @classmethod 48 | def from_padded_tensor(cls, tensor, mask): 49 | pass 50 | 51 | @classmethod 52 | def from_nested_list(cls, lst): 53 | flat_lst = flatten(lst) 54 | 55 | tensor = torch.Tensor(flat_lst) 56 | lengths = [len(sublst) for sublst in lst] 57 | 58 | return cls(tensor, lengths, dim=0) 59 | 60 | @classmethod 61 | def from_tensors_list(cls, tensors): 62 | # torch.cat(tensors) 63 | # lengths. 64 | # cls(tensor, lengths) 65 | raise NotImplementedError() 66 | 67 | def as_packed_tensor(self, return_offsets=False): 68 | unpadded_packed_tensor = self.tensor # [:self.offsets[-1]] 69 | 70 | return_vals = [unpadded_packed_tensor, self.lengths] 71 | 72 | if return_offsets: 73 | return_vals.append(self.offsets) 74 | 75 | return tuple(return_vals) 76 | 77 | # # @profile 78 | def as_padded_tensor(self): 79 | if self.use_gpu: 80 | view = _create_view(self.tensor.cuda(), self.max_stride, self.inner_dims)[self.offsets[:-1]] 81 | mask = _create_mask(self.lengths.cuda(), self.max_stride, like=view, use_gpu=self.use_gpu) 82 | else: 83 | #import pdb 84 | #pdb.set_trace() 85 | view = _create_view(self.tensor, self.max_stride, self.inner_dims) 86 | view = view[self.offsets[:-1]] 87 | mask = _create_mask(self.lengths, self.max_stride, like=view, use_gpu=self.use_gpu) 88 | 89 | return view, mask 90 | 91 | def as_tensors_list(self): 92 | raise NotImplementedError() 93 | 94 | 95 | 96 | def _select_strides(lengths, quantiles): 97 | if lengths.size(0) < 5_000: 98 | return _get_quantiles(lengths, quantiles) 99 | 100 | sample = torch.randint(0, lengths.size(0), size=(2_000,)) 101 | 102 | return _get_quantiles(lengths[sample], quantiles) 103 | 104 | def _get_quantiles(lengths, quantiles): 105 | return torch.quantile(lengths.float(), torch.tensor(quantiles, device=lengths.device)).int().tolist() 106 | 107 | 108 | def _create_view(tensor, stride, inner_dims): 109 | outdim = tensor.size(0) - stride + 1 110 | size = (outdim, stride, *inner_dims) 111 | 112 | inner_dim_prod = int(np.prod(inner_dims)) 113 | multidim_stride = [inner_dim_prod, inner_dim_prod] + [1] * len(inner_dims) 114 | 115 | return torch.as_strided(tensor, size=size, stride=multidim_stride) 116 | 117 | 118 | def _create_mask(lengths, stride, like=None, use_gpu=True): 119 | if use_gpu: 120 | mask = torch.arange(stride).cuda() + 1 121 | mask = mask.unsqueeze(0) <= lengths.cuda().unsqueeze(-1) 122 | else: 123 | mask = torch.arange(stride) + 1 124 | mask = mask.unsqueeze(0) <= lengths.unsqueeze(-1) 125 | 126 | if like is not None: 127 | for _ in range(like.dim() - mask.dim()): 128 | mask = mask.unsqueeze(-1) 129 | 130 | return mask 131 | -------------------------------------------------------------------------------- /third_party/colbert/tests/e2e_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from collections import namedtuple 4 | from datasets import load_dataset 5 | from third_party.utility.utils.dpr import has_answer, DPR_normalize 6 | import tqdm 7 | 8 | from third_party.colbert import Indexer, Searcher 9 | from third_party.colbert.infra import ColBERTConfig, RunConfig, Run 10 | 11 | SquadExample = namedtuple("SquadExample", "id title context question answers") 12 | 13 | 14 | def build_index_and_init_searcher(checkpoint, collection, experiment_dir): 15 | nbits = 1 # encode each dimension with 1 bits 16 | doc_maxlen = 180 # truncate passages at 180 tokens 17 | experiment = f"e2etest.nbits={nbits}" 18 | 19 | with Run().context(RunConfig(nranks=1)): 20 | config = ColBERTConfig( 21 | doc_maxlen=doc_maxlen, 22 | nbits=nbits, 23 | root=experiment_dir, 24 | experiment=experiment, 25 | ) 26 | indexer = Indexer(checkpoint, config=config) 27 | indexer.index(name=experiment, collection=collection, overwrite=True) 28 | 29 | config = ColBERTConfig( 30 | root=experiment_dir, 31 | experiment=experiment, 32 | ) 33 | searcher = Searcher( 34 | index=experiment, 35 | config=config, 36 | ) 37 | 38 | return searcher 39 | 40 | 41 | def success_at_k(searcher, examples, k): 42 | scores = [] 43 | for ex in tqdm.tqdm(examples): 44 | scores.append(evaluate_retrieval_example(searcher, ex, k)) 45 | return sum(scores) / len(scores) 46 | 47 | 48 | def evaluate_retrieval_example(searcher, ex, k): 49 | results = searcher.search(ex.question, k=k) 50 | for passage_id, passage_rank, passage_score in zip(*results): 51 | passage = searcher.collection[passage_id] 52 | score = has_answer([DPR_normalize(ans) for ans in ex.answers], passage) 53 | if score: 54 | return 1 55 | return 0 56 | 57 | 58 | def get_squad_split(squad, split="validation"): 59 | fields = squad[split].features 60 | data = zip(*[squad[split][field] for field in fields]) 61 | return [ 62 | SquadExample(eid, title, context, question, answers["text"]) 63 | for eid, title, context, question, answers in data 64 | ] 65 | 66 | 67 | def main(args): 68 | checkpoint = args.checkpoint 69 | collection = args.collection 70 | experiment_dir = args.expdir 71 | 72 | # Start the test 73 | k = 5 74 | searcher = build_index_and_init_searcher(checkpoint, collection, experiment_dir) 75 | 76 | squad = load_dataset("squad") 77 | squad_dev = get_squad_split(squad) 78 | success_rate = success_at_k(searcher, squad_dev[:1000], k) 79 | assert success_rate > 0.93, f"success rate at {success_rate} is lower than expected" 80 | print(f"test passed with success rate {success_rate}") 81 | 82 | 83 | if __name__ == "__main__": 84 | parser = argparse.ArgumentParser(description="Start end-to-end test.") 85 | parser.add_argument( 86 | "--checkpoint", type=str, required=True, help="Model checkpoint" 87 | ) 88 | parser.add_argument( 89 | "--collection", type=str, required=True, help="Path to collection" 90 | ) 91 | parser.add_argument( 92 | "--expdir", type=str, required=True, help="Experiment directory" 93 | ) 94 | args = parser.parse_args() 95 | main(args) 96 | -------------------------------------------------------------------------------- /third_party/colbert/tests/index_coalesce_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import ujson 4 | from tqdm import tqdm 5 | 6 | import argparse 7 | 8 | def main(args): 9 | # TODO: compare residual, codes, and doclens 10 | # Get the number of chunks in the multi-file index 11 | single_path = args.single 12 | multi_path = args.multi 13 | 14 | # Get num_chunks and num_embeddings from metadata 15 | filepath = os.path.join(multi_path, 'metadata.json') 16 | with open(filepath, 'r') as f: 17 | metadata = ujson.load(f) 18 | num_chunks = metadata['num_chunks'] 19 | print(f"Num_chunks = {num_chunks}") 20 | 21 | num_embeddings = metadata['num_embeddings'] 22 | print(f"Num_embeddings = {num_embeddings}") 23 | 24 | dim = metadata['config']['dim'] 25 | nbits = metadata['config']['nbits'] 26 | 27 | ## Load and compare doclens ## 28 | # load multi-file doclens 29 | print("Loading doclens from multi-file index") 30 | multi_doclens = [] 31 | for chunk_idx in tqdm(range(num_chunks)): 32 | with open(os.path.join(multi_path, f"doclens.{chunk_idx}.json"), 'r') as f: 33 | chunk = ujson.load(f) 34 | multi_doclens.extend(chunk) 35 | 36 | # load single-file doclens 37 | print("Loading doclens from single-file index") 38 | single_doclens = [] 39 | for _ in tqdm(range(1)): 40 | with open(os.path.join(single_path, "doclens.0.json"), 'r') as f: 41 | single_doclens = ujson.load(f) 42 | 43 | # compare doclens 44 | if (multi_doclens != single_doclens): 45 | print("Doclens do not match!") 46 | print("Multi-file doclens size = {}".format(len(multi_doclens))) 47 | print("Single-file doclens size = {}".format(len(single_doclens))) 48 | else: 49 | print("Doclens match") 50 | 51 | ## Load and compare codes ## 52 | # load multi-file codes 53 | print("Loading codes from multi-file index") 54 | multi_codes = torch.empty(num_embeddings, dtype=torch.int32) 55 | offset = 0 56 | for chunk_idx in tqdm(range(num_chunks)): 57 | chunk = torch.load(os.path.join(multi_path, f"{chunk_idx}.codes.pt")) 58 | endpos = offset + chunk.size(0) 59 | multi_codes[offset:endpos] = chunk 60 | offset = endpos 61 | 62 | # load single-file codes 63 | print("Loading codes from single-file index") 64 | single_codes = [] 65 | for _ in tqdm(range(1)): 66 | single_codes = torch.load(os.path.join(single_path, "0.codes.pt")) 67 | 68 | if (single_codes.size(0) != num_embeddings): 69 | print("Codes are the wrong size!") 70 | 71 | # compare codes 72 | if torch.equal(multi_codes, single_codes): 73 | print("Codes match") 74 | else: 75 | print("Codes do not match!") 76 | 77 | ## Load and compare residuals ## 78 | # load multi-file residuals 79 | print("Loading residuals from multi-file index") 80 | multi_residuals = torch.empty(num_embeddings, dim // 8 * nbits, dtype=torch.uint8) 81 | offset = 0 82 | for chunk_idx in tqdm(range(num_chunks)): 83 | chunk = torch.load(os.path.join(multi_path, f"{chunk_idx}.residuals.pt")) 84 | endpos = offset + chunk.size(0) 85 | multi_residuals[offset:endpos] = chunk 86 | offset = endpos 87 | 88 | # load single-file residuals 89 | print("Loading residuals from single-file index") 90 | single_residuals = [] 91 | for _ in tqdm(range(1)): 92 | single_residuals = torch.load(os.path.join(single_path, "0.residuals.pt")) 93 | 94 | # compare residuals 95 | if torch.equal(multi_residuals, single_residuals): 96 | print("Residuals match") 97 | else: 98 | print("Residuals do not match!") 99 | 100 | if __name__ == "__main__": 101 | parser = argparse.ArgumentParser(description="Compare single-file and multi-file indexes.") 102 | parser.add_argument( 103 | "--single", type=str, required=True, help="Path to single-file index." 104 | ) 105 | parser.add_argument( 106 | "--multi", type=str, required=True, help="Path to multi-file index." 107 | ) 108 | 109 | args = parser.parse_args() 110 | main(args) 111 | 112 | print("Exiting test") 113 | -------------------------------------------------------------------------------- /third_party/colbert/trainer.py: -------------------------------------------------------------------------------- 1 | from third_party.colbert.infra.run import Run 2 | from third_party.colbert.infra.launcher import Launcher 3 | from third_party.colbert.infra.config import ColBERTConfig, RunConfig 4 | 5 | from third_party.colbert.training.training import train 6 | 7 | 8 | class Trainer: 9 | def __init__(self, triples, queries, collection, config=None): 10 | self.config = ColBERTConfig.from_existing(config, Run().config) 11 | 12 | self.triples = triples 13 | self.queries = queries 14 | self.collection = collection 15 | 16 | def configure(self, **kw_args): 17 | self.config.configure(**kw_args) 18 | 19 | def train(self, checkpoint='bert-base-uncased'): 20 | """ 21 | Note that config.checkpoint is ignored. Only the supplied checkpoint here is used. 22 | """ 23 | 24 | # Resources don't come from the config object. They come from the input parameters. 25 | # TODO: After the API stabilizes, make this "self.config.assign()" to emphasize this distinction. 26 | self.configure(triples=self.triples, queries=self.queries, collection=self.collection) 27 | self.configure(checkpoint=checkpoint) 28 | 29 | launcher = Launcher(train) 30 | 31 | self._best_checkpoint_path = launcher.launch(self.config, self.triples, self.queries, self.collection) 32 | 33 | 34 | def best_checkpoint_path(self): 35 | return self._best_checkpoint_path 36 | 37 | -------------------------------------------------------------------------------- /third_party/colbert/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/training/__init__.py -------------------------------------------------------------------------------- /third_party/colbert/training/eager_batcher.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ujson 3 | 4 | from functools import partial 5 | from third_party.colbert.utils.utils import print_message 6 | from third_party.colbert.modeling.tokenization import QueryTokenizer, DocTokenizer, tensorize_triples 7 | 8 | from third_party.colbert.utils.runs import Run 9 | 10 | 11 | class EagerBatcher(): 12 | def __init__(self, args, rank=0, nranks=1): 13 | self.rank, self.nranks = rank, nranks 14 | self.bsize, self.accumsteps = args.bsize, args.accumsteps 15 | 16 | self.query_tokenizer = QueryTokenizer(args.query_maxlen) 17 | self.doc_tokenizer = DocTokenizer(args.doc_maxlen) 18 | self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer) 19 | 20 | self.triples_path = args.triples 21 | self._reset_triples() 22 | 23 | def _reset_triples(self): 24 | self.reader = open(self.triples_path, mode='r', encoding="utf-8") 25 | self.position = 0 26 | 27 | def __iter__(self): 28 | return self 29 | 30 | def __next__(self): 31 | queries, positives, negatives = [], [], [] 32 | 33 | for line_idx, line in zip(range(self.bsize * self.nranks), self.reader): 34 | if (self.position + line_idx) % self.nranks != self.rank: 35 | continue 36 | 37 | query, pos, neg = line.strip().split('\t') 38 | 39 | queries.append(query) 40 | positives.append(pos) 41 | negatives.append(neg) 42 | 43 | self.position += line_idx + 1 44 | 45 | if len(queries) < self.bsize: 46 | raise StopIteration 47 | 48 | return self.collate(queries, positives, negatives) 49 | 50 | def collate(self, queries, positives, negatives): 51 | assert len(queries) == len(positives) == len(negatives) == self.bsize 52 | 53 | return self.tensorize_triples(queries, positives, negatives, self.bsize // self.accumsteps) 54 | 55 | def skip_to_batch(self, batch_idx, intended_batch_size): 56 | self._reset_triples() 57 | 58 | Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.') 59 | 60 | _ = [self.reader.readline() for _ in range(batch_idx * intended_batch_size)] 61 | 62 | return None 63 | -------------------------------------------------------------------------------- /third_party/colbert/training/lazy_batcher.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ujson 3 | 4 | from functools import partial 5 | from third_party.colbert.infra.config.config import ColBERTConfig 6 | from third_party.colbert.utils.utils import print_message, zipstar 7 | from third_party.colbert.modeling.tokenization import QueryTokenizer, DocTokenizer, tensorize_triples 8 | from third_party.colbert.evaluation.loaders import load_collection 9 | 10 | from third_party.colbert.data.collection import Collection 11 | from third_party.colbert.data.queries import Queries 12 | from third_party.colbert.data.examples import Examples 13 | 14 | # from third_party.colbert.utils.runs import Run 15 | 16 | 17 | class LazyBatcher(): 18 | def __init__(self, config: ColBERTConfig, triples, queries, collection, rank=0, nranks=1): 19 | self.bsize, self.accumsteps = config.bsize, config.accumsteps 20 | self.nway = config.nway 21 | 22 | self.query_tokenizer = QueryTokenizer(config) 23 | self.doc_tokenizer = DocTokenizer(config) 24 | self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer) 25 | self.position = 0 26 | 27 | self.triples = Examples.cast(triples, nway=self.nway).tolist(rank, nranks) 28 | self.queries = Queries.cast(queries) 29 | self.collection = Collection.cast(collection) 30 | assert len(self.triples) > 0, "Received no triples on which to train." 31 | assert len(self.queries) > 0, "Received no queries on which to train." 32 | assert len(self.collection) > 0, "Received no collection on which to train." 33 | 34 | def __iter__(self): 35 | return self 36 | 37 | def __len__(self): 38 | return len(self.triples) 39 | 40 | def __next__(self): 41 | offset, endpos = self.position, min(self.position + self.bsize, len(self.triples)) 42 | self.position = endpos 43 | 44 | if offset + self.bsize > len(self.triples): 45 | raise StopIteration 46 | 47 | all_queries, all_passages, all_scores = [], [], [] 48 | 49 | for position in range(offset, endpos): 50 | query, *pids = self.triples[position] 51 | pids = pids[:self.nway] 52 | 53 | query = self.queries[query] 54 | 55 | try: 56 | pids, scores = zipstar(pids) 57 | except: 58 | scores = [] 59 | 60 | passages = [self.collection[pid] for pid in pids] 61 | 62 | all_queries.append(query) 63 | all_passages.extend(passages) 64 | all_scores.extend(scores) 65 | 66 | assert len(all_scores) in [0, len(all_passages)], len(all_scores) 67 | 68 | return self.collate(all_queries, all_passages, all_scores) 69 | 70 | def collate(self, queries, passages, scores): 71 | assert len(queries) == self.bsize 72 | assert len(passages) == self.nway * self.bsize 73 | 74 | return self.tensorize_triples(queries, passages, scores, self.bsize // self.accumsteps, self.nway) 75 | 76 | # def skip_to_batch(self, batch_idx, intended_batch_size): 77 | # Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.') 78 | # self.position = intended_batch_size * batch_idx 79 | -------------------------------------------------------------------------------- /third_party/colbert/training/rerank_batcher.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ujson 3 | 4 | from functools import partial 5 | from third_party.colbert.infra.config.config import ColBERTConfig 6 | from third_party.colbert.utils.utils import flatten, print_message, zipstar 7 | from third_party.colbert.modeling.reranker.tokenizer import RerankerTokenizer 8 | 9 | from third_party.colbert.data.collection import Collection 10 | from third_party.colbert.data.queries import Queries 11 | from third_party.colbert.data.examples import Examples 12 | 13 | # from third_party.colbert.utils.runs import Run 14 | 15 | 16 | class RerankBatcher(): 17 | def __init__(self, config: ColBERTConfig, triples, queries, collection, rank=0, nranks=1): 18 | self.bsize, self.accumsteps = config.bsize, config.accumsteps 19 | self.nway = config.nway 20 | 21 | assert self.accumsteps == 1, "The tensorizer doesn't support larger accumsteps yet --- but it's easy to add." 22 | 23 | self.tokenizer = RerankerTokenizer(total_maxlen=config.doc_maxlen, base=config.checkpoint) 24 | self.position = 0 25 | 26 | self.triples = Examples.cast(triples, nway=self.nway).tolist(rank, nranks) 27 | self.queries = Queries.cast(queries) 28 | self.collection = Collection.cast(collection) 29 | 30 | assert len(self.triples) > 0, "Received no triples on which to train." 31 | assert len(self.queries) > 0, "Received no queries on which to train." 32 | assert len(self.collection) > 0, "Received no collection on which to train." 33 | 34 | def __iter__(self): 35 | return self 36 | 37 | def __len__(self): 38 | return len(self.triples) 39 | 40 | def __next__(self): 41 | offset, endpos = self.position, min(self.position + self.bsize, len(self.triples)) 42 | self.position = endpos 43 | 44 | if offset + self.bsize > len(self.triples): 45 | raise StopIteration 46 | 47 | all_queries, all_passages, all_scores = [], [], [] 48 | 49 | for position in range(offset, endpos): 50 | query, *pids = self.triples[position] 51 | pids = pids[:self.nway] 52 | 53 | query = self.queries[query] 54 | 55 | try: 56 | pids, scores = zipstar(pids) 57 | except: 58 | scores = [] 59 | 60 | passages = [self.collection[pid] for pid in pids] 61 | 62 | all_queries.append(query) 63 | all_passages.extend(passages) 64 | all_scores.extend(scores) 65 | 66 | assert len(all_scores) in [0, len(all_passages)], len(all_scores) 67 | 68 | return self.collate(all_queries, all_passages, all_scores) 69 | 70 | def collate(self, queries, passages, scores): 71 | assert len(queries) == self.bsize 72 | assert len(passages) == self.nway * self.bsize 73 | 74 | queries = flatten([[query] * self.nway for query in queries]) 75 | return [(self.tokenizer.tensorize(queries, passages), scores)] 76 | 77 | # def skip_to_batch(self, batch_idx, intended_batch_size): 78 | # Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.') 79 | # self.position = intended_batch_size * batch_idx 80 | -------------------------------------------------------------------------------- /third_party/colbert/training/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | # from third_party.colbert.utils.runs import Run 5 | from third_party.colbert.utils.utils import print_message, save_checkpoint 6 | from third_party.colbert.parameters import SAVED_CHECKPOINTS 7 | from third_party.colbert.infra.run import Run 8 | 9 | 10 | def print_progress(scores): 11 | positive_avg, negative_avg = round(scores[:, 0].mean().item(), 2), round(scores[:, 1].mean().item(), 2) 12 | print("#>>> ", positive_avg, negative_avg, '\t\t|\t\t', positive_avg - negative_avg) 13 | 14 | 15 | def manage_checkpoints(args, colbert, optimizer, batch_idx, savepath=None, consumed_all_triples=False): 16 | # arguments = dict(args) 17 | 18 | # TODO: Call provenance() on the values that support it?? 19 | 20 | checkpoints_path = savepath or os.path.join(Run().path_, 'checkpoints') 21 | name = None 22 | 23 | try: 24 | save = colbert.save 25 | except: 26 | save = colbert.module.save 27 | 28 | if not os.path.exists(checkpoints_path): 29 | os.makedirs(checkpoints_path) 30 | 31 | path_save = None 32 | 33 | if consumed_all_triples or (batch_idx % 2000 == 0): 34 | # name = os.path.join(path, "colbert.dnn") 35 | # save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments) 36 | path_save = os.path.join(checkpoints_path, "colbert") 37 | 38 | if batch_idx in SAVED_CHECKPOINTS: 39 | # name = os.path.join(path, "colbert-{}.dnn".format(batch_idx)) 40 | # save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments) 41 | path_save = os.path.join(checkpoints_path, f"colbert-{batch_idx}") 42 | 43 | if path_save: 44 | print(f"#> Saving a checkpoint to {path_save} ..") 45 | 46 | checkpoint = {} 47 | checkpoint['batch'] = batch_idx 48 | # checkpoint['epoch'] = 0 49 | # checkpoint['model_state_dict'] = model.state_dict() 50 | # checkpoint['optimizer_state_dict'] = optimizer.state_dict() 51 | # checkpoint['arguments'] = arguments 52 | 53 | save(path_save) 54 | 55 | return path_save 56 | -------------------------------------------------------------------------------- /third_party/colbert/utilities/annotate_em.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import git 5 | import tqdm 6 | import ujson 7 | import random 8 | 9 | from argparse import ArgumentParser 10 | from multiprocessing import Pool 11 | 12 | from third_party.colbert.utils.utils import groupby_first_item, print_message 13 | from third_party.utility.utils.qa_loaders import load_qas_, load_collection_ 14 | from third_party.utility.utils.save_metadata import format_metadata, get_metadata 15 | from third_party.utility.evaluate.annotate_EM_helpers import * 16 | 17 | from third_party.colbert.infra.run import Run 18 | from third_party.colbert.data.collection import Collection 19 | from third_party.colbert.data.ranking import Ranking 20 | 21 | 22 | class AnnotateEM: 23 | def __init__(self, collection, qas): 24 | # TODO: These should just be Queries! But Queries needs to support looking up answers as qid2answers below. 25 | qas = load_qas_(qas) 26 | collection = Collection.cast(collection) # .tolist() #load_collection_(collection, retain_titles=True) 27 | 28 | self.parallel_pool = Pool(30) 29 | 30 | print_message('#> Tokenize the answers in the Q&As in parallel...') 31 | qas = list(self.parallel_pool.map(tokenize_all_answers, qas)) 32 | 33 | qid2answers = {qid: tok_answers for qid, _, tok_answers in qas} 34 | assert len(qas) == len(qid2answers), (len(qas), len(qid2answers)) 35 | 36 | self.qas, self.collection = qas, collection 37 | self.qid2answers = qid2answers 38 | 39 | def annotate(self, ranking): 40 | rankings = Ranking.cast(ranking) 41 | 42 | # print(len(rankings), rankings[0]) 43 | 44 | print_message('#> Lookup passages from PIDs...') 45 | expanded_rankings = [(qid, pid, rank, self.collection[pid], self.qid2answers[qid]) 46 | for qid, pid, rank, *_ in rankings.tolist()] 47 | 48 | print_message('#> Assign labels in parallel...') 49 | labeled_rankings = list(self.parallel_pool.map(assign_label_to_passage, enumerate(expanded_rankings))) 50 | 51 | # Dump output. 52 | self.qid2rankings = groupby_first_item(labeled_rankings) 53 | 54 | self.num_judged_queries, self.num_ranked_queries = check_sizes(self.qid2answers, self.qid2rankings) 55 | 56 | # Evaluation metrics and depths. 57 | self.success, self.counts = self._compute_labels(self.qid2answers, self.qid2rankings) 58 | 59 | print(rankings.provenance(), self.success) 60 | 61 | return Ranking(data=self.qid2rankings, provenance=("AnnotateEM", rankings.provenance())) 62 | 63 | def _compute_labels(self, qid2answers, qid2rankings): 64 | cutoffs = [1, 5, 10, 20, 30, 50, 100, 1000, 'all'] 65 | success = {cutoff: 0.0 for cutoff in cutoffs} 66 | counts = {cutoff: 0.0 for cutoff in cutoffs} 67 | 68 | for qid in qid2answers: 69 | if qid not in qid2rankings: 70 | continue 71 | 72 | prev_rank = 0 # ranks should start at one (i.e., and not zero) 73 | labels = [] 74 | 75 | for pid, rank, label in qid2rankings[qid]: 76 | assert rank == prev_rank+1, (qid, pid, (prev_rank, rank)) 77 | prev_rank = rank 78 | 79 | labels.append(label) 80 | 81 | for cutoff in cutoffs: 82 | if cutoff != 'all': 83 | success[cutoff] += sum(labels[:cutoff]) > 0 84 | counts[cutoff] += sum(labels[:cutoff]) 85 | else: 86 | success[cutoff] += sum(labels) > 0 87 | counts[cutoff] += sum(labels) 88 | 89 | return success, counts 90 | 91 | def save(self, new_path): 92 | print_message("#> Dumping output to", new_path, "...") 93 | 94 | Ranking(data=self.qid2rankings).save(new_path) 95 | 96 | # Dump metrics. 97 | with Run().open(f'{new_path}.metrics', 'w') as f: 98 | d = {'num_ranked_queries': self.num_ranked_queries, 'num_judged_queries': self.num_judged_queries} 99 | 100 | extra = '__WARNING' if self.num_judged_queries != self.num_ranked_queries else '' 101 | d[f'success{extra}'] = {k: v / self.num_judged_queries for k, v in self.success.items()} 102 | d[f'counts{extra}'] = {k: v / self.num_judged_queries for k, v in self.counts.items()} 103 | # d['arguments'] = get_metadata(args) # TODO: Need arguments... 104 | 105 | f.write(format_metadata(d) + '\n') 106 | 107 | 108 | if __name__ == '__main__': 109 | r = sys.argv[2] 110 | 111 | a = AnnotateEM(collection='/dfs/scratch0/okhattab/OpenQA/collection.tsv', 112 | qas=sys.argv[1]) 113 | a.annotate(ranking=r) 114 | -------------------------------------------------------------------------------- /third_party/colbert/utilities/create_triples.py: -------------------------------------------------------------------------------- 1 | import random 2 | from third_party.colbert.infra.provenance import Provenance 3 | 4 | from third_party.utility.utils.save_metadata import save_metadata 5 | from third_party.utility.supervision.triples import sample_for_query 6 | 7 | from third_party.colbert.utils.utils import print_message 8 | 9 | from third_party.colbert.data.ranking import Ranking 10 | from third_party.colbert.data.examples import Examples 11 | 12 | MAX_NUM_TRIPLES = 40_000_000 13 | 14 | 15 | class Triples: 16 | def __init__(self, ranking, seed=12345): 17 | random.seed(seed) # TODO: Use internal RNG instead.. 18 | self.seed = seed 19 | 20 | ranking = Ranking.cast(ranking) 21 | self.ranking_provenance = ranking.provenance() 22 | self.qid2rankings = ranking.todict() 23 | 24 | def create(self, positives, depth): 25 | assert all(len(x) == 2 for x in positives) 26 | assert all(maxBest <= maxDepth for maxBest, maxDepth in positives), positives 27 | 28 | self.positives = positives 29 | self.depth = depth 30 | 31 | Triples = [] 32 | NonEmptyQIDs = 0 33 | 34 | for processing_idx, qid in enumerate(self.qid2rankings): 35 | l = sample_for_query(qid, self.qid2rankings[qid], positives, depth, False, None) 36 | NonEmptyQIDs += (len(l) > 0) 37 | Triples.extend(l) 38 | 39 | if processing_idx % (10_000) == 0: 40 | print_message(f"#> Done with {processing_idx+1} questions!\t\t " 41 | f"{str(len(Triples) / 1000)}k triples for {NonEmptyQIDs} unqiue QIDs.") 42 | 43 | print_message(f"#> Sub-sample the triples (if > {MAX_NUM_TRIPLES})..") 44 | print_message(f"#> len(Triples) = {len(Triples)}") 45 | 46 | if len(Triples) > MAX_NUM_TRIPLES: 47 | Triples = random.sample(Triples, MAX_NUM_TRIPLES) 48 | 49 | ### Prepare the triples ### 50 | print_message("#> Shuffling the triples...") 51 | random.shuffle(Triples) 52 | 53 | self.Triples = Examples(data=Triples) 54 | 55 | return Triples 56 | 57 | def save(self, new_path): 58 | provenance = Provenance() 59 | provenance.source = 'Triples::create' 60 | provenance.seed = self.seed 61 | provenance.positives = self.positives 62 | provenance.depth = self.depth 63 | provenance.ranking = self.ranking_provenance 64 | 65 | Examples(data=self.Triples, provenance=provenance).save(new_path) 66 | -------------------------------------------------------------------------------- /third_party/colbert/utilities/minicorpus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from third_party.colbert.utils.utils import create_directory 5 | 6 | from third_party.colbert.data.collection import Collection 7 | from third_party.colbert.data.queries import Queries 8 | from third_party.colbert.data.ranking import Ranking 9 | 10 | 11 | def sample_minicorpus(name, factor, topk=30, maxdev=3000): 12 | """ 13 | Factor: 14 | * nano=1 15 | * micro=10 16 | * mini=100 17 | * small=100 with topk=100 18 | * medium=150 with topk=300 19 | """ 20 | 21 | random.seed(12345) 22 | 23 | # Load collection 24 | collection = Collection(path='/dfs/scratch0/okhattab/OpenQA/collection.tsv') 25 | 26 | # Load train and dev queries 27 | qas_train = Queries(path='/dfs/scratch0/okhattab/OpenQA/NQ/train/qas.json').qas() 28 | qas_dev = Queries(path='/dfs/scratch0/okhattab/OpenQA/NQ/dev/qas.json').qas() 29 | 30 | # Load train and dev C3 rankings 31 | ranking_train = Ranking(path='/dfs/scratch0/okhattab/OpenQA/NQ/train/rankings/C3.tsv.annotated').todict() 32 | ranking_dev = Ranking(path='/dfs/scratch0/okhattab/OpenQA/NQ/dev/rankings/C3.tsv.annotated').todict() 33 | 34 | # Sample NT and ND queries from each, keep only the top-k passages for those 35 | sample_train = random.sample(list(qas_train.keys()), min(len(qas_train.keys()), 300*factor)) 36 | sample_dev = random.sample(list(qas_dev.keys()), min(len(qas_dev.keys()), maxdev, 30*factor)) 37 | 38 | train_pids = [pid for qid in sample_train for qpids in ranking_train[qid][:topk] for pid in qpids] 39 | dev_pids = [pid for qid in sample_dev for qpids in ranking_dev[qid][:topk] for pid in qpids] 40 | 41 | sample_pids = sorted(list(set(train_pids + dev_pids))) 42 | print(f'len(sample_pids) = {len(sample_pids)}') 43 | 44 | # Save the new query sets: train and dev 45 | ROOT = f'/future/u/okhattab/root/unit/data/NQ-{name}' 46 | 47 | create_directory(os.path.join(ROOT, 'train')) 48 | create_directory(os.path.join(ROOT, 'dev')) 49 | 50 | new_train = Queries(data={qid: qas_train[qid] for qid in sample_train}) 51 | new_train.save(os.path.join(ROOT, 'train/questions.tsv')) 52 | new_train.save_qas(os.path.join(ROOT, 'train/qas.json')) 53 | 54 | new_dev = Queries(data={qid: qas_dev[qid] for qid in sample_dev}) 55 | new_dev.save(os.path.join(ROOT, 'dev/questions.tsv')) 56 | new_dev.save_qas(os.path.join(ROOT, 'dev/qas.json')) 57 | 58 | # Save the new collection 59 | print(f"Saving to {os.path.join(ROOT, 'collection.tsv')}") 60 | Collection(data=[collection[pid] for pid in sample_pids]).save(os.path.join(ROOT, 'collection.tsv')) 61 | 62 | print('#> Done!') 63 | 64 | 65 | if __name__ == '__main__': 66 | sample_minicorpus('medium', 150, topk=300) 67 | -------------------------------------------------------------------------------- /third_party/colbert/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/utils/__init__.py -------------------------------------------------------------------------------- /third_party/colbert/utils/amp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from contextlib import contextmanager 4 | from third_party.colbert.utils.utils import NullContextManager 5 | 6 | 7 | class MixedPrecisionManager(): 8 | def __init__(self, activated): 9 | self.activated = activated 10 | 11 | if self.activated: 12 | self.scaler = torch.cuda.amp.GradScaler() 13 | 14 | def context(self): 15 | return torch.cuda.amp.autocast() if self.activated else NullContextManager() 16 | 17 | def backward(self, loss): 18 | if self.activated: 19 | self.scaler.scale(loss).backward() 20 | else: 21 | loss.backward() 22 | 23 | def step(self, colbert, optimizer, scheduler=None): 24 | if self.activated: 25 | self.scaler.unscale_(optimizer) 26 | torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0, error_if_nonfinite=False) 27 | 28 | self.scaler.step(optimizer) 29 | self.scaler.update() 30 | else: 31 | torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0) 32 | optimizer.step() 33 | 34 | if scheduler is not None: 35 | scheduler.step() 36 | 37 | optimizer.zero_grad() 38 | -------------------------------------------------------------------------------- /third_party/colbert/utils/coalesce.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from tqdm import tqdm 5 | import ujson 6 | import shutil 7 | 8 | 9 | def main(args): 10 | in_file = args.input 11 | out_file = args.output 12 | 13 | # Get num_chunks from metadata 14 | filepath = os.path.join(in_file, 'metadata.json') 15 | with open(filepath, 'r') as f: 16 | metadata = ujson.load(f) 17 | num_chunks = metadata['num_chunks'] 18 | print(f"Num_chunks = {num_chunks}") 19 | 20 | # Create output dir if not already created 21 | if not os.path.exists(out_file): 22 | os.makedirs(out_file) 23 | 24 | ## Coalesce doclens ## 25 | print("Coalescing doclens files...") 26 | 27 | temp = [] 28 | # read into one large list 29 | for i in tqdm(range(num_chunks)): 30 | filepath = os.path.join(in_file, f'doclens.{i}.json') 31 | with open(filepath, 'r') as f: 32 | chunk = ujson.load(f) 33 | temp.extend(chunk) 34 | 35 | # write to output json 36 | filepath = os.path.join(out_file, 'doclens.0.json') 37 | with open(filepath, 'w') as f: 38 | ujson.dump(temp, f) 39 | 40 | ## Coalesce codes ## 41 | print("Coalescing codes files...") 42 | 43 | temp = torch.empty(0, dtype=torch.int32) 44 | # read into one large tensor 45 | for i in tqdm(range(num_chunks)): 46 | filepath = os.path.join(in_file, f'{i}.codes.pt') 47 | chunk = torch.load(filepath) 48 | temp = torch.cat((temp, chunk)) 49 | 50 | # save length of index 51 | index_len = temp.size()[0] 52 | 53 | # write to output tensor 54 | filepath = os.path.join(out_file, '0.codes.pt') 55 | torch.save(temp, filepath) 56 | 57 | ## Coalesce residuals ## 58 | print("Coalescing residuals files...") 59 | 60 | # Allocate all the memory needed in the beginning. Starting from torch.empty() and concatenating repeatedly results in excessive memory use and is much much slower. 61 | temp = torch.zeros(((metadata['num_embeddings'], int(metadata['config']['dim'] * metadata['config']['nbits'] // 8))), dtype=torch.uint8) 62 | cur_offset = 0 63 | # read into one large tensor 64 | for i in tqdm(range(num_chunks)): 65 | filepath = os.path.join(in_file, f'{i}.residuals.pt') 66 | chunk = torch.load(filepath) 67 | temp[cur_offset : cur_offset + len(chunk):] = chunk 68 | cur_offset += len(chunk) 69 | 70 | print("Saving residuals to output directory (this may take a few minutes)...") 71 | 72 | # write to output tensor 73 | filepath = os.path.join(out_file, '0.residuals.pt') 74 | torch.save(temp, filepath) 75 | 76 | # save metadata.json 77 | metadata['num_chunks'] = 1 78 | filepath = os.path.join(out_file, 'metadata.json') 79 | with open(filepath, 'w') as f: 80 | ujson.dump(metadata, f, indent=4) 81 | 82 | metadata_0 = {} 83 | metadata_0["num_embeddings"] = metadata["num_embeddings"] 84 | metadata_0["passage_offset"] = 0 85 | metadata_0["embedding_offset"] = 0 86 | 87 | filepath = os.path.join(in_file, str(num_chunks-1) + '.metadata.json') 88 | with open(filepath, 'r') as f: 89 | metadata_last = ujson.load(f) 90 | metadata_0["num_passages"] = int(metadata_last["num_passages"]) + int(metadata_last["passage_offset"]) 91 | 92 | filepath = os.path.join(out_file, '0.metadata.json') 93 | with open(filepath, 'w') as f: 94 | ujson.dump(metadata_0, f, indent=4) 95 | 96 | filepath = os.path.join(in_file, 'plan.json') 97 | with open(filepath, 'r') as f: 98 | plan = ujson.load(f) 99 | plan['num_chunks'] = 1 100 | filepath = os.path.join(out_file, 'plan.json') 101 | with open(filepath, 'w') as f: 102 | ujson.dump(plan, f, indent=4) 103 | 104 | other_files = ['avg_residual.pt', 'buckets.pt', 'centroids.pt', 'ivf.pt', 'ivf.pid.pt'] 105 | for filename in other_files: 106 | filepath = os.path.join(in_file, filename) 107 | if os.path.isfile(filepath): 108 | shutil.copy(filepath, out_file) 109 | 110 | print("Saved index to output directory {}.".format(out_file)) 111 | print("Number of embeddings = {}".format(index_len)) 112 | 113 | 114 | if __name__ == "__main__": 115 | parser = argparse.ArgumentParser(description="Coalesce multi-file index into a single file.") 116 | parser.add_argument( 117 | "--input", type=str, required=True, help="Path to input index directory" 118 | ) 119 | parser.add_argument( 120 | "--output", type=str, required=True, help="Path to output index directory" 121 | ) 122 | 123 | args = parser.parse_args() 124 | main(args) 125 | -------------------------------------------------------------------------------- /third_party/colbert/utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import numpy as np 5 | import datetime 6 | 7 | ALREADY_INITALIZED = False 8 | 9 | # TODO: Consider torch.distributed.is_initialized() instead 10 | 11 | 12 | def init(rank): 13 | nranks = 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) 14 | nranks = max(1, nranks) 15 | is_distributed = (nranks > 1) or ('WORLD_SIZE' in os.environ) 16 | 17 | global ALREADY_INITALIZED 18 | if ALREADY_INITALIZED: 19 | return nranks, is_distributed 20 | 21 | ALREADY_INITALIZED = True 22 | 23 | if is_distributed and torch.cuda.is_available(): 24 | num_gpus = torch.cuda.device_count() 25 | print(f'nranks = {nranks} \t num_gpus = {num_gpus} \t device={rank % num_gpus}') 26 | 27 | torch.cuda.set_device(rank % num_gpus) 28 | 29 | # increase the timeout for indexing large datasets with unbalanced split sizes across devices 30 | torch.distributed.init_process_group(backend='nccl', init_method='env://', timeout=datetime.timedelta(hours=1)) 31 | 32 | return nranks, is_distributed 33 | 34 | 35 | def barrier(rank): 36 | nranks = 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) 37 | nranks = max(1, nranks) 38 | 39 | if rank >= 0 and nranks > 1: 40 | torch.distributed.barrier(device_ids=[rank % torch.cuda.device_count()]) 41 | -------------------------------------------------------------------------------- /third_party/colbert/utils/logging.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import ujson 4 | # import mlflow 5 | import traceback 6 | 7 | # from torch.utils.tensorboard import SummaryWriter 8 | from third_party.colbert.utils.utils import print_message, create_directory 9 | 10 | 11 | class Logger(): 12 | def __init__(self, rank, run): 13 | self.rank = rank 14 | self.is_main = self.rank in [-1, 0] 15 | self.run = run 16 | self.logs_path = os.path.join(self.run.path, "logs/") 17 | 18 | if self.is_main: 19 | # self._init_mlflow() 20 | # self.initialized_tensorboard = False 21 | create_directory(self.logs_path) 22 | 23 | # def _init_mlflow(self): 24 | # mlflow.set_tracking_uri('file://' + os.path.join(self.run.experiments_root, "logs/mlruns/")) 25 | # mlflow.set_experiment('/'.join([self.run.experiment, self.run.script])) 26 | 27 | # mlflow.set_tag('experiment', self.run.experiment) 28 | # mlflow.set_tag('name', self.run.name) 29 | # mlflow.set_tag('path', self.run.path) 30 | 31 | # def _init_tensorboard(self): 32 | # root = os.path.join(self.run.experiments_root, "logs/tensorboard/") 33 | # logdir = '__'.join([self.run.experiment, self.run.script, self.run.name]) 34 | # logdir = os.path.join(root, logdir) 35 | 36 | # self.writer = SummaryWriter(log_dir=logdir) 37 | # self.initialized_tensorboard = True 38 | 39 | def _log_exception(self, etype, value, tb): 40 | if not self.is_main: 41 | return 42 | 43 | output_path = os.path.join(self.logs_path, 'exception.txt') 44 | trace = ''.join(traceback.format_exception(etype, value, tb)) + '\n' 45 | print_message(trace, '\n\n') 46 | 47 | self.log_new_artifact(output_path, trace) 48 | 49 | def _log_all_artifacts(self): 50 | if not self.is_main: 51 | return 52 | 53 | # mlflow.log_artifacts(self.logs_path) 54 | 55 | def _log_args(self, args): 56 | if not self.is_main: 57 | return 58 | 59 | # for key in vars(args): 60 | # value = getattr(args, key) 61 | # if type(value) in [int, float, str, bool]: 62 | # mlflow.log_param(key, value) 63 | 64 | # with open(os.path.join(self.logs_path, 'args.json'), 'w') as output_metadata: 65 | # # TODO: Call provenance() on the values that support it 66 | # ujson.dump(args.input_arguments.__dict__, output_metadata, indent=4) 67 | # output_metadata.write('\n') 68 | 69 | with open(os.path.join(self.logs_path, 'args.txt'), 'w') as output_metadata: 70 | output_metadata.write(' '.join(sys.argv) + '\n') 71 | 72 | def log_metric(self, name, value, step, log_to_mlflow=True): 73 | if not self.is_main: 74 | return 75 | 76 | # if not self.initialized_tensorboard: 77 | # self._init_tensorboard() 78 | 79 | # if log_to_mlflow: 80 | # mlflow.log_metric(name, value, step=step) 81 | # self.writer.add_scalar(name, value, step) 82 | 83 | def log_new_artifact(self, path, content): 84 | with open(path, 'w') as f: 85 | f.write(content) 86 | 87 | # mlflow.log_artifact(path) 88 | 89 | def warn(self, *args): 90 | msg = print_message('[WARNING]', '\t', *args) 91 | 92 | with open(os.path.join(self.logs_path, 'warnings.txt'), 'a') as output_metadata: 93 | output_metadata.write(msg + '\n\n\n') 94 | 95 | def info_all(self, *args): 96 | print_message('[' + str(self.rank) + ']', '\t', *args) 97 | 98 | def info(self, *args): 99 | if self.is_main: 100 | print_message(*args) 101 | -------------------------------------------------------------------------------- /third_party/colbert/utils/runs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import __main__ 5 | import traceback 6 | # import mlflow 7 | 8 | import third_party.colbert.utils.distributed as distributed 9 | 10 | from contextlib import contextmanager 11 | from third_party.colbert.utils.logging import Logger 12 | from third_party.colbert.utils.utils import timestamp, create_directory, print_message 13 | 14 | 15 | class _RunManager(): 16 | def __init__(self): 17 | self.experiments_root = None 18 | self.experiment = None 19 | self.path = None 20 | self.script = self._get_script_name() 21 | self.name = self._generate_default_run_name() 22 | self.original_name = self.name 23 | self.exit_status = 'FINISHED' 24 | 25 | self._logger = None 26 | self.start_time = time.time() 27 | 28 | def init(self, rank, root, experiment, name): 29 | assert '/' not in experiment, experiment 30 | assert '/' not in name, name 31 | 32 | self.experiments_root = os.path.abspath(root) 33 | self.experiment = experiment 34 | self.name = name 35 | self.path = os.path.join(self.experiments_root, self.experiment, self.script, self.name) 36 | 37 | if rank < 1: 38 | if os.path.exists(self.path): 39 | print('\n\n') 40 | print_message("It seems that ", self.path, " already exists.") 41 | print_message("Do you want to overwrite it? \t yes/no \n") 42 | 43 | # TODO: This should timeout and exit (i.e., fail) given no response for 60 seconds. 44 | 45 | response = input() 46 | if response.strip() != 'yes': 47 | assert not os.path.exists(self.path), self.path 48 | else: 49 | create_directory(self.path) 50 | 51 | distributed.barrier(rank) 52 | 53 | self._logger = Logger(rank, self) 54 | self._log_args = self._logger._log_args 55 | self.warn = self._logger.warn 56 | self.info = self._logger.info 57 | self.info_all = self._logger.info_all 58 | self.log_metric = self._logger.log_metric 59 | self.log_new_artifact = self._logger.log_new_artifact 60 | 61 | def _generate_default_run_name(self): 62 | return timestamp() 63 | 64 | def _get_script_name(self): 65 | return os.path.basename(__main__.__file__) if '__file__' in dir(__main__) else 'none' 66 | 67 | @contextmanager 68 | def context(self, consider_failed_if_interrupted=True): 69 | try: 70 | yield 71 | 72 | except KeyboardInterrupt as ex: 73 | print('\n\nInterrupted\n\n') 74 | self._logger._log_exception(ex.__class__, ex, ex.__traceback__) 75 | self._logger._log_all_artifacts() 76 | 77 | if consider_failed_if_interrupted: 78 | self.exit_status = 'KILLED' # mlflow.entities.RunStatus.KILLED 79 | 80 | sys.exit(128 + 2) 81 | 82 | except Exception as ex: 83 | self._logger._log_exception(ex.__class__, ex, ex.__traceback__) 84 | self._logger._log_all_artifacts() 85 | 86 | self.exit_status = 'FAILED' # mlflow.entities.RunStatus.FAILED 87 | 88 | raise ex 89 | 90 | finally: 91 | total_seconds = str(time.time() - self.start_time) + '\n' 92 | original_name = str(self.original_name) 93 | name = str(self.name) 94 | 95 | self.log_new_artifact(os.path.join(self._logger.logs_path, 'elapsed.txt'), total_seconds) 96 | self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.original.txt'), original_name) 97 | self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.txt'), name) 98 | 99 | self._logger._log_all_artifacts() 100 | 101 | # mlflow.end_run(status=self.exit_status) 102 | 103 | 104 | Run = _RunManager() 105 | -------------------------------------------------------------------------------- /third_party/colbert_ai.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | MANIFEST.in 3 | README.md 4 | setup.py 5 | colbert/__init__.py 6 | colbert/index.py 7 | colbert/index_updater.py 8 | colbert/indexer.py 9 | colbert/parameters.py 10 | colbert/searcher.py 11 | colbert/trainer.py 12 | colbert/data/__init__.py 13 | colbert/data/collection.py 14 | colbert/data/dataset.py 15 | colbert/data/examples.py 16 | colbert/data/queries.py 17 | colbert/data/ranking.py 18 | colbert/evaluation/__init__.py 19 | colbert/evaluation/load_model.py 20 | colbert/evaluation/loaders.py 21 | colbert/evaluation/metrics.py 22 | colbert/indexing/__init__.py 23 | colbert/indexing/collection_encoder.py 24 | colbert/indexing/collection_indexer.py 25 | colbert/indexing/index_manager.py 26 | colbert/indexing/index_saver.py 27 | colbert/indexing/loaders.py 28 | colbert/indexing/utils.py 29 | colbert/indexing/codecs/__init__.py 30 | colbert/indexing/codecs/decompress_residuals.cpp 31 | colbert/indexing/codecs/decompress_residuals.cu 32 | colbert/indexing/codecs/packbits.cpp 33 | colbert/indexing/codecs/packbits.cu 34 | colbert/indexing/codecs/residual.py 35 | colbert/indexing/codecs/residual_embeddings.py 36 | colbert/indexing/codecs/residual_embeddings_strided.py 37 | colbert/infra/__init__.py 38 | colbert/infra/launcher.py 39 | colbert/infra/provenance.py 40 | colbert/infra/run.py 41 | colbert/infra/config/__init__.py 42 | colbert/infra/config/base_config.py 43 | colbert/infra/config/config.py 44 | colbert/infra/config/core_config.py 45 | colbert/infra/config/settings.py 46 | colbert/modeling/__init__.py 47 | colbert/modeling/base_colbert.py 48 | colbert/modeling/checkpoint.py 49 | colbert/modeling/colbert.py 50 | colbert/modeling/hf_colbert.py 51 | colbert/modeling/segmented_maxsim.cpp 52 | colbert/modeling/reranker/__init__.py 53 | colbert/modeling/reranker/electra.py 54 | colbert/modeling/reranker/tokenizer.py 55 | colbert/modeling/tokenization/__init__.py 56 | colbert/modeling/tokenization/doc_tokenization.py 57 | colbert/modeling/tokenization/query_tokenization.py 58 | colbert/modeling/tokenization/utils.py 59 | colbert/ranking/__init__.py 60 | colbert/search/__init__.py 61 | colbert/search/candidate_generation.py 62 | colbert/search/decompress_residuals.cpp 63 | colbert/search/filter_pids.cpp 64 | colbert/search/index_loader.py 65 | colbert/search/index_storage.py 66 | colbert/search/segmented_lookup.cpp 67 | colbert/search/strided_tensor.py 68 | colbert/search/strided_tensor_core.py 69 | colbert/training/__init__.py 70 | colbert/training/eager_batcher.py 71 | colbert/training/lazy_batcher.py 72 | colbert/training/rerank_batcher.py 73 | colbert/training/training.py 74 | colbert/training/utils.py 75 | colbert/utils/__init__.py 76 | colbert/utils/amp.py 77 | colbert/utils/coalesce.py 78 | colbert/utils/distributed.py 79 | colbert/utils/logging.py 80 | colbert/utils/parser.py 81 | colbert/utils/runs.py 82 | colbert/utils/utils.py 83 | colbert_ai.egg-info/PKG-INFO 84 | colbert_ai.egg-info/SOURCES.txt 85 | colbert_ai.egg-info/dependency_links.txt 86 | colbert_ai.egg-info/requires.txt 87 | colbert_ai.egg-info/top_level.txt 88 | utility/__init__.py 89 | utility/evaluate/__init__.py 90 | utility/evaluate/annotate_EM.py 91 | utility/evaluate/annotate_EM_helpers.py 92 | utility/evaluate/evaluate_lotte_rankings.py 93 | utility/evaluate/msmarco_passages.py 94 | utility/preprocess/__init__.py 95 | utility/preprocess/docs2passages.py 96 | utility/preprocess/queries_split.py 97 | utility/rankings/__init__.py 98 | utility/rankings/dev_subsample.py 99 | utility/rankings/merge.py 100 | utility/rankings/split_by_offset.py 101 | utility/rankings/split_by_queries.py 102 | utility/rankings/tune.py 103 | utility/supervision/__init__.py 104 | utility/supervision/self_training.py 105 | utility/supervision/triples.py 106 | utility/utils/__init__.py 107 | utility/utils/dpr.py 108 | utility/utils/qa_loaders.py 109 | utility/utils/save_metadata.py -------------------------------------------------------------------------------- /third_party/colbert_ai.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /third_party/colbert_ai.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | bitarray 2 | datasets 3 | flask 4 | git-python 5 | python-dotenv 6 | ninja 7 | scipy 8 | tqdm 9 | transformers 10 | ujson 11 | 12 | [faiss-cpu] 13 | faiss-cpu>=1.7.0 14 | 15 | [faiss-gpu] 16 | faiss-gpu>=1.7.0 17 | 18 | [torch] 19 | torch==1.13.1 20 | -------------------------------------------------------------------------------- /third_party/colbert_ai.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | colbert 2 | utility 3 | -------------------------------------------------------------------------------- /third_party/conda_env.yml: -------------------------------------------------------------------------------- 1 | name: colbert 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - python=3.8 9 | - faiss-gpu 10 | - pip=21.0 11 | - pytorch=1.13.1 12 | - torchaudio=0.13.1 13 | - torchvision=0.14.1 14 | - cudatoolkit=11.3 15 | - gcc=9.4.0 16 | - gxx=9.4.0 17 | - pip: 18 | - bitarray 19 | - datasets 20 | - gitpython 21 | - jupyter 22 | - jupyterlab 23 | - ninja 24 | - scipy 25 | - tqdm 26 | - transformers 27 | - ujson 28 | - flask 29 | - python-dotenv 30 | -------------------------------------------------------------------------------- /third_party/conda_env_cpu.yml: -------------------------------------------------------------------------------- 1 | name: colbert 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - python=3.8 8 | - faiss-cpu 9 | - pip=21.0 10 | - pytorch-cpu=1.13 11 | - pip: 12 | - bitarray 13 | - datasets 14 | - gitpython 15 | - ninja 16 | - scipy 17 | - tqdm 18 | - transformers 19 | - ujson 20 | - flask 21 | - python-dotenv 22 | -------------------------------------------------------------------------------- /third_party/server.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, render_template, request 2 | from functools import lru_cache 3 | import math 4 | import os 5 | from dotenv import load_dotenv 6 | 7 | from third_party.colbert.infra import Run, RunConfig, ColBERTConfig 8 | from third_party.colbert import Searcher 9 | 10 | load_dotenv() 11 | 12 | INDEX_NAME = os.getenv("INDEX_NAME") 13 | INDEX_ROOT = os.getenv("INDEX_ROOT") 14 | app = Flask(__name__) 15 | 16 | searcher = Searcher(index=INDEX_NAME, index_root=INDEX_ROOT) 17 | counter = {"api" : 0} 18 | 19 | @lru_cache(maxsize=1000000) 20 | def api_search_query(query, k): 21 | print(f"Query={query}") 22 | if k == None: k = 10 23 | k = min(int(k), 100) 24 | pids, ranks, scores = searcher.search(query, k=100) 25 | pids, ranks, scores = pids[:k], ranks[:k], scores[:k] 26 | passages = [searcher.collection[pid] for pid in pids] 27 | probs = [math.exp(score) for score in scores] 28 | probs = [prob / sum(probs) for prob in probs] 29 | topk = [] 30 | for pid, rank, score, prob in zip(pids, ranks, scores, probs): 31 | text = searcher.collection[pid] 32 | d = {'text': text, 'pid': pid, 'rank': rank, 'score': score, 'prob': prob} 33 | topk.append(d) 34 | topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid']))) 35 | return {"query" : query, "topk": topk} 36 | 37 | @app.route("/api/search", methods=["GET"]) 38 | def api_search(): 39 | if request.method == "GET": 40 | counter["api"] += 1 41 | print("API request count:", counter["api"]) 42 | return api_search_query(request.args.get("query"), request.args.get("k")) 43 | else: 44 | return ('', 405) 45 | 46 | if __name__ == "__main__": 47 | app.run("0.0.0.0", int(os.getenv("PORT"))) 48 | 49 | -------------------------------------------------------------------------------- /third_party/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as f: 4 | long_description = f.read() 5 | 6 | package_data = { 7 | "": ["*.cpp", "*.cu"], 8 | } 9 | 10 | setuptools.setup( 11 | name="colbert-ai", 12 | version="0.2.20", 13 | author="Omar Khattab", 14 | author_email="okhattab@stanford.edu", 15 | description="Efficient and Effective Passage Search via Contextualized Late Interaction over BERT", 16 | long_description=long_description, 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/stanford-futuredata/ColBERT", 19 | packages=setuptools.find_packages(), 20 | python_requires=">=3.8", 21 | install_requires=[ 22 | "bitarray", 23 | "datasets", 24 | "flask", 25 | "git-python", 26 | "python-dotenv", 27 | "ninja", 28 | "scipy", 29 | "tqdm", 30 | "transformers", 31 | "ujson", 32 | ], 33 | extras_require={ 34 | "faiss-gpu": ["faiss-gpu>=1.7.0"], 35 | "faiss-cpu": ["faiss-cpu>=1.7.0"], 36 | "torch": ["torch==1.13.1"], 37 | }, 38 | include_package_data=True, 39 | package_data=package_data, 40 | ) 41 | -------------------------------------------------------------------------------- /third_party/utility/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/utility/__init__.py -------------------------------------------------------------------------------- /third_party/utility/evaluate/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/utility/evaluate/__init__.py -------------------------------------------------------------------------------- /third_party/utility/evaluate/annotate_EM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import git 4 | import tqdm 5 | import ujson 6 | import random 7 | 8 | from argparse import ArgumentParser 9 | from multiprocessing import Pool 10 | 11 | from third_party.colbert.utils.utils import print_message, load_ranking, groupby_first_item 12 | from third_party.utility.utils.qa_loaders import load_qas_, load_collection_ 13 | from third_party.utility.utils.save_metadata import format_metadata, get_metadata 14 | from third_party.utility.evaluate.annotate_EM_helpers import * 15 | 16 | 17 | # TODO: Tokenize passages in advance, especially if the ranked list is long! This requires changes to the has_answer input, slightly. 18 | 19 | def main(args): 20 | qas = load_qas_(args.qas) 21 | collection = load_collection_(args.collection, retain_titles=True) 22 | rankings = load_ranking(args.ranking) 23 | parallel_pool = Pool(30) 24 | 25 | print_message('#> Tokenize the answers in the Q&As in parallel...') 26 | qas = list(parallel_pool.map(tokenize_all_answers, qas)) 27 | 28 | qid2answers = {qid: tok_answers for qid, _, tok_answers in qas} 29 | assert len(qas) == len(qid2answers), (len(qas), len(qid2answers)) 30 | 31 | print_message('#> Lookup passages from PIDs...') 32 | expanded_rankings = [(qid, pid, rank, collection[pid], qid2answers[qid]) 33 | for qid, pid, rank, *_ in rankings] 34 | 35 | print_message('#> Assign labels in parallel...') 36 | labeled_rankings = list(parallel_pool.map(assign_label_to_passage, enumerate(expanded_rankings))) 37 | 38 | # Dump output. 39 | print_message("#> Dumping output to", args.output, "...") 40 | qid2rankings = groupby_first_item(labeled_rankings) 41 | 42 | num_judged_queries, num_ranked_queries = check_sizes(qid2answers, qid2rankings) 43 | 44 | # Evaluation metrics and depths. 45 | success, counts = compute_and_write_labels(args.output, qid2answers, qid2rankings) 46 | 47 | # Dump metrics. 48 | with open(args.output_metrics, 'w') as f: 49 | d = {'num_ranked_queries': num_ranked_queries, 'num_judged_queries': num_judged_queries} 50 | 51 | extra = '__WARNING' if num_judged_queries != num_ranked_queries else '' 52 | d[f'success{extra}'] = {k: v / num_judged_queries for k, v in success.items()} 53 | d[f'counts{extra}'] = {k: v / num_judged_queries for k, v in counts.items()} 54 | d['arguments'] = get_metadata(args) 55 | 56 | f.write(format_metadata(d) + '\n') 57 | 58 | print('\n\n') 59 | print(args.output) 60 | print(args.output_metrics) 61 | print("#> Done\n") 62 | 63 | 64 | if __name__ == "__main__": 65 | random.seed(12345) 66 | 67 | parser = ArgumentParser(description='.') 68 | 69 | # Input / Output Arguments 70 | parser.add_argument('--qas', dest='qas', required=True, type=str) 71 | parser.add_argument('--collection', dest='collection', required=True, type=str) 72 | parser.add_argument('--ranking', dest='ranking', required=True, type=str) 73 | 74 | args = parser.parse_args() 75 | 76 | args.output = f'{args.ranking}.annotated' 77 | args.output_metrics = f'{args.ranking}.annotated.metrics' 78 | 79 | assert not os.path.exists(args.output), args.output 80 | 81 | main(args) 82 | -------------------------------------------------------------------------------- /third_party/utility/evaluate/annotate_EM_helpers.py: -------------------------------------------------------------------------------- 1 | from third_party.colbert.utils.utils import print_message 2 | from third_party.utility.utils.dpr import DPR_normalize, has_answer 3 | 4 | 5 | def tokenize_all_answers(args): 6 | qid, question, answers = args 7 | return qid, question, [DPR_normalize(ans) for ans in answers] 8 | 9 | 10 | def assign_label_to_passage(args): 11 | idx, (qid, pid, rank, passage, tokenized_answers) = args 12 | 13 | if idx % (1*1000*1000) == 0: 14 | print(idx) 15 | 16 | return qid, pid, rank, has_answer(tokenized_answers, passage) 17 | 18 | 19 | def check_sizes(qid2answers, qid2rankings): 20 | num_judged_queries = len(qid2answers) 21 | num_ranked_queries = len(qid2rankings) 22 | 23 | print_message('num_judged_queries =', num_judged_queries) 24 | print_message('num_ranked_queries =', num_ranked_queries) 25 | 26 | if num_judged_queries != num_ranked_queries: 27 | assert num_ranked_queries <= num_judged_queries 28 | 29 | print('\n\n') 30 | print_message('[WARNING] num_judged_queries != num_ranked_queries') 31 | print('\n\n') 32 | 33 | return num_judged_queries, num_ranked_queries 34 | 35 | 36 | def compute_and_write_labels(output_path, qid2answers, qid2rankings): 37 | cutoffs = [1, 5, 10, 20, 30, 50, 100, 1000, 'all'] 38 | success = {cutoff: 0.0 for cutoff in cutoffs} 39 | counts = {cutoff: 0.0 for cutoff in cutoffs} 40 | 41 | with open(output_path, 'w') as f: 42 | for qid in qid2answers: 43 | if qid not in qid2rankings: 44 | continue 45 | 46 | prev_rank = 0 # ranks should start at one (i.e., and not zero) 47 | labels = [] 48 | 49 | for pid, rank, label in qid2rankings[qid]: 50 | assert rank == prev_rank+1, (qid, pid, (prev_rank, rank)) 51 | prev_rank = rank 52 | 53 | labels.append(label) 54 | line = '\t'.join(map(str, [qid, pid, rank, int(label)])) + '\n' 55 | f.write(line) 56 | 57 | for cutoff in cutoffs: 58 | if cutoff != 'all': 59 | success[cutoff] += sum(labels[:cutoff]) > 0 60 | counts[cutoff] += sum(labels[:cutoff]) 61 | else: 62 | success[cutoff] += sum(labels) > 0 63 | counts[cutoff] += sum(labels) 64 | 65 | return success, counts 66 | 67 | 68 | # def dump_metrics(f, nqueries, cutoffs, success, counts): 69 | # for cutoff in cutoffs: 70 | # success_log = "#> P@{} = {}".format(cutoff, success[cutoff] / nqueries) 71 | # counts_log = "#> D@{} = {}".format(cutoff, counts[cutoff] / nqueries) 72 | # print('\n'.join([success_log, counts_log]) + '\n') 73 | 74 | # f.write('\n'.join([success_log, counts_log]) + '\n\n') 75 | -------------------------------------------------------------------------------- /third_party/utility/evaluate/evaluate_lotte_rankings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | import jsonlines 4 | import os 5 | import sys 6 | 7 | 8 | def evaluate_dataset(query_type, dataset, split, k, data_rootdir, rankings_rootdir): 9 | data_path = os.path.join(data_rootdir, dataset, split) 10 | rankings_path = os.path.join( 11 | rankings_rootdir, split, f"{dataset}.{query_type}.ranking.tsv" 12 | ) 13 | if not os.path.exists(rankings_path): 14 | print(f"[query_type={query_type}, dataset={dataset}] Success@{k}: ???") 15 | return 16 | rankings = defaultdict(list) 17 | with open(rankings_path, "r") as f: 18 | for line in f: 19 | items = line.strip().split("\t") 20 | qid, pid, rank = items[:3] 21 | qid = int(qid) 22 | pid = int(pid) 23 | rank = int(rank) 24 | rankings[qid].append(pid) 25 | assert rank == len(rankings[qid]) 26 | 27 | success = 0 28 | qas_path = os.path.join(data_path, f"qas.{query_type}.jsonl") 29 | 30 | num_total_qids = 0 31 | with jsonlines.open(qas_path, mode="r") as f: 32 | for line in f: 33 | qid = int(line["qid"]) 34 | num_total_qids += 1 35 | if qid not in rankings: 36 | print(f"WARNING: qid {qid} not found in {rankings_path}!", file=sys.stderr) 37 | continue 38 | answer_pids = set(line["answer_pids"]) 39 | if len(set(rankings[qid][:k]).intersection(answer_pids)) > 0: 40 | success += 1 41 | print( 42 | f"[query_type={query_type}, dataset={dataset}] " 43 | f"Success@{k}: {success / num_total_qids * 100:.1f}" 44 | ) 45 | 46 | 47 | def main(args): 48 | for query_type in ["search", "forum"]: 49 | for dataset in [ 50 | "writing", 51 | "recreation", 52 | "science", 53 | "technology", 54 | "lifestyle", 55 | "pooled", 56 | ]: 57 | evaluate_dataset( 58 | query_type, 59 | dataset, 60 | args.split, 61 | args.k, 62 | args.data_dir, 63 | args.rankings_dir, 64 | ) 65 | print() 66 | 67 | 68 | if __name__ == "__main__": 69 | parser = argparse.ArgumentParser(description="LoTTE evaluation script") 70 | parser.add_argument("--k", type=int, default=5, help="Success@k") 71 | parser.add_argument( 72 | "-s", "--split", choices=["dev", "test"], required=True, help="Split" 73 | ) 74 | parser.add_argument( 75 | "-d", "--data_dir", type=str, required=True, help="Path to LoTTE data directory" 76 | ) 77 | parser.add_argument( 78 | "-r", 79 | "--rankings_dir", 80 | type=str, 81 | required=True, 82 | help="Path to LoTTE rankings directory", 83 | ) 84 | args = parser.parse_args() 85 | main(args) 86 | -------------------------------------------------------------------------------- /third_party/utility/evaluate/msmarco_passages.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluate MS MARCO Passages ranking. 3 | """ 4 | 5 | import os 6 | import math 7 | import tqdm 8 | import ujson 9 | import random 10 | 11 | from argparse import ArgumentParser 12 | from collections import defaultdict 13 | from third_party.colbert.utils.utils import print_message, file_tqdm 14 | 15 | 16 | def main(args): 17 | qid2positives = defaultdict(list) 18 | qid2ranking = defaultdict(list) 19 | qid2mrr = {} 20 | qid2recall = {depth: {} for depth in [50, 200, 1000, 5000, 10000]} 21 | 22 | with open(args.qrels) as f: 23 | print_message(f"#> Loading QRELs from {args.qrels} ..") 24 | for line in file_tqdm(f): 25 | qid, _, pid, label = map(int, line.strip().split()) 26 | assert label == 1 27 | 28 | qid2positives[qid].append(pid) 29 | 30 | with open(args.ranking) as f: 31 | print_message(f"#> Loading ranked lists from {args.ranking} ..") 32 | for line in file_tqdm(f): 33 | qid, pid, rank, *score = line.strip().split('\t') 34 | qid, pid, rank = int(qid), int(pid), int(rank) 35 | 36 | if len(score) > 0: 37 | assert len(score) == 1 38 | score = float(score[0]) 39 | else: 40 | score = None 41 | 42 | qid2ranking[qid].append((rank, pid, score)) 43 | 44 | assert set.issubset(set(qid2ranking.keys()), set(qid2positives.keys())) 45 | 46 | num_judged_queries = len(qid2positives) 47 | num_ranked_queries = len(qid2ranking) 48 | 49 | if num_judged_queries != num_ranked_queries: 50 | print() 51 | print_message("#> [WARNING] num_judged_queries != num_ranked_queries") 52 | print_message(f"#> {num_judged_queries} != {num_ranked_queries}") 53 | print() 54 | 55 | print_message(f"#> Computing MRR@10 for {num_judged_queries} queries.") 56 | 57 | for qid in tqdm.tqdm(qid2positives): 58 | ranking = qid2ranking[qid] 59 | positives = qid2positives[qid] 60 | 61 | for rank, (_, pid, _) in enumerate(ranking): 62 | rank = rank + 1 # 1-indexed 63 | 64 | if pid in positives: 65 | if rank <= 10: 66 | qid2mrr[qid] = 1.0 / rank 67 | break 68 | 69 | for rank, (_, pid, _) in enumerate(ranking): 70 | rank = rank + 1 # 1-indexed 71 | 72 | if pid in positives: 73 | for depth in qid2recall: 74 | if rank <= depth: 75 | qid2recall[depth][qid] = qid2recall[depth].get(qid, 0) + 1.0 / len(positives) 76 | 77 | assert len(qid2mrr) <= num_ranked_queries, (len(qid2mrr), num_ranked_queries) 78 | 79 | print() 80 | mrr_10_sum = sum(qid2mrr.values()) 81 | print_message(f"#> MRR@10 = {mrr_10_sum / num_judged_queries}") 82 | print_message(f"#> MRR@10 (only for ranked queries) = {mrr_10_sum / num_ranked_queries}") 83 | print() 84 | 85 | for depth in qid2recall: 86 | assert len(qid2recall[depth]) <= num_ranked_queries, (len(qid2recall[depth]), num_ranked_queries) 87 | 88 | print() 89 | metric_sum = sum(qid2recall[depth].values()) 90 | print_message(f"#> Recall@{depth} = {metric_sum / num_judged_queries}") 91 | print_message(f"#> Recall@{depth} (only for ranked queries) = {metric_sum / num_ranked_queries}") 92 | print() 93 | 94 | if args.annotate: 95 | print_message(f"#> Writing annotations to {args.output} ..") 96 | 97 | with open(args.output, 'w') as f: 98 | for qid in tqdm.tqdm(qid2positives): 99 | ranking = qid2ranking[qid] 100 | positives = qid2positives[qid] 101 | 102 | for rank, (_, pid, score) in enumerate(ranking): 103 | rank = rank + 1 # 1-indexed 104 | label = int(pid in positives) 105 | 106 | line = [qid, pid, rank, score, label] 107 | line = [x for x in line if x is not None] 108 | line = '\t'.join(map(str, line)) + '\n' 109 | f.write(line) 110 | 111 | 112 | if __name__ == "__main__": 113 | parser = ArgumentParser(description="msmarco_passages.") 114 | 115 | # Input Arguments. 116 | parser.add_argument('--qrels', dest='qrels', required=True, type=str) 117 | parser.add_argument('--ranking', dest='ranking', required=True, type=str) 118 | parser.add_argument('--annotate', dest='annotate', default=False, action='store_true') 119 | 120 | args = parser.parse_args() 121 | 122 | if args.annotate: 123 | args.output = f'{args.ranking}.annotated' 124 | assert not os.path.exists(args.output), args.output 125 | 126 | main(args) 127 | -------------------------------------------------------------------------------- /third_party/utility/preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/utility/preprocess/__init__.py -------------------------------------------------------------------------------- /third_party/utility/preprocess/queries_split.py: -------------------------------------------------------------------------------- 1 | """ 2 | Divide a query set into two. 3 | """ 4 | 5 | import os 6 | import math 7 | import ujson 8 | import random 9 | 10 | from argparse import ArgumentParser 11 | from collections import OrderedDict 12 | from third_party.colbert.utils.utils import print_message 13 | 14 | 15 | def main(args): 16 | random.seed(12345) 17 | 18 | """ 19 | Load the queries 20 | """ 21 | Queries = OrderedDict() 22 | 23 | print_message(f"#> Loading queries from {args.input}..") 24 | with open(args.input) as f: 25 | for line in f: 26 | qid, query = line.strip().split('\t') 27 | 28 | assert qid not in Queries 29 | Queries[qid] = query 30 | 31 | """ 32 | Apply the splitting 33 | """ 34 | size_a = len(Queries) - args.holdout 35 | size_b = args.holdout 36 | size_a, size_b = max(size_a, size_b), min(size_a, size_b) 37 | 38 | assert size_a > 0 and size_b > 0, (len(Queries), size_a, size_b) 39 | 40 | print_message(f"#> Deterministically splitting the queries into ({size_a}, {size_b})-sized splits.") 41 | 42 | keys = list(Queries.keys()) 43 | sample_b_indices = sorted(list(random.sample(range(len(keys)), size_b))) 44 | sample_a_indices = sorted(list(set.difference(set(list(range(len(keys)))), set(sample_b_indices)))) 45 | 46 | assert len(sample_a_indices) == size_a 47 | assert len(sample_b_indices) == size_b 48 | 49 | sample_a = [keys[idx] for idx in sample_a_indices] 50 | sample_b = [keys[idx] for idx in sample_b_indices] 51 | 52 | """ 53 | Write the output 54 | """ 55 | 56 | output_path_a = f'{args.input}.a' 57 | output_path_b = f'{args.input}.b' 58 | 59 | assert not os.path.exists(output_path_a), output_path_a 60 | assert not os.path.exists(output_path_b), output_path_b 61 | 62 | print_message(f"#> Writing the splits out to {output_path_a} and {output_path_b} ...") 63 | 64 | for output_path, sample in [(output_path_a, sample_a), (output_path_b, sample_b)]: 65 | with open(output_path, 'w') as f: 66 | for qid in sample: 67 | query = Queries[qid] 68 | line = '\t'.join([qid, query]) + '\n' 69 | f.write(line) 70 | 71 | 72 | if __name__ == "__main__": 73 | parser = ArgumentParser(description="queries_split.") 74 | 75 | # Input Arguments. 76 | parser.add_argument('--input', dest='input', required=True) 77 | parser.add_argument('--holdout', dest='holdout', required=True, type=int) 78 | 79 | args = parser.parse_args() 80 | 81 | main(args) 82 | -------------------------------------------------------------------------------- /third_party/utility/rankings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/utility/rankings/__init__.py -------------------------------------------------------------------------------- /third_party/utility/rankings/dev_subsample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ujson 3 | import random 4 | 5 | from argparse import ArgumentParser 6 | 7 | from third_party.colbert.utils.utils import print_message, create_directory, load_ranking, groupby_first_item 8 | from third_party.utility.utils.qa_loaders import load_qas_ 9 | 10 | 11 | def main(args): 12 | print_message("#> Loading all..") 13 | qas = load_qas_(args.qas) 14 | rankings = load_ranking(args.ranking) 15 | qid2rankings = groupby_first_item(rankings) 16 | 17 | print_message("#> Subsampling all..") 18 | qas_sample = random.sample(qas, args.sample) 19 | 20 | with open(args.output, 'w') as f: 21 | for qid, *_ in qas_sample: 22 | for items in qid2rankings[qid]: 23 | items = [qid] + items 24 | line = '\t'.join(map(str, items)) + '\n' 25 | f.write(line) 26 | 27 | print('\n\n') 28 | print(args.output) 29 | print("#> Done.") 30 | 31 | 32 | if __name__ == "__main__": 33 | random.seed(12345) 34 | 35 | parser = ArgumentParser(description='Subsample the dev set.') 36 | parser.add_argument('--qas', dest='qas', required=True, type=str) 37 | parser.add_argument('--ranking', dest='ranking', required=True) 38 | parser.add_argument('--output', dest='output', required=True) 39 | 40 | parser.add_argument('--sample', dest='sample', default=1500, type=int) 41 | 42 | args = parser.parse_args() 43 | 44 | assert not os.path.exists(args.output), args.output 45 | create_directory(os.path.dirname(args.output)) 46 | 47 | main(args) 48 | -------------------------------------------------------------------------------- /third_party/utility/rankings/merge.py: -------------------------------------------------------------------------------- 1 | """ 2 | Divide two or more ranking files, by score. 3 | """ 4 | 5 | import os 6 | import tqdm 7 | 8 | from argparse import ArgumentParser 9 | from collections import defaultdict 10 | from third_party.colbert.utils.utils import print_message, file_tqdm 11 | 12 | 13 | def main(args): 14 | Rankings = defaultdict(list) 15 | 16 | for path in args.input: 17 | print_message(f"#> Loading the rankings in {path} ..") 18 | 19 | with open(path) as f: 20 | for line in file_tqdm(f): 21 | qid, pid, rank, score = line.strip().split('\t') 22 | qid, pid, rank = map(int, [qid, pid, rank]) 23 | score = float(score) 24 | 25 | Rankings[qid].append((score, rank, pid)) 26 | 27 | with open(args.output, 'w') as f: 28 | print_message(f"#> Writing the output rankings to {args.output} ..") 29 | 30 | for qid in tqdm.tqdm(Rankings): 31 | ranking = sorted(Rankings[qid], reverse=True) 32 | 33 | for rank, (score, original_rank, pid) in enumerate(ranking): 34 | rank = rank + 1 # 1-indexed 35 | 36 | if (args.depth > 0) and (rank > args.depth): 37 | break 38 | 39 | line = [qid, pid, rank, score] 40 | line = '\t'.join(map(str, line)) + '\n' 41 | f.write(line) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = ArgumentParser(description="merge_rankings.") 46 | 47 | # Input Arguments. 48 | parser.add_argument('--input', dest='input', required=True, nargs='+') 49 | parser.add_argument('--output', dest='output', required=True, type=str) 50 | 51 | parser.add_argument('--depth', dest='depth', required=True, type=int) 52 | 53 | args = parser.parse_args() 54 | 55 | assert not os.path.exists(args.output), args.output 56 | 57 | main(args) 58 | -------------------------------------------------------------------------------- /third_party/utility/rankings/split_by_offset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Split the ranked lists after retrieval with a merged query set. 3 | """ 4 | 5 | import os 6 | import random 7 | 8 | from argparse import ArgumentParser 9 | 10 | 11 | def main(args): 12 | output_paths = ['{}.{}'.format(args.ranking, split) for split in args.names] 13 | assert all(not os.path.exists(path) for path in output_paths), output_paths 14 | 15 | output_files = [open(path, 'w') for path in output_paths] 16 | 17 | with open(args.ranking) as f: 18 | for line in f: 19 | qid, pid, rank, *other = line.strip().split('\t') 20 | qid = int(qid) 21 | split_output_path = output_files[qid // args.gap - 1] 22 | qid = qid % args.gap 23 | 24 | split_output_path.write('\t'.join([str(x) for x in [qid, pid, rank, *other]]) + '\n') 25 | 26 | print(f.name) 27 | 28 | _ = [f.close() for f in output_files] 29 | 30 | print("#> Done!") 31 | 32 | 33 | if __name__ == "__main__": 34 | random.seed(12345) 35 | 36 | parser = ArgumentParser(description='Subsample the dev set.') 37 | parser.add_argument('--ranking', dest='ranking', required=True) 38 | 39 | parser.add_argument('--names', dest='names', required=False, default=['train', 'dev', 'test'], type=str, nargs='+') # order matters! 40 | parser.add_argument('--gap', dest='gap', required=False, default=1_000_000_000, type=int) # larger than any individual query set 41 | 42 | args = parser.parse_args() 43 | 44 | main(args) 45 | -------------------------------------------------------------------------------- /third_party/utility/rankings/split_by_queries.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import tqdm 4 | import ujson 5 | import random 6 | 7 | from argparse import ArgumentParser 8 | from collections import OrderedDict 9 | from third_party.colbert.utils.utils import print_message, file_tqdm 10 | 11 | 12 | def main(args): 13 | qid_to_file_idx = {} 14 | 15 | for qrels_idx, qrels in enumerate(args.all_queries): 16 | with open(qrels) as f: 17 | for line in f: 18 | qid, *_ = line.strip().split('\t') 19 | qid = int(qid) 20 | 21 | assert qid_to_file_idx.get(qid, qrels_idx) == qrels_idx, (qid, qrels_idx) 22 | qid_to_file_idx[qid] = qrels_idx 23 | 24 | all_outputs_paths = [f'{args.ranking}.{idx}' for idx in range(len(args.all_queries))] 25 | 26 | assert all(not os.path.exists(path) for path in all_outputs_paths) 27 | 28 | all_outputs = [open(path, 'w') for path in all_outputs_paths] 29 | 30 | with open(args.ranking) as f: 31 | print_message(f"#> Loading ranked lists from {f.name} ..") 32 | 33 | last_file_idx = -1 34 | 35 | for line in file_tqdm(f): 36 | qid, *_ = line.strip().split('\t') 37 | 38 | file_idx = qid_to_file_idx[int(qid)] 39 | 40 | if file_idx != last_file_idx: 41 | print_message(f"#> Switched to file #{file_idx} at {all_outputs[file_idx].name}") 42 | 43 | last_file_idx = file_idx 44 | 45 | all_outputs[file_idx].write(line) 46 | 47 | print() 48 | 49 | for f in all_outputs: 50 | print(f.name) 51 | f.close() 52 | 53 | print("#> Done!") 54 | 55 | 56 | if __name__ == "__main__": 57 | random.seed(12345) 58 | 59 | parser = ArgumentParser(description='.') 60 | 61 | # Input Arguments 62 | parser.add_argument('--ranking', dest='ranking', required=True, type=str) 63 | parser.add_argument('--all-queries', dest='all_queries', required=True, type=str, nargs='+') 64 | 65 | args = parser.parse_args() 66 | 67 | main(args) 68 | -------------------------------------------------------------------------------- /third_party/utility/rankings/tune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ujson 3 | import random 4 | 5 | from argparse import ArgumentParser 6 | from third_party.colbert.utils.utils import print_message, create_directory 7 | from third_party.utility.utils.save_metadata import save_metadata 8 | 9 | 10 | def main(args): 11 | AllMetrics = {} 12 | Scores = {} 13 | 14 | for path in args.paths: 15 | with open(path) as f: 16 | metric = ujson.load(f) 17 | AllMetrics[path] = metric 18 | 19 | for k in args.metric: 20 | metric = metric[k] 21 | 22 | assert type(metric) is float 23 | Scores[path] = metric 24 | 25 | MaxKey = max(Scores, key=Scores.get) 26 | 27 | MaxCKPT = int(MaxKey.split('/')[-2].split('.')[-1]) 28 | MaxARGS = os.path.join(os.path.dirname(MaxKey), 'logs', 'args.json') 29 | 30 | with open(MaxARGS) as f: 31 | logs = ujson.load(f) 32 | MaxCHECKPOINT = logs['checkpoint'] 33 | 34 | assert MaxCHECKPOINT.endswith(f'colbert-{MaxCKPT}.dnn'), (MaxCHECKPOINT, MaxCKPT) 35 | 36 | with open(args.output, 'w') as f: 37 | f.write(MaxCHECKPOINT) 38 | 39 | args.Scores = Scores 40 | args.AllMetrics = AllMetrics 41 | 42 | save_metadata(f'{args.output}.meta', args) 43 | 44 | print('\n\n', args, '\n\n') 45 | print(args.output) 46 | print_message("#> Done.") 47 | 48 | 49 | if __name__ == "__main__": 50 | random.seed(12345) 51 | 52 | parser = ArgumentParser(description='.') 53 | 54 | # Input / Output Arguments 55 | parser.add_argument('--metric', dest='metric', required=True, type=str) # e.g., success.20 56 | parser.add_argument('--paths', dest='paths', required=True, type=str, nargs='+') 57 | parser.add_argument('--output', dest='output', required=True, type=str) 58 | 59 | args = parser.parse_args() 60 | 61 | args.metric = args.metric.split('.') 62 | 63 | assert not os.path.exists(args.output), args.output 64 | create_directory(os.path.dirname(args.output)) 65 | 66 | main(args) 67 | -------------------------------------------------------------------------------- /third_party/utility/supervision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/utility/supervision/__init__.py -------------------------------------------------------------------------------- /third_party/utility/supervision/self_training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import git 4 | import tqdm 5 | import ujson 6 | import random 7 | 8 | from argparse import ArgumentParser 9 | from third_party.colbert.utils.utils import print_message, load_ranking, groupby_first_item 10 | 11 | 12 | MAX_NUM_TRIPLES = 40_000_000 13 | 14 | 15 | def sample_negatives(negatives, num_sampled, biased=False): 16 | num_sampled = min(len(negatives), num_sampled) 17 | 18 | if biased: 19 | assert num_sampled % 2 == 0 20 | num_sampled_top100 = num_sampled // 2 21 | num_sampled_rest = num_sampled - num_sampled_top100 22 | 23 | return random.sample(negatives[:100], num_sampled_top100) + random.sample(negatives[100:], num_sampled_rest) 24 | 25 | return random.sample(negatives, num_sampled) 26 | 27 | 28 | def sample_for_query(qid, ranking, npositives, depth_positive, depth_negative, cutoff_negative): 29 | """ 30 | Requires that the ranks are sorted per qid. 31 | """ 32 | assert npositives <= depth_positive < cutoff_negative < depth_negative 33 | 34 | positives, negatives, triples = [], [], [] 35 | 36 | for pid, rank, *_ in ranking: 37 | assert rank >= 1, f"ranks should start at 1 \t\t got rank = {rank}" 38 | 39 | if rank > depth_negative: 40 | break 41 | 42 | if rank <= depth_positive: 43 | positives.append(pid) 44 | elif rank > cutoff_negative: 45 | negatives.append(pid) 46 | 47 | num_sampled = 100 48 | 49 | for neg in sample_negatives(negatives, num_sampled): 50 | positives_ = random.sample(positives, npositives) 51 | positives_ = positives_[0] if npositives == 1 else positives_ 52 | triples.append((qid, positives_, neg)) 53 | 54 | return triples 55 | 56 | 57 | def main(args): 58 | rankings = load_ranking(args.ranking, types=[int, int, int, float, int]) 59 | 60 | print_message("#> Group by QID") 61 | qid2rankings = groupby_first_item(tqdm.tqdm(rankings)) 62 | 63 | Triples = [] 64 | NonEmptyQIDs = 0 65 | 66 | for processing_idx, qid in enumerate(qid2rankings): 67 | l = sample_for_query(qid, qid2rankings[qid], args.positives, args.depth_positive, args.depth_negative, args.cutoff_negative) 68 | NonEmptyQIDs += (len(l) > 0) 69 | Triples.extend(l) 70 | 71 | if processing_idx % (10_000) == 0: 72 | print_message(f"#> Done with {processing_idx+1} questions!\t\t " 73 | f"{str(len(Triples) / 1000)}k triples for {NonEmptyQIDs} unqiue QIDs.") 74 | 75 | print_message(f"#> Sub-sample the triples (if > {MAX_NUM_TRIPLES})..") 76 | print_message(f"#> len(Triples) = {len(Triples)}") 77 | 78 | if len(Triples) > MAX_NUM_TRIPLES: 79 | Triples = random.sample(Triples, MAX_NUM_TRIPLES) 80 | 81 | ### Prepare the triples ### 82 | print_message("#> Shuffling the triples...") 83 | random.shuffle(Triples) 84 | 85 | print_message("#> Writing {}M examples to file.".format(len(Triples) / 1000.0 / 1000.0)) 86 | 87 | with open(args.output, 'w') as f: 88 | for example in Triples: 89 | ujson.dump(example, f) 90 | f.write('\n') 91 | 92 | with open(f'{args.output}.meta', 'w') as f: 93 | args.cmd = ' '.join(sys.argv) 94 | args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha 95 | ujson.dump(args.__dict__, f, indent=4) 96 | f.write('\n') 97 | 98 | print('\n\n', args, '\n\n') 99 | print(args.output) 100 | print_message("#> Done.") 101 | 102 | 103 | if __name__ == "__main__": 104 | random.seed(12345) 105 | 106 | parser = ArgumentParser(description='Create training triples from ranked list.') 107 | 108 | # Input / Output Arguments 109 | parser.add_argument('--ranking', dest='ranking', required=True, type=str) 110 | parser.add_argument('--output', dest='output', required=True, type=str) 111 | 112 | # Weak Supervision Arguments. 113 | parser.add_argument('--positives', dest='positives', required=True, type=int) 114 | parser.add_argument('--depth+', dest='depth_positive', required=True, type=int) 115 | 116 | parser.add_argument('--depth-', dest='depth_negative', required=True, type=int) 117 | parser.add_argument('--cutoff-', dest='cutoff_negative', required=True, type=int) 118 | 119 | args = parser.parse_args() 120 | 121 | assert not os.path.exists(args.output), args.output 122 | 123 | main(args) 124 | -------------------------------------------------------------------------------- /third_party/utility/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/utility/utils/__init__.py -------------------------------------------------------------------------------- /third_party/utility/utils/qa_loaders.py: -------------------------------------------------------------------------------- 1 | import os 2 | import ujson 3 | 4 | from collections import defaultdict 5 | from third_party.colbert.utils.utils import print_message, file_tqdm 6 | 7 | 8 | def load_collection_(path, retain_titles): 9 | with open(path) as f: 10 | collection = [] 11 | 12 | for line in file_tqdm(f): 13 | _, passage, title = line.strip().split('\t') 14 | 15 | if retain_titles: 16 | passage = title + ' | ' + passage 17 | 18 | collection.append(passage) 19 | 20 | return collection 21 | 22 | 23 | def load_qas_(path): 24 | print_message("#> Loading the reference QAs from", path) 25 | 26 | triples = [] 27 | 28 | with open(path) as f: 29 | for line in f: 30 | qa = ujson.loads(line) 31 | triples.append((qa['qid'], qa['question'], qa['answers'])) 32 | 33 | return triples 34 | -------------------------------------------------------------------------------- /third_party/utility/utils/save_metadata.py: -------------------------------------------------------------------------------- 1 | from third_party.colbert.utils.utils import dotdict 2 | import os 3 | import sys 4 | # import git 5 | import time 6 | import copy 7 | import ujson 8 | import socket 9 | 10 | 11 | def get_metadata_only(): 12 | args = dotdict() 13 | 14 | args.hostname = socket.gethostname() 15 | # try: 16 | # args.git_branch = git.Repo(search_parent_directories=True).active_branch.name 17 | # args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha 18 | # args.git_commit_datetime = str(git.Repo(search_parent_directories=True).head.object.committed_datetime) 19 | # except git.exc.InvalidGitRepositoryError as e: 20 | # pass 21 | args.current_datetime = time.strftime('%b %d, %Y ; %l:%M%p %Z (%z)') 22 | args.cmd = ' '.join(sys.argv) 23 | 24 | return args 25 | 26 | 27 | def get_metadata(args): 28 | args = copy.deepcopy(args) 29 | 30 | args.hostname = socket.gethostname() 31 | args.git_branch = git.Repo(search_parent_directories=True).active_branch.name 32 | args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha 33 | args.git_commit_datetime = str(git.Repo(search_parent_directories=True).head.object.committed_datetime) 34 | args.current_datetime = time.strftime('%b %d, %Y ; %l:%M%p %Z (%z)') 35 | args.cmd = ' '.join(sys.argv) 36 | 37 | try: 38 | args.input_arguments = copy.deepcopy(args.input_arguments.__dict__) 39 | except: 40 | args.input_arguments = None 41 | 42 | return dict(args.__dict__) 43 | 44 | # TODO: No reason for deepcopy. But: (a) Call provenance() on objects that can, (b) Only save simple, small objects. No massive lists or models or weird stuff! 45 | # With that, I think we don't even need (necessarily) to restrict things to input_arguments. 46 | 47 | def format_metadata(metadata): 48 | assert type(metadata) == dict 49 | 50 | return ujson.dumps(metadata, indent=4) 51 | 52 | 53 | def save_metadata(path, args): 54 | assert not os.path.exists(path), path 55 | 56 | with open(path, 'w') as output_metadata: 57 | data = get_metadata(args) 58 | output_metadata.write(format_metadata(data) + '\n') 59 | 60 | return data 61 | --------------------------------------------------------------------------------