├── .gitignore
├── LICENSE
├── README.md
├── assets
└── model.png
├── inference.py
├── requirements.txt
├── scripts
├── inference_m2kr_large.sh
└── inference_m2kr_small.sh
├── src
├── models
│ ├── __init__.py
│ ├── contrastive_loss
│ │ ├── __init__.py
│ │ ├── disco_clip.py
│ │ ├── gather_utils.py
│ │ └── modeling_contrastive_loss.py
│ ├── ret
│ │ ├── __init__.py
│ │ ├── configuration_ret.py
│ │ └── modeling_ret.py
│ └── retriever
│ │ ├── __init__.py
│ │ ├── configuration_retriever.py
│ │ └── modeling_retriever.py
└── utils.py
└── third_party
├── LICENSE
├── LoTTE.md
├── MANIFEST.in
├── README.md
├── colbert
├── __init__.py
├── data
│ ├── __init__.py
│ ├── collection.py
│ ├── dataset.py
│ ├── examples.py
│ ├── queries.py
│ └── ranking.py
├── distillation
│ ├── ranking_scorer.py
│ └── scorer.py
├── evaluation
│ ├── __init__.py
│ ├── load_model.py
│ ├── loaders.py
│ └── metrics.py
├── index.py
├── index_updater.py
├── indexer.py
├── indexing
│ ├── __init__.py
│ ├── codecs
│ │ ├── __init__.py
│ │ ├── decompress_residuals.cpp
│ │ ├── decompress_residuals.cu
│ │ ├── packbits.cpp
│ │ ├── packbits.cu
│ │ ├── residual.py
│ │ ├── residual_embeddings.py
│ │ └── residual_embeddings_strided.py
│ ├── collection_encoder.py
│ ├── collection_indexer.py
│ ├── index_manager.py
│ ├── index_saver.py
│ ├── loaders.py
│ └── utils.py
├── infra
│ ├── __init__.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── base_config.py
│ │ ├── config.py
│ │ ├── core_config.py
│ │ └── settings.py
│ ├── launcher.py
│ ├── provenance.py
│ ├── run.py
│ └── utilities
│ │ ├── annotate_em.py
│ │ ├── create_triples.py
│ │ └── minicorpus.py
├── modeling
│ ├── __init__.py
│ ├── base_colbert.py
│ ├── checkpoint.py
│ ├── colbert.py
│ ├── hf_colbert.py
│ ├── reranker
│ │ ├── __init__.py
│ │ ├── electra.py
│ │ └── tokenizer.py
│ ├── segmented_maxsim.cpp
│ └── tokenization
│ │ ├── __init__.py
│ │ ├── doc_tokenization.py
│ │ ├── query_tokenization.py
│ │ └── utils.py
├── parameters.py
├── ranking
│ └── __init__.py
├── search
│ ├── __init__.py
│ ├── candidate_generation.py
│ ├── decompress_residuals.cpp
│ ├── filter_pids.cpp
│ ├── index_loader.py
│ ├── index_storage.py
│ ├── segmented_lookup.cpp
│ ├── strided_tensor.py
│ └── strided_tensor_core.py
├── searcher.py
├── tests
│ ├── e2e_test.py
│ ├── index_coalesce_test.py
│ ├── index_updater_test.py
│ └── tokenizers_test.py
├── trainer.py
├── training
│ ├── __init__.py
│ ├── eager_batcher.py
│ ├── lazy_batcher.py
│ ├── rerank_batcher.py
│ ├── training.py
│ └── utils.py
├── utilities
│ ├── annotate_em.py
│ ├── create_triples.py
│ └── minicorpus.py
└── utils
│ ├── __init__.py
│ ├── amp.py
│ ├── coalesce.py
│ ├── distributed.py
│ ├── logging.py
│ ├── parser.py
│ ├── runs.py
│ └── utils.py
├── colbert_ai.egg-info
├── PKG-INFO
├── SOURCES.txt
├── dependency_links.txt
├── requires.txt
└── top_level.txt
├── conda_env.yml
├── conda_env_cpu.yml
├── server.py
├── setup.py
└── utility
├── __init__.py
├── evaluate
├── __init__.py
├── annotate_EM.py
├── annotate_EM_helpers.py
├── evaluate_lotte_rankings.py
└── msmarco_passages.py
├── preprocess
├── __init__.py
├── docs2passages.py
└── queries_split.py
├── rankings
├── __init__.py
├── dev_subsample.py
├── merge.py
├── split_by_offset.py
├── split_by_queries.py
└── tune.py
├── supervision
├── __init__.py
├── self_training.py
└── triples.py
└── utils
├── __init__.py
├── dpr.py
├── qa_loaders.py
└── save_metadata.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .scratch/
2 | .vscode/
3 | __pycache__/
--------------------------------------------------------------------------------
/assets/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/assets/model.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.5.2
2 | certifi==2025.1.31
3 | charset-normalizer==3.4.1
4 | filelock==3.13.1
5 | fsspec==2024.6.1
6 | huggingface-hub==0.29.3
7 | idna==3.10
8 | Jinja2==3.1.4
9 | MarkupSafe==2.1.5
10 | mpmath==1.3.0
11 | networkx==3.3
12 | numpy==1.26.4
13 | packaging==24.2
14 | pandas==2.1.4
15 | pillow==11.0.0
16 | python-dateutil==2.9.0.post0
17 | pytz==2025.1
18 | PyYAML==6.0.2
19 | regex==2024.11.6
20 | requests==2.32.3
21 | safetensors==0.5.3
22 | six==1.17.0
23 | sympy==1.13.1
24 | tokenizers==0.20.3
25 | tqdm==4.67.1
26 | transformers==4.45.0
27 | typing_extensions==4.12.2
28 | tzdata==2025.1
29 | ujson==5.10.0
30 | urllib3==2.3.0
--------------------------------------------------------------------------------
/scripts/inference_m2kr_large.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=inference_m2kr_large
3 | #SBATCH --output=
4 | #SBATCH --error=
5 | #SBATCH --open-mode=truncate
6 | #SBATCH --partition=
7 | #SBATCH --account=
8 | #SBATCH --nodes=1
9 | #SBATCH --ntasks-per-node=1
10 | #SBATCH --gpus-per-node=4
11 | #SBATCH --mem=128G
12 | #SBATCH --cpus-per-task=8
13 | #SBATCH --array=0-3
14 | #SBATCH --time=00:30:00
15 |
16 | # tested on 4 NVIDIA A100-SXM-64GB
17 |
18 | conda activate ret
19 | cd ~/ReT
20 | export PYTHONPATH=.
21 |
22 | export TRANSFORMERS_VERBOSITY=info
23 | export TOKENIZERS_PARALLELISM=false
24 | export COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True
25 |
26 | DATASET_NAMES=(
27 | "okvqa"
28 | "infoseek"
29 | "evqa"
30 | "wit"
31 | "llava"
32 | "kvqa"
33 | "oven"
34 | "iglue"
35 | )
36 |
37 | JSONL_ROOT_PATH=
38 | DATASET_PATHS=(
39 | "${JSONL_ROOT_PATH}/okvqa_test.jsonl"
40 | "${JSONL_ROOT_PATH}/infoseek_test.jsonl"
41 | "${JSONL_ROOT_PATH}/evqa_test_m2kr.jsonl"
42 | "${JSONL_ROOT_PATH}/wit_test.jsonl"
43 | "${JSONL_ROOT_PATH}/llava_test.jsonl"
44 | "${JSONL_ROOT_PATH}/kvqa_test.jsonl"
45 | "${JSONL_ROOT_PATH}/oven_test.jsonl"
46 | "${JSONL_ROOT_PATH}/iglue_test.jsonl"
47 | )
48 |
49 | DATASET_PASSAGES_PATHS=(
50 | "${JSONL_ROOT_PATH}/okvqa_passages_test.jsonl"
51 | "${JSONL_ROOT_PATH}/infoseek_passages_test.jsonl"
52 | "${JSONL_ROOT_PATH}/evqa_passages_test.jsonl"
53 | "${JSONL_ROOT_PATH}/wit_passages_test.jsonl"
54 | "${JSONL_ROOT_PATH}/llava_passages_test.jsonl"
55 | "${JSONL_ROOT_PATH}/kvqa_passages_test.jsonl"
56 | "${JSONL_ROOT_PATH}/oven_passages_test.jsonl"
57 | "${JSONL_ROOT_PATH}/iglue_passages_test.jsonl"
58 | )
59 |
60 | IMAGE_ROOT_PATH=
61 |
62 | model_name="ReT-CLIP-ViT-L-14"
63 | checkpoint_path="aimagelab/${model_name}"
64 | root_path=
65 | dataset_path="${DATASET_PATHS[$SLURM_ARRAY_TASK_ID]}"
66 | dataset_passages_path="${DATASET_PASSAGES_PATHS[$SLURM_ARRAY_TASK_ID]}"
67 | experiment_name="${model_name}"
68 | index_name="${DATASET_NAMES[$SLURM_ARRAY_TASK_ID]}"
69 |
70 | echo "DATASET PATH: ${dataset_path}"
71 | echo "DATASET PASSAGES PATH: ${dataset_passages_path}"
72 |
73 | srun -c $SLURM_CPUS_PER_TASK --mem $SLURM_MEM_PER_NODE \
74 | python inference.py \
75 | --action index \
76 | --dataset_path $dataset_passages_path \
77 | --image_root_path $IMAGE_ROOT_PATH \
78 | --checkpoint_path $checkpoint_path \
79 | --root_path $root_path \
80 | --experiment_name $experiment_name \
81 | --index_name $index_name \
82 | --index_bsize 128
83 |
84 | srun -c $SLURM_CPUS_PER_TASK --mem $SLURM_MEM_PER_NODE \
85 | python inference.py \
86 | --action search \
87 | --dataset_path $dataset_path \
88 | --dataset_passages_path $dataset_passages_path \
89 | --image_root_path $IMAGE_ROOT_PATH \
90 | --checkpoint_path $checkpoint_path \
91 | --root_path $root_path \
92 | --experiment_name $experiment_name \
93 | --index_name $index_name \
94 | --index_bsize 128 \
95 | --num_docs_to_retrieve 500
96 |
--------------------------------------------------------------------------------
/scripts/inference_m2kr_small.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=inference_m2kr_small
3 | #SBATCH --output=
4 | #SBATCH --error=
5 | #SBATCH --open-mode=truncate
6 | #SBATCH --partition=
7 | #SBATCH --account=
8 | #SBATCH --nodes=1
9 | #SBATCH --ntasks-per-node=1
10 | #SBATCH --gpus-per-node=1
11 | #SBATCH --mem=128G
12 | #SBATCH --cpus-per-task=8
13 | #SBATCH --array=4-7
14 | #SBATCH --time=00:30:00
15 |
16 | # tested on 1 NVIDIA A100-SXM-64GB
17 |
18 | conda activate ret
19 | cd ~/ReT
20 | export PYTHONPATH=.
21 |
22 | export TRANSFORMERS_VERBOSITY=info
23 | export TOKENIZERS_PARALLELISM=false
24 | export COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True
25 |
26 | DATASET_NAMES=(
27 | "okvqa"
28 | "infoseek"
29 | "evqa"
30 | "wit"
31 | "llava"
32 | "kvqa"
33 | "oven"
34 | "iglue"
35 | )
36 |
37 | JSONL_ROOT_PATH=
38 | DATASET_PATHS=(
39 | "${JSONL_ROOT_PATH}/okvqa_test.jsonl"
40 | "${JSONL_ROOT_PATH}/infoseek_test.jsonl"
41 | "${JSONL_ROOT_PATH}/evqa_test_m2kr.jsonl"
42 | "${JSONL_ROOT_PATH}/wit_test.jsonl"
43 | "${JSONL_ROOT_PATH}/llava_test.jsonl"
44 | "${JSONL_ROOT_PATH}/kvqa_test.jsonl"
45 | "${JSONL_ROOT_PATH}/oven_test.jsonl"
46 | "${JSONL_ROOT_PATH}/iglue_test.jsonl"
47 | )
48 |
49 | DATASET_PASSAGES_PATHS=(
50 | "${JSONL_ROOT_PATH}/okvqa_passages_test.jsonl"
51 | "${JSONL_ROOT_PATH}/infoseek_passages_test.jsonl"
52 | "${JSONL_ROOT_PATH}/evqa_passages_test.jsonl"
53 | "${JSONL_ROOT_PATH}/wit_passages_test.jsonl"
54 | "${JSONL_ROOT_PATH}/llava_passages_test.jsonl"
55 | "${JSONL_ROOT_PATH}/kvqa_passages_test.jsonl"
56 | "${JSONL_ROOT_PATH}/oven_passages_test.jsonl"
57 | "${JSONL_ROOT_PATH}/iglue_passages_test.jsonl"
58 | )
59 |
60 | IMAGE_ROOT_PATH=
61 |
62 | model_name="ReT-CLIP-ViT-L-14"
63 | checkpoint_path="aimagelab/${model_name}"
64 | root_path=
65 | dataset_path="${DATASET_PATHS[$SLURM_ARRAY_TASK_ID]}"
66 | dataset_passages_path="${DATASET_PASSAGES_PATHS[$SLURM_ARRAY_TASK_ID]}"
67 | experiment_name="${model_name}"
68 | index_name="${DATASET_NAMES[$SLURM_ARRAY_TASK_ID]}"
69 |
70 | echo "DATASET PATH: ${dataset_path}"
71 | echo "DATASET PASSAGES PATH: ${dataset_passages_path}"
72 |
73 | srun -c $SLURM_CPUS_PER_TASK --mem $SLURM_MEM_PER_NODE \
74 | python inference.py \
75 | --action index \
76 | --dataset_path $dataset_passages_path \
77 | --image_root_path $IMAGE_ROOT_PATH \
78 | --checkpoint_path $checkpoint_path \
79 | --root_path $root_path \
80 | --experiment_name $experiment_name \
81 | --index_name $index_name \
82 | --index_bsize 128
83 |
84 | srun -c $SLURM_CPUS_PER_TASK --mem $SLURM_MEM_PER_NODE \
85 | python inference.py \
86 | --action search \
87 | --dataset_path $dataset_path \
88 | --dataset_passages_path $dataset_passages_path \
89 | --image_root_path $IMAGE_ROOT_PATH \
90 | --checkpoint_path $checkpoint_path \
91 | --root_path $root_path \
92 | --experiment_name $experiment_name \
93 | --index_name $index_name \
94 | --index_bsize 128 \
95 | --num_docs_to_retrieve 500
96 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .ret import *
2 | from .retriever import *
--------------------------------------------------------------------------------
/src/models/contrastive_loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .modeling_contrastive_loss import contrastive_loss, ContrastiveLossOutput
--------------------------------------------------------------------------------
/src/models/contrastive_loss/disco_clip.py:
--------------------------------------------------------------------------------
1 | # from https://github.com/IDEA-Research/DisCo-CLIP/blob/main/disco/gather.py
2 |
3 | import os
4 | import torch
5 | import torch.distributed as dist
6 |
7 | class DisCoGather(torch.autograd.Function):
8 | """An autograd function that performs allgather on a tensor."""
9 |
10 | @staticmethod
11 | def forward(ctx, tensor):
12 | if not torch.distributed.is_initialized():
13 | raise "torch.distributed is not initialized"
14 |
15 | world_size = torch.distributed.get_world_size()
16 | ctx.bs = tensor.shape[0]
17 | ctx.rank = torch.distributed.get_rank()
18 |
19 | gathered_tensors = [
20 | torch.zeros_like(tensor) for _ in range(world_size)
21 | ]
22 | torch.distributed.all_gather(gathered_tensors, tensor)
23 |
24 | gathered_tensors = torch.cat(gathered_tensors, dim=0)
25 | gathered_tensors.requires_grad_(True)
26 |
27 | return gathered_tensors
28 |
29 | @staticmethod
30 | def backward(ctx, grad_output):
31 | # do not remove this contiguous
32 | torch.distributed.all_reduce(grad_output.contiguous(), op=torch.distributed.ReduceOp.SUM)
33 | return grad_output[ctx.bs*ctx.rank:ctx.bs*(ctx.rank+1)]
34 |
35 |
36 | def Gather(tensor):
37 | return DisCoGather.apply(tensor)
--------------------------------------------------------------------------------
/src/models/contrastive_loss/gather_utils.py:
--------------------------------------------------------------------------------
1 | # from LAVIS codebase: https://github.com/salesforce/LAVIS/blob/main/lavis/models/base_model.py#L202
2 | import torch
3 | from torch.distributed.nn.functional import all_gather as all_gather_with_grad_torch_functional
4 |
5 |
6 | class GatherLayer(torch.autograd.Function):
7 | """
8 | Gather tensors from all workers with support for backward propagation:
9 | This implementation does not cut the gradients as torch.distributed.all_gather does.
10 | """
11 |
12 | @staticmethod
13 | def forward(ctx, x):
14 | output = [
15 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
16 | ]
17 | torch.distributed.all_gather(output, x)
18 | return tuple(output)
19 |
20 | @staticmethod
21 | def backward(ctx, *grads):
22 | all_gradients = torch.stack(grads)
23 | torch.distributed.all_reduce(all_gradients)
24 | return all_gradients[torch.distributed.get_rank()]
25 |
26 |
27 | def all_gather_with_grad(tensors):
28 | """
29 | Performs all_gather operation on the provided tensors.
30 | Graph remains connected for backward grad computation.
31 | """
32 | # Queue the gathered tensors
33 | world_size = torch.distributed.get_world_size()
34 | # There is no need for reduction in the single-proc case
35 | if world_size == 1:
36 | return tensors
37 |
38 | # tensor_all = GatherLayer.apply(tensors)
39 | tensor_all = GatherLayer.apply(tensors)
40 |
41 | return torch.cat(tensor_all, dim=0)
42 |
43 |
44 | def all_gather_with_grad_torch(tensors):
45 | """
46 | Official version of all_gather with grads in torch.distributed.nn.functional
47 | """
48 | tensor_all = all_gather_with_grad_torch_functional(tensors)
49 | return torch.cat(tensor_all, dim=0)
50 |
51 |
52 | def all_gather(x):
53 | output = [
54 | torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
55 | ]
56 | torch.distributed.all_gather(output, x)
57 | return torch.cat(output, dim=0)
--------------------------------------------------------------------------------
/src/models/ret/__init__.py:
--------------------------------------------------------------------------------
1 | from .configuration_ret import RetConfig, RetLayerStrategy
2 | from .modeling_ret import RetModel, RetModelOutput
3 | from transformers import AutoConfig, AutoModel
4 |
5 | __all__ = [
6 | 'RetLayerStrategy',
7 | 'RetConfig',
8 | 'RetModel',
9 | 'RetModelOutput'
10 | ]
11 |
12 | AutoConfig.register(RetConfig.model_type, RetConfig)
13 | AutoModel.register(RetConfig, RetModel)
--------------------------------------------------------------------------------
/src/models/ret/configuration_ret.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Union
2 | from transformers import PretrainedConfig, AutoConfig
3 | from enum import Enum
4 |
5 | class RetLayerStrategy(str, Enum):
6 | CLIP_VIT_B = 'clip_vit_b'
7 | CLIP_VIT_L = 'clip_vit_l'
8 | OPENCLIP_VIT_H = 'openclip_vit_h'
9 | OPENCLIP_VIT_G = 'openclip_vit_g'
10 |
11 |
12 | _RET_LAYER_STRATEGY_MAPPING = {
13 | RetLayerStrategy.CLIP_VIT_B: {
14 | 'text': tuple(range(12)),
15 | 'vision': tuple(range(12))
16 | },
17 | RetLayerStrategy.CLIP_VIT_L: {
18 | 'text': list(range(12)),
19 | 'vision': list(range(24))[::2]
20 | },
21 | RetLayerStrategy.OPENCLIP_VIT_H: {
22 | 'text': (1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23),
23 | 'vision': (0, 2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 31)
24 | },
25 | RetLayerStrategy.OPENCLIP_VIT_G: {
26 | 'text': (0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 31),
27 | 'vision': (2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35, 38, 41, 44, 47)
28 | }
29 | }
30 |
31 |
32 | class RetConfig(PretrainedConfig):
33 | model_type = 'ret'
34 | is_composition = False
35 |
36 | def __init__(
37 | self,
38 | text_config: Optional[Union[PretrainedConfig, Dict]] = None,
39 | vision_config: Optional[Union[PretrainedConfig, Dict]] = None,
40 | is_text_frozen: bool = True,
41 | is_vision_frozen: bool = True,
42 | late_proj_output_size: int = 128,
43 | layer_strategy: str = RetLayerStrategy.CLIP_VIT_B,
44 | use_pooler_features: bool = False,
45 | num_queries: int = 32,
46 | hidden_size: int = 1024,
47 | dropout_p: float = 0.05,
48 | attention_dropout: float = 0.05,
49 | activation_fn: str = 'gelu',
50 | input_gate_bias_prior: float = 0.0,
51 | forget_gate_bias_prior: float = 0.0,
52 | use_tanh: bool = False,
53 | **kwargs
54 | ):
55 | super().__init__(**kwargs)
56 |
57 | if isinstance(text_config, dict):
58 | text_config = AutoConfig.for_model(text_config.pop('model_type'), **text_config)
59 | self.text_config = text_config
60 |
61 | if isinstance(vision_config, dict):
62 | vision_config = AutoConfig.for_model(vision_config.pop('model_type'), **vision_config)
63 | self.vision_config = vision_config
64 |
65 | self.is_text_frozen = is_text_frozen
66 | self.is_vision_frozen = is_vision_frozen
67 | self.late_proj_output_size = late_proj_output_size
68 | self.layer_strategy = layer_strategy
69 | self.use_pooler_features = use_pooler_features
70 |
71 | # recurrent cell
72 | self.num_queries = num_queries
73 | self.hidden_size = hidden_size
74 | self.intermediate_size = hidden_size * 4
75 | self.num_attention_heads = hidden_size // 64
76 | self.dropout_p = dropout_p
77 | self.attention_dropout = attention_dropout
78 | self.activation_fn = activation_fn
79 | self.input_gate_bias_prior = input_gate_bias_prior
80 | self.forget_gate_bias_prior = forget_gate_bias_prior
81 | self.use_tanh = use_tanh
82 |
83 | @property
84 | def n_rec_layers(self):
85 | return len(_RET_LAYER_STRATEGY_MAPPING[self.layer_strategy]['text'])
86 |
87 | @property
88 | def text_layer_idxs(self):
89 | return _RET_LAYER_STRATEGY_MAPPING[self.layer_strategy]['text']
90 |
91 | @property
92 | def vision_layer_idxs(self):
93 | return _RET_LAYER_STRATEGY_MAPPING[self.layer_strategy]['vision']
94 |
95 | @property
96 | def text_hidden_size(self):
97 | if self.text_config is None:
98 | return 0
99 | else:
100 | return self.text_config.hidden_size
101 |
102 | @property
103 | def vision_hidden_size(self):
104 | if self.vision_config is None:
105 | return 0
106 | else:
107 | return self.vision_config.hidden_size
--------------------------------------------------------------------------------
/src/models/retriever/__init__.py:
--------------------------------------------------------------------------------
1 | from .configuration_retriever import RetrieverConfig
2 | from .modeling_retriever import RetrieverModel
3 | from transformers import AutoConfig, AutoModel
4 |
5 | __all__ = [
6 | 'RetrieverConfig',
7 | 'RetrieverModel'
8 | ]
9 |
10 | AutoConfig.register(RetrieverConfig.model_type, RetrieverConfig)
11 | AutoModel.register(RetrieverConfig, RetrieverModel)
--------------------------------------------------------------------------------
/src/models/retriever/configuration_retriever.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Union
2 | from transformers import PretrainedConfig
3 | from src.models.ret import RetConfig
4 |
5 |
6 | class RetrieverConfig(PretrainedConfig):
7 | model_type = 'retriever'
8 | is_composition = False
9 |
10 | def __init__(
11 | self,
12 | query_config: Optional[Union[RetConfig, Dict]] = None,
13 | passage_config: Optional[Union[RetConfig, Dict]] = None,
14 | fg_loss: bool = True,
15 | simmetric_loss: bool = True,
16 | share_query_passage_models: bool = False,
17 | share_text_models: bool = True,
18 | share_vision_models: bool = True,
19 | **kwargs
20 | ):
21 | super().__init__(**kwargs)
22 |
23 | if isinstance(query_config, RetConfig):
24 | query_config = query_config
25 | elif isinstance(query_config, dict):
26 | query_config = RetConfig.from_dict(query_config)
27 | self.query_config = query_config
28 |
29 | if isinstance(passage_config, RetConfig):
30 | passage_config = passage_config
31 | elif isinstance(passage_config, dict):
32 | passage_config = RetConfig.from_dict(passage_config)
33 |
34 | if share_query_passage_models:
35 | self.passage_config = self.query_config
36 | else:
37 | self.passage_config = passage_config
38 |
39 | self.fg_loss = fg_loss
40 | self.simmetric_loss = simmetric_loss
41 | self.share_query_passage_models = share_query_passage_models
42 | self.share_text_models = share_text_models
43 | self.share_vision_models = share_vision_models
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import logging
3 | import os
4 |
5 | logging.set_verbosity_info()
6 | logging.enable_explicit_format()
7 |
8 |
9 | def get_logger():
10 | return logging.get_logger("transformers")
11 |
12 |
13 | def get_additive_attn_mask(binary_attn_mask, dtype):
14 | ret = torch.where(binary_attn_mask.bool(), 0, torch.finfo(dtype).min).to(dtype)
15 | return ret
16 |
17 |
18 | def is_debug():
19 | return int(os.getenv("DEBUG", 0)) == 1
--------------------------------------------------------------------------------
/third_party/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019, 2020 Stanford Future Data Systems
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/third_party/LoTTE.md:
--------------------------------------------------------------------------------
1 | ## LoTTE dataset
2 |
3 | The Long-Tail Topic-stratified Evaluation (LoTTE) benchmark includes 12 domain-specific datasets derived from StackExchange questions and answers. Datasets span topics including writing, recreation, science, technology, and lifestyle. LoTTE includes two sets of queries: the first set is comprised of search-based queries from the GooAQ dataset, while the second set is comprised of forum-based queries taken directly from StackExchange.
4 |
5 | The dataset can be downloaded from this link: [https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/lotte.tar.gz](https://downloads.cs.stanford.edu/nlp/data/colbert/colbertv2/lotte.tar.gz)
6 |
7 | The dataset is organized as follows:
8 | ```
9 | |-- lotte
10 | |-- writing
11 | |-- dev
12 | |-- collection.tsv
13 | |-- metadata.jsonl
14 | |-- questions.search.tsv
15 | |-- qas.search.jsonl
16 | |-- questions.forum.tsv
17 | |-- qas.forum.jsonl
18 | |-- test
19 | |-- collection.tsv
20 | |-- ...
21 | |-- recreation
22 | |-- ...
23 | |-- ...
24 | ```
25 | Here is a description of each file's contents:
26 | - `collection.tsv`: A list of passages where each line is of the form by `[pid]\t[text]`
27 | - `metadata.jsonl`: A list of JSON dictionaries for each question where each line is of the form:
28 | ```
29 | {
30 | "dataset": dataset,
31 | "question_id": question_id,
32 | "post_ids": [post_id_1, post_id_2, ..., post_id_n],
33 | "scores": [score_1, score_2, ..., score_n],
34 | "post_urls": [url_1, url_2, ..., url_n],
35 | "post_authors": [author_1, author_2, ..., author_n],
36 | "post_author_urls": [url_1, url_2, ..., url_n],
37 | "question_author": question_author,
38 | "question_author_url", question_author_url
39 | }
40 | ```
41 | - `questions.search.tsv`: A list of search-based questions of the form `[qid]\t[text]`
42 | - `qas.search.jsonl`: A list of JSON dictionaries for each search-based question's answer data of the form:
43 |
44 | ```
45 | {
46 | "qid": qid,
47 | "query": query,
48 | "answer_pids": answer_pids
49 | }
50 | ```
51 | - `questions.forum.tsv`: A list of forum-based questions
52 | - `qas.forum.tsv`: A list of JSON dictionaries for each forum-based question's answer data
53 |
54 | We also include a script to evaluate LoTTE rankings: `evaluate_lotte_rankings.py`. Each rankings file must be in a tsv format with each line of the form `[qid]\t[pid]\t[rank]\t[score]`. Note that `qid`s must be in sequential order starting from 0, and `rank`s must be in sequential order starting from 1. The rankings directory must have the following structure:
55 | ```
56 | |--rankings
57 | |-- dev
58 | |-- writing.search.ranking.tsv
59 | |-- writing.forum.ranking.tsv
60 | |-- recreation.search.ranking.tsv
61 | |-- recreation.forum.ranking.tsv
62 | |-- science.search.ranking.tsv
63 | |-- science.forum.ranking.tsv
64 | |-- technology.search.ranking.tsv
65 | |-- technology.forum.ranking.tsv
66 | |-- lifestyle.search.ranking.tsv
67 | |-- lifestyle.forum.ranking.tsv
68 | |-- pooled.search.ranking.tsv
69 | |-- pooled.forum.ranking.tsv
70 | |-- test
71 | |-- writing.search.ranking.tsv
72 | |-- ...
73 | ```
74 | Note that the file names must match exactly, though if some files are missing the script will print partial results. An example usage of the script is as follows:
75 | ```
76 | python evaluate_lotte_rankings.py --k 5 --split test --data_path /path/to/lotte --rankings_path /path/to/rankings
77 | ```
78 | This will produce the following output (numbers taken from the ColBERTv2 evaluation):
79 | ```
80 | [query_type=search, dataset=writing] Success@5: 80.1
81 | [query_type=search, dataset=recreation] Success@5: 72.3
82 | [query_type=search, dataset=science] Success@5: 56.7
83 | [query_type=search, dataset=technology] Success@5: 66.1
84 | [query_type=search, dataset=lifestyle] Success@5: 84.7
85 | [query_type=search, dataset=pooled] Success@5: 71.6
86 |
87 | [query_type=forum, dataset=writing] Success@5: 76.3
88 | [query_type=forum, dataset=recreation] Success@5: 70.8
89 | [query_type=forum, dataset=science] Success@5: 46.1
90 | [query_type=forum, dataset=technology] Success@5: 53.6
91 | [query_type=forum, dataset=lifestyle] Success@5: 76.9
92 | [query_type=forum, dataset=pooled] Success@5: 63.4
93 | ```
94 |
95 |
--------------------------------------------------------------------------------
/third_party/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include colbert/indexing/codecs/*.cpp
2 | include colbert/indexing/codecs/*.cu
--------------------------------------------------------------------------------
/third_party/colbert/__init__.py:
--------------------------------------------------------------------------------
1 | from .trainer import Trainer
2 | from .indexer import Indexer
3 | from .searcher import Searcher
4 | from .index_updater import IndexUpdater
5 |
6 | from .modeling.checkpoint import Checkpoint
7 |
--------------------------------------------------------------------------------
/third_party/colbert/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .collection import *
2 | from .queries import *
3 |
4 | from .ranking import *
5 | from .examples import *
6 |
--------------------------------------------------------------------------------
/third_party/colbert/data/collection.py:
--------------------------------------------------------------------------------
1 |
2 | # Could be .tsv or .json. The latter always allows more customization via optional parameters.
3 | # I think it could be worth doing some kind of parallel reads too, if the file exceeds 1 GiBs.
4 | # Just need to use a datastructure that shares things across processes without too much pickling.
5 | # I think multiprocessing.Manager can do that!
6 |
7 | import os
8 | import itertools
9 |
10 | from third_party.colbert.evaluation.loaders import load_collection
11 | from third_party.colbert.infra.run import Run
12 |
13 |
14 | class Collection:
15 | def __init__(self, path=None, data=None):
16 | self.path = path
17 | self.data = data or self._load_file(path)
18 |
19 | def __iter__(self):
20 | # TODO: If __data isn't there, stream from disk!
21 | return self.data.__iter__()
22 |
23 | def __getitem__(self, item):
24 | # TODO: Load from disk the first time this is called. Unless self.data is already not None.
25 | return self.data[item]
26 |
27 | def __len__(self):
28 | # TODO: Load here too. Basically, let's make data a property function and, on first call, either load or get __data.
29 | return len(self.data)
30 |
31 | def _load_file(self, path):
32 | self.path = path
33 | return self._load_tsv(path) if path.endswith('.tsv') else self._load_jsonl(path)
34 |
35 | def _load_tsv(self, path):
36 | return load_collection(path)
37 |
38 | def _load_jsonl(self, path):
39 | raise NotImplementedError()
40 |
41 | def provenance(self):
42 | return self.path
43 |
44 | def toDict(self):
45 | return {'provenance': self.provenance()}
46 |
47 | def save(self, new_path):
48 | assert new_path.endswith('.tsv'), "TODO: Support .json[l] too."
49 | assert not os.path.exists(new_path), new_path
50 |
51 | with Run().open(new_path, 'w') as f:
52 | # TODO: expects content to always be a string here; no separate title!
53 | for pid, content in enumerate(self.data):
54 | content = f'{pid}\t{content}\n'
55 | f.write(content)
56 |
57 | return f.name
58 |
59 | def enumerate(self, rank):
60 | for _, offset, passages in self.enumerate_batches(rank=rank):
61 | for idx, passage in enumerate(passages):
62 | yield (offset + idx, passage)
63 |
64 | def enumerate_batches(self, rank, chunksize=None):
65 | assert rank is not None, "TODO: Add support for the rank=None case."
66 |
67 | chunksize = chunksize or self.get_chunksize()
68 |
69 | offset = 0
70 | iterator = iter(self)
71 |
72 | for chunk_idx, owner in enumerate(itertools.cycle(range(Run().nranks))):
73 | L = [line for _, line in zip(range(chunksize), iterator)]
74 |
75 | if len(L) > 0 and owner == rank:
76 | yield (chunk_idx, offset, L)
77 |
78 | offset += len(L)
79 |
80 | if len(L) < chunksize:
81 | return
82 |
83 | def get_chunksize(self):
84 | return min(25_000, 1 + len(self) // Run().nranks) # 25k is great, 10k allows things to reside on GPU??
85 |
86 | @classmethod
87 | def cast(cls, obj):
88 | if type(obj) is str:
89 | return cls(path=obj)
90 |
91 | if type(obj) is list:
92 | return cls(data=obj)
93 |
94 | if type(obj) is cls:
95 | return obj
96 |
97 | assert False, f"obj has type {type(obj)} which is not compatible with cast()"
98 |
99 |
100 | # TODO: Look up path in some global [per-thread or thread-safe] list.
101 |
--------------------------------------------------------------------------------
/third_party/colbert/data/dataset.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Not just the corpus, but also an arbitrary number of query sets, indexed by name in a dictionary/dotdict.
4 | # And also query sets with top-k PIDs.
5 | # QAs too? TripleSets too?
6 |
7 |
8 | class Dataset:
9 | def __init__(self):
10 | pass
11 |
12 | def select(self, key):
13 | # Select the {corpus, queryset, tripleset, rankingset} determined by uniqueness or by key and return a "unique" dataset (e.g., for key=train)
14 | pass
15 |
--------------------------------------------------------------------------------
/third_party/colbert/data/examples.py:
--------------------------------------------------------------------------------
1 | from third_party.colbert.infra.run import Run
2 | import os
3 | import ujson
4 |
5 | from third_party.colbert.utils.utils import print_message
6 | from third_party.colbert.infra.provenance import Provenance
7 | from third_party.utility.utils.save_metadata import get_metadata_only
8 |
9 |
10 | class Examples:
11 | def __init__(self, path=None, data=None, nway=None, provenance=None):
12 | self.__provenance = provenance or path or Provenance()
13 | self.nway = nway
14 | self.path = path
15 | self.data = data or self._load_file(path)
16 |
17 | def provenance(self):
18 | return self.__provenance
19 |
20 | def toDict(self):
21 | return self.provenance()
22 |
23 | def _load_file(self, path):
24 | nway = self.nway + 1 if self.nway else self.nway
25 | examples = []
26 |
27 | with open(path) as f:
28 | for line in f:
29 | example = ujson.loads(line)[:nway]
30 | examples.append(example)
31 |
32 | return examples
33 |
34 | def tolist(self, rank=None, nranks=None):
35 | """
36 | NOTE: For distributed sampling, this isn't equivalent to perfectly uniform sampling.
37 | In particular, each subset is perfectly represented in every batch! However, since we never
38 | repeat passes over the data, we never repeat any particular triple, and the split across
39 | nodes is random (since the underlying file is pre-shuffled), there's no concern here.
40 | """
41 |
42 | if rank or nranks:
43 | assert rank in range(nranks), (rank, nranks)
44 | return [self.data[idx] for idx in range(0, len(self.data), nranks)] # if line_idx % nranks == rank
45 |
46 | return list(self.data)
47 |
48 | def save(self, new_path):
49 | assert 'json' in new_path.strip('/').split('/')[-1].split('.'), "TODO: Support .json[l] too."
50 |
51 | print_message(f"#> Writing {len(self.data) / 1000_000.0}M examples to {new_path}")
52 |
53 | with Run().open(new_path, 'w') as f:
54 | for example in self.data:
55 | ujson.dump(example, f)
56 | f.write('\n')
57 |
58 | output_path = f.name
59 | print_message(f"#> Saved examples with {len(self.data)} lines to {f.name}")
60 |
61 | with Run().open(f'{new_path}.meta', 'w') as f:
62 | d = {}
63 | d['metadata'] = get_metadata_only()
64 | d['provenance'] = self.provenance()
65 | line = ujson.dumps(d, indent=4)
66 | f.write(line)
67 |
68 | return output_path
69 |
70 | @classmethod
71 | def cast(cls, obj, nway=None):
72 | if type(obj) is str:
73 | return cls(path=obj, nway=nway)
74 |
75 | if isinstance(obj, list):
76 | return cls(data=obj, nway=nway)
77 |
78 | if type(obj) is cls:
79 | assert nway is None, nway
80 | return obj
81 |
82 | assert False, f"obj has type {type(obj)} which is not compatible with cast()"
83 |
--------------------------------------------------------------------------------
/third_party/colbert/data/queries.py:
--------------------------------------------------------------------------------
1 | from third_party.colbert.infra.run import Run
2 | import os
3 | import ujson
4 |
5 | from third_party.colbert.evaluation.loaders import load_queries
6 |
7 | # TODO: Look up path in some global [per-thread or thread-safe] list.
8 | # TODO: path could be a list of paths...? But then how can we tell it's not a list of queries..
9 |
10 |
11 | class Queries:
12 | def __init__(self, path=None, data=None):
13 | self.path = path
14 |
15 | if data:
16 | assert isinstance(data, dict), type(data)
17 | self._load_data(data) or self._load_file(path)
18 |
19 | def __len__(self):
20 | return len(self.data)
21 |
22 | def __iter__(self):
23 | return iter(self.data.items())
24 |
25 | def provenance(self):
26 | return self.path
27 |
28 | def toDict(self):
29 | return {'provenance': self.provenance()}
30 |
31 | def _load_data(self, data):
32 | if data is None:
33 | return None
34 |
35 | self.data = {}
36 | self._qas = {}
37 |
38 | for qid, content in data.items():
39 | if isinstance(content, dict):
40 | self.data[qid] = content['question']
41 | self._qas[qid] = content
42 | else:
43 | self.data[qid] = content
44 |
45 | if len(self._qas) == 0:
46 | del self._qas
47 |
48 | return True
49 |
50 | def _load_file(self, path):
51 | if not path.endswith('.json'):
52 | self.data = load_queries(path)
53 | return True
54 |
55 | # Load QAs
56 | self.data = {}
57 | self._qas = {}
58 |
59 | with open(path) as f:
60 | for line in f:
61 | qa = ujson.loads(line)
62 |
63 | assert qa['qid'] not in self.data
64 | self.data[qa['qid']] = qa['question']
65 | self._qas[qa['qid']] = qa
66 |
67 | return self.data
68 |
69 | def qas(self):
70 | return dict(self._qas)
71 |
72 | def __getitem__(self, key):
73 | return self.data[key]
74 |
75 | def keys(self):
76 | return self.data.keys()
77 |
78 | def values(self):
79 | return self.data.values()
80 |
81 | def items(self):
82 | return self.data.items()
83 |
84 | def save(self, new_path):
85 | assert new_path.endswith('.tsv')
86 | assert not os.path.exists(new_path), new_path
87 |
88 | with Run().open(new_path, 'w') as f:
89 | for qid, content in self.data.items():
90 | content = f'{qid}\t{content}\n'
91 | f.write(content)
92 |
93 | return f.name
94 |
95 | def save_qas(self, new_path):
96 | assert new_path.endswith('.json')
97 | assert not os.path.exists(new_path), new_path
98 |
99 | with open(new_path, 'w') as f:
100 | for qid, qa in self._qas.items():
101 | qa['qid'] = qid
102 | f.write(ujson.dumps(qa) + '\n')
103 |
104 | def _load_tsv(self, path):
105 | raise NotImplementedError
106 |
107 | def _load_jsonl(self, path):
108 | raise NotImplementedError
109 |
110 | @classmethod
111 | def cast(cls, obj):
112 | if type(obj) is str:
113 | return cls(path=obj)
114 |
115 | if isinstance(obj, dict) or isinstance(obj, list):
116 | return cls(data=obj)
117 |
118 | if type(obj) is cls:
119 | return obj
120 |
121 | assert False, f"obj has type {type(obj)} which is not compatible with cast()"
122 |
123 |
124 | # class QuerySet:
125 | # def __init__(self, *paths, renumber=False):
126 | # self.paths = paths
127 | # self.original_queries = [load_queries(path) for path in paths]
128 |
129 | # if renumber:
130 | # self.queries = flatten([q.values() for q in self.original_queries])
131 | # self.queries = {idx: text for idx, text in enumerate(self.queries)}
132 |
133 | # else:
134 | # self.queries = {}
135 |
136 | # for queries in self.original_queries:
137 | # assert len(set.intersection(set(queries.keys()), set(self.queries.keys()))) == 0, \
138 | # "renumber=False requires non-overlapping query IDs"
139 |
140 | # self.queries.update(queries)
141 |
142 | # assert len(self.queries) == sum(map(len, self.original_queries))
143 |
144 | # def todict(self):
145 | # return dict(self.queries)
146 |
147 | # def tolist(self):
148 | # return list(self.queries.values())
149 |
150 | # def query_sets(self):
151 | # return self.original_queries
152 |
153 | # def split_rankings(self, rankings):
154 | # assert type(rankings) is list
155 | # assert len(rankings) == len(self.queries)
156 |
157 | # sub_rankings = []
158 | # offset = 0
159 | # for source in self.original_queries:
160 | # sub_rankings.append(rankings[offset:offset+len(source)])
161 | # offset += len(source)
162 |
163 | # return sub_rankings
164 |
--------------------------------------------------------------------------------
/third_party/colbert/data/ranking.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tqdm
3 | import ujson
4 | from third_party.colbert.infra.provenance import Provenance
5 |
6 | from third_party.colbert.infra.run import Run
7 | from third_party.colbert.utils.utils import print_message, groupby_first_item
8 | from third_party.utility.utils.save_metadata import get_metadata_only
9 |
10 |
11 | def numericize(v):
12 | if '.' in v:
13 | return float(v)
14 |
15 | return int(v)
16 |
17 |
18 | def load_ranking(path): # works with annotated and un-annotated ranked lists
19 | print_message("#> Loading the ranked lists from", path)
20 |
21 | with open(path) as f:
22 | return [list(map(numericize, line.strip().split('\t'))) for line in f]
23 |
24 |
25 | class Ranking:
26 | def __init__(self, path=None, data=None, metrics=None, provenance=None):
27 | self.__provenance = provenance or path or Provenance()
28 | self.data = self._prepare_data(data or self._load_file(path))
29 |
30 | def provenance(self):
31 | return self.__provenance
32 |
33 | def toDict(self):
34 | return {'provenance': self.provenance()}
35 |
36 | def _prepare_data(self, data):
37 | # TODO: Handle list of lists???
38 | if isinstance(data, dict):
39 | self.flat_ranking = [(qid, *rest) for qid, subranking in data.items() for rest in subranking]
40 | return data
41 |
42 | self.flat_ranking = data
43 | return groupby_first_item(tqdm.tqdm(self.flat_ranking))
44 |
45 | def _load_file(self, path):
46 | return load_ranking(path)
47 |
48 | def todict(self):
49 | return dict(self.data)
50 |
51 | def tolist(self):
52 | return list(self.flat_ranking)
53 |
54 | def items(self):
55 | return self.data.items()
56 |
57 | def _load_tsv(self, path):
58 | raise NotImplementedError
59 |
60 | def _load_jsonl(self, path):
61 | raise NotImplementedError
62 |
63 | def save(self, new_path):
64 | assert 'tsv' in new_path.strip('/').split('/')[-1].split('.'), "TODO: Support .json[l] too."
65 |
66 | with Run().open(new_path, 'w') as f:
67 | for items in self.flat_ranking:
68 | line = '\t'.join(map(lambda x: str(int(x) if type(x) is bool else x), items)) + '\n'
69 | f.write(line)
70 |
71 | output_path = f.name
72 | print_message(f"#> Saved ranking of {len(self.data)} queries and {len(self.flat_ranking)} lines to {f.name}")
73 |
74 | with Run().open(f'{new_path}.meta', 'w') as f:
75 | d = {}
76 | d['metadata'] = get_metadata_only()
77 | d['provenance'] = self.provenance()
78 | line = ujson.dumps(d, indent=4)
79 | f.write(line)
80 |
81 | return output_path
82 |
83 | @classmethod
84 | def cast(cls, obj):
85 | if type(obj) is str:
86 | return cls(path=obj)
87 |
88 | if isinstance(obj, dict) or isinstance(obj, list):
89 | return cls(data=obj)
90 |
91 | if type(obj) is cls:
92 | return obj
93 |
94 | assert False, f"obj has type {type(obj)} which is not compatible with cast()"
95 |
--------------------------------------------------------------------------------
/third_party/colbert/distillation/ranking_scorer.py:
--------------------------------------------------------------------------------
1 | import tqdm
2 | import ujson
3 |
4 | from collections import defaultdict
5 |
6 | from third_party.colbert.utils.utils import print_message, zipstar
7 | from third_party.utility.utils.save_metadata import get_metadata_only
8 |
9 | from third_party.colbert.infra import Run
10 | from third_party.colbert.data import Ranking
11 | from third_party.colbert.infra.provenance import Provenance
12 | from third_party.colbert.distillation.scorer import Scorer
13 |
14 |
15 | class RankingScorer:
16 | def __init__(self, scorer: Scorer, ranking: Ranking):
17 | self.scorer = scorer
18 | self.ranking = ranking.tolist()
19 | self.__provenance = Provenance()
20 |
21 | print_message(f"#> Loaded ranking with {len(self.ranking)} qid--pid pairs!")
22 |
23 | def provenance(self):
24 | return self.__provenance
25 |
26 | def run(self):
27 | print_message(f"#> Starting..")
28 |
29 | qids, pids, *_ = zipstar(self.ranking)
30 | distillation_scores = self.scorer.launch(qids, pids)
31 |
32 | scores_by_qid = defaultdict(list)
33 |
34 | for qid, pid, score in tqdm.tqdm(zip(qids, pids, distillation_scores)):
35 | scores_by_qid[qid].append((score, pid))
36 |
37 | with Run().open('distillation_scores.json', 'w') as f:
38 | for qid in tqdm.tqdm(scores_by_qid):
39 | obj = (qid, scores_by_qid[qid])
40 | f.write(ujson.dumps(obj) + '\n')
41 |
42 | output_path = f.name
43 | print_message(f'#> Saved the distillation_scores to {output_path}')
44 |
45 | with Run().open(f'{output_path}.meta', 'w') as f:
46 | d = {}
47 | d['metadata'] = get_metadata_only()
48 | d['provenance'] = self.provenance()
49 | line = ujson.dumps(d, indent=4)
50 | f.write(line)
51 |
52 | return output_path
53 |
--------------------------------------------------------------------------------
/third_party/colbert/distillation/scorer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import tqdm
3 |
4 | from transformers import AutoTokenizer, AutoModelForSequenceClassification
5 |
6 | from third_party.colbert.infra.launcher import Launcher
7 | from third_party.colbert.infra import Run, RunConfig
8 | from third_party.colbert.modeling.reranker.electra import ElectraReranker
9 | from third_party.colbert.utils.utils import flatten
10 |
11 |
12 | DEFAULT_MODEL = 'cross-encoder/ms-marco-MiniLM-L-6-v2'
13 |
14 |
15 | class Scorer:
16 | def __init__(self, queries, collection, model=DEFAULT_MODEL, maxlen=180, bsize=256):
17 | self.queries = queries
18 | self.collection = collection
19 | self.model = model
20 |
21 | self.maxlen = maxlen
22 | self.bsize = bsize
23 |
24 | def launch(self, qids, pids):
25 | launcher = Launcher(self._score_pairs_process, return_all=True)
26 | outputs = launcher.launch(Run().config, qids, pids)
27 |
28 | return flatten(outputs)
29 |
30 | def _score_pairs_process(self, config, qids, pids):
31 | assert len(qids) == len(pids), (len(qids), len(pids))
32 | share = 1 + len(qids) // config.nranks
33 | offset = config.rank * share
34 | endpos = (1 + config.rank) * share
35 |
36 | return self._score_pairs(qids[offset:endpos], pids[offset:endpos], show_progress=(config.rank < 1))
37 |
38 | def _score_pairs(self, qids, pids, show_progress=False):
39 | tokenizer = AutoTokenizer.from_pretrained(self.model)
40 | model = AutoModelForSequenceClassification.from_pretrained(self.model).cuda()
41 |
42 | assert len(qids) == len(pids), (len(qids), len(pids))
43 |
44 | scores = []
45 |
46 | model.eval()
47 | with torch.inference_mode():
48 | with torch.cuda.amp.autocast():
49 | for offset in tqdm.tqdm(range(0, len(qids), self.bsize), disable=(not show_progress)):
50 | endpos = offset + self.bsize
51 |
52 | queries_ = [self.queries[qid] for qid in qids[offset:endpos]]
53 | passages_ = [self.collection[pid] for pid in pids[offset:endpos]]
54 |
55 | features = tokenizer(queries_, passages_, padding='longest', truncation=True,
56 | return_tensors='pt', max_length=self.maxlen).to(model.device)
57 |
58 | scores.append(model(**features).logits.flatten())
59 |
60 | scores = torch.cat(scores)
61 | scores = scores.tolist()
62 |
63 | Run().print(f'Returning with {len(scores)} scores')
64 |
65 | return scores
66 |
67 |
68 | # LONG-TERM TODO: This can be sped up by sorting by length in advance.
69 |
--------------------------------------------------------------------------------
/third_party/colbert/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/evaluation/__init__.py
--------------------------------------------------------------------------------
/third_party/colbert/evaluation/load_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ujson
3 | import torch
4 | import random
5 |
6 | from collections import defaultdict, OrderedDict
7 |
8 | from third_party.colbert.parameters import DEVICE
9 | from third_party.colbert.modeling.colbert import ColBERT
10 | from third_party.colbert.utils.utils import print_message, load_checkpoint
11 |
12 |
13 | def load_model(args, do_print=True):
14 | colbert = ColBERT.from_pretrained('bert-base-uncased',
15 | query_maxlen=args.query_maxlen,
16 | doc_maxlen=args.doc_maxlen,
17 | dim=args.dim,
18 | similarity_metric=args.similarity,
19 | mask_punctuation=args.mask_punctuation)
20 | colbert = colbert.to(DEVICE)
21 |
22 | print_message("#> Loading model checkpoint.", condition=do_print)
23 |
24 | checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print)
25 |
26 | colbert.eval()
27 |
28 | return colbert, checkpoint
29 |
--------------------------------------------------------------------------------
/third_party/colbert/evaluation/metrics.py:
--------------------------------------------------------------------------------
1 | import ujson
2 |
3 | from collections import defaultdict
4 | from third_party.colbert.utils.runs import Run
5 |
6 |
7 | class Metrics:
8 | def __init__(self, mrr_depths: set, recall_depths: set, success_depths: set, total_queries=None):
9 | self.results = {}
10 | self.mrr_sums = {depth: 0.0 for depth in mrr_depths}
11 | self.recall_sums = {depth: 0.0 for depth in recall_depths}
12 | self.success_sums = {depth: 0.0 for depth in success_depths}
13 | self.total_queries = total_queries
14 |
15 | self.max_query_idx = -1
16 | self.num_queries_added = 0
17 |
18 | def add(self, query_idx, query_key, ranking, gold_positives):
19 | self.num_queries_added += 1
20 |
21 | assert query_key not in self.results
22 | assert len(self.results) <= query_idx
23 | assert len(set(gold_positives)) == len(gold_positives)
24 | assert len(set([pid for _, pid, _ in ranking])) == len(ranking)
25 |
26 | self.results[query_key] = ranking
27 |
28 | positives = [i for i, (_, pid, _) in enumerate(ranking) if pid in gold_positives]
29 |
30 | if len(positives) == 0:
31 | return
32 |
33 | for depth in self.mrr_sums:
34 | first_positive = positives[0]
35 | self.mrr_sums[depth] += (1.0 / (first_positive+1.0)) if first_positive < depth else 0.0
36 |
37 | for depth in self.success_sums:
38 | first_positive = positives[0]
39 | self.success_sums[depth] += 1.0 if first_positive < depth else 0.0
40 |
41 | for depth in self.recall_sums:
42 | num_positives_up_to_depth = len([pos for pos in positives if pos < depth])
43 | self.recall_sums[depth] += num_positives_up_to_depth / len(gold_positives)
44 |
45 | def print_metrics(self, query_idx):
46 | for depth in sorted(self.mrr_sums):
47 | print("MRR@" + str(depth), "=", self.mrr_sums[depth] / (query_idx+1.0))
48 |
49 | for depth in sorted(self.success_sums):
50 | print("Success@" + str(depth), "=", self.success_sums[depth] / (query_idx+1.0))
51 |
52 | for depth in sorted(self.recall_sums):
53 | print("Recall@" + str(depth), "=", self.recall_sums[depth] / (query_idx+1.0))
54 |
55 | def log(self, query_idx):
56 | assert query_idx >= self.max_query_idx
57 | self.max_query_idx = query_idx
58 |
59 | Run.log_metric("ranking/max_query_idx", query_idx, query_idx)
60 | Run.log_metric("ranking/num_queries_added", self.num_queries_added, query_idx)
61 |
62 | for depth in sorted(self.mrr_sums):
63 | score = self.mrr_sums[depth] / (query_idx+1.0)
64 | Run.log_metric("ranking/MRR." + str(depth), score, query_idx)
65 |
66 | for depth in sorted(self.success_sums):
67 | score = self.success_sums[depth] / (query_idx+1.0)
68 | Run.log_metric("ranking/Success." + str(depth), score, query_idx)
69 |
70 | for depth in sorted(self.recall_sums):
71 | score = self.recall_sums[depth] / (query_idx+1.0)
72 | Run.log_metric("ranking/Recall." + str(depth), score, query_idx)
73 |
74 | def output_final_metrics(self, path, query_idx, num_queries):
75 | assert query_idx + 1 == num_queries
76 | assert num_queries == self.total_queries
77 |
78 | if self.max_query_idx < query_idx:
79 | self.log(query_idx)
80 |
81 | self.print_metrics(query_idx)
82 |
83 | output = defaultdict(dict)
84 |
85 | for depth in sorted(self.mrr_sums):
86 | score = self.mrr_sums[depth] / (query_idx+1.0)
87 | output['mrr'][depth] = score
88 |
89 | for depth in sorted(self.success_sums):
90 | score = self.success_sums[depth] / (query_idx+1.0)
91 | output['success'][depth] = score
92 |
93 | for depth in sorted(self.recall_sums):
94 | score = self.recall_sums[depth] / (query_idx+1.0)
95 | output['recall'][depth] = score
96 |
97 | with open(path, 'w') as f:
98 | ujson.dump(output, f, indent=4)
99 | f.write('\n')
100 |
101 |
102 | def evaluate_recall(qrels, queries, topK_pids):
103 | if qrels is None:
104 | return
105 |
106 | assert set(qrels.keys()) == set(queries.keys())
107 | recall_at_k = [len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) / max(1.0, len(qrels[qid]))
108 | for qid in qrels]
109 | recall_at_k = sum(recall_at_k) / len(qrels)
110 | recall_at_k = round(recall_at_k, 3)
111 | print("Recall @ maximum depth =", recall_at_k)
112 |
113 |
114 | # TODO: If implicit qrels are used (for re-ranking), warn if a recall metric is requested + add an asterisk to output.
115 |
--------------------------------------------------------------------------------
/third_party/colbert/index.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | # TODO: This is the loaded index, underneath a searcher.
4 |
5 |
6 | """
7 | ## Operations:
8 |
9 | index = Index(index='/path/to/index')
10 | index.load_to_memory()
11 |
12 | batch_of_pids = [2324,32432,98743,23432]
13 | index.lookup(batch_of_pids, device='cuda:0') -> (N, doc_maxlen, dim)
14 |
15 | index.iterate_over_parts()
16 |
17 | """
18 |
--------------------------------------------------------------------------------
/third_party/colbert/indexer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 | import torch.multiprocessing as mp
5 |
6 | from third_party.colbert.infra.run import Run
7 | from third_party.colbert.infra.config import ColBERTConfig, RunConfig
8 | from third_party.colbert.infra.launcher import Launcher
9 |
10 | from third_party.colbert.utils.utils import create_directory, print_message
11 |
12 | from third_party.colbert.indexing.collection_indexer import encode
13 | from transformers import AutoConfig
14 | from src.utils import is_debug
15 |
16 | class Indexer:
17 | def __init__(self, checkpoint, config=None, verbose: int = 3):
18 | """
19 | Use Run().context() to choose the run's configuration. They are NOT extracted from `config`.
20 | """
21 |
22 | self.index_path = None
23 | self.verbose = verbose
24 | self.checkpoint = checkpoint
25 | self.checkpoint_config = ColBERTConfig.load_from_checkpoint(checkpoint)
26 | if config is None:
27 | self.checkpoint_config = AutoConfig.from_pretrained(checkpoint)
28 | else:
29 | self.checkpoint_config = config
30 |
31 | self.config = ColBERTConfig.from_existing(self.checkpoint_config, config, Run().config)
32 | self.configure(checkpoint=checkpoint)
33 |
34 | def configure(self, **kw_args):
35 | self.config.configure(**kw_args)
36 |
37 | def get_index(self):
38 | return self.index_path
39 |
40 | def erase(self, force_silent: bool = False):
41 | assert self.index_path is not None
42 | directory = self.index_path
43 | deleted = []
44 |
45 | for filename in sorted(os.listdir(directory)):
46 | filename = os.path.join(directory, filename)
47 |
48 | delete = filename.endswith(".json")
49 | delete = delete and ('metadata' in filename or 'doclen' in filename or 'plan' in filename)
50 | delete = delete or filename.endswith(".pt")
51 |
52 | if delete:
53 | deleted.append(filename)
54 |
55 | if len(deleted):
56 | if not force_silent and not is_debug():
57 | print_message(f"#> Will delete {len(deleted)} files already at {directory} in 20 seconds...")
58 | time.sleep(20)
59 |
60 | for filename in deleted:
61 | os.remove(filename)
62 |
63 | return deleted
64 |
65 | def index(self, name, collection, overwrite=False):
66 | assert overwrite in [True, False, 'reuse', 'resume', "force_silent_overwrite"]
67 |
68 | self.configure(collection=collection, index_name=name, resume=overwrite=='resume')
69 | # Note: The bsize value set here is ignored internally. Users are encouraged
70 | # to supply their own batch size for indexing by using the index_bsize parameter in the ColBERTConfig.
71 | self.configure(bsize=4, partitions=None)
72 |
73 | self.index_path = self.config.index_path_
74 | index_does_not_exist = (not os.path.exists(self.config.index_path_))
75 |
76 | assert (overwrite in [True, 'reuse', 'resume', "force_silent_overwrite"]) or index_does_not_exist, self.config.index_path_
77 | create_directory(self.config.index_path_)
78 |
79 | if overwrite == 'force_silent_overwrite':
80 | self.erase(force_silent=True)
81 | elif overwrite is True:
82 | self.erase()
83 |
84 | if index_does_not_exist or overwrite != 'reuse':
85 | self.__launch(collection)
86 |
87 | return self.index_path
88 |
89 | def __launch(self, collection):
90 | launcher = Launcher(encode)
91 | if self.config.nranks == 1 and self.config.avoid_fork_if_possible:
92 | shared_queues = []
93 | shared_lists = []
94 | launcher.launch_without_fork(self.config, collection, shared_lists, shared_queues, self.verbose)
95 |
96 | return
97 |
98 | manager = mp.Manager()
99 | shared_lists = [manager.list() for _ in range(self.config.nranks)]
100 | shared_queues = [manager.Queue(maxsize=1) for _ in range(self.config.nranks)]
101 |
102 | # Encodes collection into index using the CollectionIndexer class
103 | self.verbose = True
104 | launcher.launch(self.config, collection, shared_lists, shared_queues, self.verbose)
105 |
--------------------------------------------------------------------------------
/third_party/colbert/indexing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/indexing/__init__.py
--------------------------------------------------------------------------------
/third_party/colbert/indexing/codecs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/indexing/codecs/__init__.py
--------------------------------------------------------------------------------
/third_party/colbert/indexing/codecs/decompress_residuals.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | torch::Tensor decompress_residuals_cuda(
4 | const torch::Tensor binary_residuals, const torch::Tensor bucket_weights,
5 | const torch::Tensor reversed_bit_map,
6 | const torch::Tensor bucket_weight_combinations, const torch::Tensor codes,
7 | const torch::Tensor centroids, const int dim, const int nbits);
8 |
9 | torch::Tensor decompress_residuals(
10 | const torch::Tensor binary_residuals, const torch::Tensor bucket_weights,
11 | const torch::Tensor reversed_bit_map,
12 | const torch::Tensor bucket_weight_combinations, const torch::Tensor codes,
13 | const torch::Tensor centroids, const int dim, const int nbits) {
14 | // Add input verification
15 | return decompress_residuals_cuda(
16 | binary_residuals, bucket_weights, reversed_bit_map,
17 | bucket_weight_combinations, codes, centroids, dim, nbits);
18 | }
19 |
20 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
21 | m.def("decompress_residuals_cpp", &decompress_residuals,
22 | "Decompress residuals");
23 | }
24 |
--------------------------------------------------------------------------------
/third_party/colbert/indexing/codecs/decompress_residuals.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 | __global__ void decompress_residuals_kernel(
9 | const uint8_t* binary_residuals,
10 | const torch::PackedTensorAccessor32
11 | bucket_weights,
12 | const torch::PackedTensorAccessor32
13 | reversed_bit_map,
14 | const torch::PackedTensorAccessor32
15 | bucket_weight_combinations,
16 | const torch::PackedTensorAccessor32 codes,
17 | const torch::PackedTensorAccessor32
18 | centroids,
19 | const int n, const int dim, const int nbits, const int packed_size,
20 | at::Half* output) {
21 | const int packed_dim = (int)(dim * nbits / packed_size);
22 | const int i = blockIdx.x;
23 | const int j = threadIdx.x;
24 |
25 | if (i >= n) return;
26 | if (j >= dim * nbits / packed_size) return;
27 |
28 | const int code = codes[i];
29 |
30 | uint8_t x = binary_residuals[i * packed_dim + j];
31 | x = reversed_bit_map[x];
32 | int output_idx = (int)(j * packed_size / nbits);
33 | for (int k = 0; k < packed_size / nbits; k++) {
34 | assert(output_idx < dim);
35 | const int bucket_weight_idx = bucket_weight_combinations[x][k];
36 | output[i * dim + output_idx] = bucket_weights[bucket_weight_idx];
37 | output[i * dim + output_idx] += centroids[code][output_idx];
38 | output_idx++;
39 | }
40 | }
41 |
42 | torch::Tensor decompress_residuals_cuda(
43 | const torch::Tensor binary_residuals, const torch::Tensor bucket_weights,
44 | const torch::Tensor reversed_bit_map,
45 | const torch::Tensor bucket_weight_combinations, const torch::Tensor codes,
46 | const torch::Tensor centroids, const int dim, const int nbits) {
47 | auto options = torch::TensorOptions()
48 | .dtype(torch::kFloat16)
49 | .device(torch::kCUDA, 0)
50 | .requires_grad(false);
51 | torch::Tensor output =
52 | torch::zeros({(int)binary_residuals.size(0), (int)dim}, options);
53 |
54 | // TODO: Set this automatically?
55 | const int packed_size = 8;
56 |
57 | const int threads = dim / (packed_size / nbits);
58 | const int blocks =
59 | (binary_residuals.size(0) * binary_residuals.size(1)) / threads;
60 |
61 | decompress_residuals_kernel<<>>(
62 | binary_residuals.data(),
63 | bucket_weights
64 | .packed_accessor32(),
65 | reversed_bit_map
66 | .packed_accessor32(),
67 | bucket_weight_combinations
68 | .packed_accessor32(),
69 | codes.packed_accessor32(),
70 | centroids.packed_accessor32(),
71 | binary_residuals.size(0), dim, nbits, packed_size,
72 | output.data());
73 |
74 | return output;
75 | }
76 |
--------------------------------------------------------------------------------
/third_party/colbert/indexing/codecs/packbits.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | torch::Tensor packbits_cuda(const torch::Tensor residuals);
4 |
5 | torch::Tensor packbits(const torch::Tensor residuals) {
6 | return packbits_cuda(residuals);
7 | }
8 |
9 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
10 | m.def("packbits_cpp", &packbits, "Pack bits");
11 | }
12 |
13 |
--------------------------------------------------------------------------------
/third_party/colbert/indexing/codecs/packbits.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 | #define FULL_MASK 0xffffffff
9 |
10 | __global__ void packbits_kernel(
11 | const uint8_t* residuals,
12 | uint8_t* packed_residuals,
13 | const int residuals_size) {
14 | const int i = blockIdx.x;
15 | const int j = threadIdx.x;
16 |
17 | assert(blockDim.x == 32);
18 |
19 | const int residuals_idx = i * blockDim.x + j;
20 | if (residuals_idx >= residuals_size) {
21 | return;
22 | }
23 |
24 | const int packed_residuals_idx = residuals_idx / 8;
25 |
26 |
27 | uint32_t mask = __ballot_sync(FULL_MASK, residuals[residuals_idx]);
28 |
29 | mask = __brev(mask);
30 |
31 | if (residuals_idx % 32 == 0) {
32 | for (int k = 0; k < 4; k++) {
33 | packed_residuals[packed_residuals_idx + k] =
34 | (mask >> (8 * (4 - k - 1))) & 0xff;
35 | }
36 | }
37 | }
38 |
39 | torch::Tensor packbits_cuda(const torch::Tensor residuals) {
40 | auto options = torch::TensorOptions()
41 | .dtype(torch::kUInt8)
42 | .device(torch::kCUDA, residuals.device().index())
43 | .requires_grad(false);
44 | assert(residuals.size(0) % 32 == 0);
45 | torch::Tensor packed_residuals = torch::zeros({int(residuals.size(0) / 8)}, options);
46 |
47 | const int threads = 32;
48 | const int blocks = std::ceil(residuals.size(0) / (float) threads);
49 |
50 | packbits_kernel<<>>(
51 | residuals.data(),
52 | packed_residuals.data(),
53 | residuals.size(0)
54 | );
55 |
56 | return packed_residuals;
57 | }
58 |
--------------------------------------------------------------------------------
/third_party/colbert/indexing/codecs/residual_embeddings_strided.py:
--------------------------------------------------------------------------------
1 | # from third_party.colbert.indexing.codecs.residual import ResidualCodec
2 | import third_party.colbert.indexing.codecs.residual_embeddings as residual_embeddings
3 |
4 | from third_party.colbert.search.strided_tensor import StridedTensor
5 |
6 | class ResidualEmbeddingsStrided:
7 | def __init__(self, codec, embeddings, doclens):
8 | self.codec = codec
9 | self.codes = embeddings.codes
10 | self.residuals = embeddings.residuals
11 | self.use_gpu = self.codec.use_gpu
12 |
13 | self.codes_strided = StridedTensor(self.codes, doclens, use_gpu=self.use_gpu)
14 | self.residuals_strided = StridedTensor(self.residuals, doclens, use_gpu=self.use_gpu)
15 |
16 | def lookup_pids(self, passage_ids, out_device='cuda'):
17 | codes_packed, codes_lengths = self.codes_strided.lookup(passage_ids)#.as_packed_tensor()
18 | residuals_packed, _ = self.residuals_strided.lookup(passage_ids)#.as_packed_tensor()
19 |
20 | embeddings_packed = self.codec.decompress(residual_embeddings.ResidualEmbeddings(codes_packed, residuals_packed))
21 |
22 | return embeddings_packed, codes_lengths
23 |
24 | def lookup_codes(self, passage_ids):
25 | return self.codes_strided.lookup(passage_ids)#.as_packed_tensor()
26 |
--------------------------------------------------------------------------------
/third_party/colbert/indexing/collection_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from third_party.colbert.infra.run import Run
4 | from third_party.colbert.utils.utils import print_message, batch
5 |
6 |
7 | class CollectionEncoder:
8 | def __init__(self, config, checkpoint):
9 | self.config = config
10 | self.checkpoint = checkpoint
11 | self.use_gpu = self.config.total_visible_gpus > 0
12 |
13 | def encode_passages(self, passages):
14 | Run().print(f"#> Encoding {len(passages)} passages..")
15 |
16 | if len(passages) == 0:
17 | return None, None
18 |
19 | with torch.inference_mode():
20 | embs, doclens = [], []
21 |
22 | # Batch here to avoid OOM from storing intermediate embeddings on GPU.
23 | # Storing on the GPU helps with speed of masking, etc.
24 | # But ideally this batching happens internally inside docFromText.
25 | for passage_batch in batch(passages, self.config.index_bsize):
26 | embs_, doclens_ = self.checkpoint.docFromText(passage_batch)
27 | embs.append(embs_)
28 | doclens.extend(doclens_)
29 |
30 | embs = torch.cat(embs)
31 |
32 | return embs, doclens
33 |
--------------------------------------------------------------------------------
/third_party/colbert/indexing/index_manager.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | from bitarray import bitarray
5 |
6 |
7 | class IndexManager():
8 | def __init__(self, dim):
9 | self.dim = dim
10 |
11 | def save(self, tensor, path_prefix):
12 | torch.save(tensor, path_prefix)
13 |
14 | def save_bitarray(self, bitarray, path_prefix):
15 | with open(path_prefix, "wb") as f:
16 | bitarray.tofile(f)
17 |
18 |
19 | def load_index_part(filename, verbose=True):
20 | part = torch.load(filename)
21 |
22 | if type(part) == list: # for backward compatibility
23 | part = torch.cat(part)
24 |
25 | return part
26 |
27 |
28 | def load_compressed_index_part(filename, dim, bits):
29 | a = bitarray()
30 |
31 | with open(filename, "rb") as f:
32 | a.fromfile(f)
33 |
34 | n = len(a) // dim // bits
35 | part = torch.tensor(np.frombuffer(a.tobytes(), dtype=np.uint8)) # TODO: isn't from_numpy(.) faster?
36 | part = part.reshape((n, int(np.ceil(dim * bits / 8))))
37 |
38 | return part
39 |
--------------------------------------------------------------------------------
/third_party/colbert/indexing/index_saver.py:
--------------------------------------------------------------------------------
1 | import os
2 | import queue
3 | import ujson
4 | import threading
5 |
6 | from contextlib import contextmanager
7 |
8 | from third_party.colbert.indexing.codecs.residual import ResidualCodec
9 |
10 | from third_party.colbert.utils.utils import print_message
11 |
12 |
13 | class IndexSaver():
14 | def __init__(self, config):
15 | self.config = config
16 |
17 | def save_codec(self, codec):
18 | codec.save(index_path=self.config.index_path_)
19 |
20 | def load_codec(self):
21 | return ResidualCodec.load(index_path=self.config.index_path_)
22 |
23 | def try_load_codec(self):
24 | try:
25 | ResidualCodec.load(index_path=self.config.index_path_)
26 | return True
27 | except Exception as e:
28 | return False
29 |
30 | def check_chunk_exists(self, chunk_idx):
31 | # TODO: Verify that the chunk has the right amount of data?
32 |
33 | doclens_path = os.path.join(self.config.index_path_, f'doclens.{chunk_idx}.json')
34 | if not os.path.exists(doclens_path):
35 | return False
36 |
37 | metadata_path = os.path.join(self.config.index_path_, f'{chunk_idx}.metadata.json')
38 | if not os.path.exists(metadata_path):
39 | return False
40 |
41 | path_prefix = os.path.join(self.config.index_path_, str(chunk_idx))
42 | codes_path = f'{path_prefix}.codes.pt'
43 | if not os.path.exists(codes_path):
44 | return False
45 |
46 | residuals_path = f'{path_prefix}.residuals.pt' # f'{path_prefix}.residuals.bn'
47 | if not os.path.exists(residuals_path):
48 | return False
49 |
50 | return True
51 |
52 | @contextmanager
53 | def thread(self):
54 | self.codec = self.load_codec()
55 |
56 | self.saver_queue = queue.Queue(maxsize=3)
57 | thread = threading.Thread(target=self._saver_thread)
58 | thread.start()
59 |
60 | try:
61 | yield
62 |
63 | finally:
64 | self.saver_queue.put(None)
65 | thread.join()
66 |
67 | del self.saver_queue
68 | del self.codec
69 |
70 | def save_chunk(self, chunk_idx, offset, embs, doclens):
71 | compressed_embs = self.codec.compress(embs)
72 |
73 | self.saver_queue.put((chunk_idx, offset, compressed_embs, doclens))
74 |
75 | def _saver_thread(self):
76 | for args in iter(self.saver_queue.get, None):
77 | self._write_chunk_to_disk(*args)
78 |
79 | def _write_chunk_to_disk(self, chunk_idx, offset, compressed_embs, doclens):
80 | path_prefix = os.path.join(self.config.index_path_, str(chunk_idx))
81 | compressed_embs.save(path_prefix)
82 |
83 | doclens_path = os.path.join(self.config.index_path_, f'doclens.{chunk_idx}.json')
84 | with open(doclens_path, 'w') as output_doclens:
85 | ujson.dump(doclens, output_doclens)
86 |
87 | metadata_path = os.path.join(self.config.index_path_, f'{chunk_idx}.metadata.json')
88 | with open(metadata_path, 'w') as output_metadata:
89 | metadata = {'passage_offset': offset, 'num_passages': len(doclens), 'num_embeddings': len(compressed_embs)}
90 | ujson.dump(metadata, output_metadata)
91 |
--------------------------------------------------------------------------------
/third_party/colbert/indexing/loaders.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import ujson
4 |
5 |
6 | def get_parts(directory):
7 | extension = '.pt'
8 |
9 | parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory)
10 | if filename.endswith(extension)])
11 |
12 | assert list(range(len(parts))) == parts, parts
13 |
14 | # Integer-sortedness matters.
15 | parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts]
16 | samples_paths = [os.path.join(directory, '{}.sample'.format(filename)) for filename in parts]
17 |
18 | return parts, parts_paths, samples_paths
19 |
20 |
21 | def load_doclens(directory, flatten=True):
22 | doclens_filenames = {}
23 |
24 | for filename in os.listdir(directory):
25 | match = re.match("doclens.(\d+).json", filename)
26 |
27 | if match is not None:
28 | doclens_filenames[int(match.group(1))] = filename
29 |
30 | doclens_filenames = [os.path.join(directory, doclens_filenames[i]) for i in sorted(doclens_filenames.keys())]
31 |
32 | all_doclens = [ujson.load(open(filename)) for filename in doclens_filenames]
33 |
34 | if flatten:
35 | all_doclens = [x for sub_doclens in all_doclens for x in sub_doclens]
36 |
37 | if len(all_doclens) == 0:
38 | raise ValueError("Could not load doclens")
39 |
40 | return all_doclens
41 |
42 |
43 | def get_deltas(directory):
44 | extension = '.residuals.pt'
45 |
46 | parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory)
47 | if filename.endswith(extension)])
48 |
49 | assert list(range(len(parts))) == parts, parts
50 |
51 | # Integer-sortedness matters.
52 | parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts]
53 |
54 | return parts, parts_paths
55 |
56 |
57 | # def load_compression_data(level, path):
58 | # with open(path, "r") as f:
59 | # for line in f:
60 | # line = line.split(',')
61 | # bits = int(line[0])
62 |
63 | # if bits == level:
64 | # return [float(v) for v in line[1:]]
65 |
66 | # raise ValueError(f"No data found for {level}-bit compression")
67 |
--------------------------------------------------------------------------------
/third_party/colbert/indexing/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import tqdm
4 |
5 | from third_party.colbert.indexing.loaders import load_doclens
6 | from third_party.colbert.utils.utils import print_message, flatten
7 |
8 | def optimize_ivf(orig_ivf, orig_ivf_lengths, index_path, verbose:int=3):
9 | if verbose > 1:
10 | print_message("#> Optimizing IVF to store map from centroids to list of pids..")
11 |
12 | print_message("#> Building the emb2pid mapping..")
13 | all_doclens = load_doclens(index_path, flatten=False)
14 |
15 | # assert self.num_embeddings == sum(flatten(all_doclens))
16 |
17 | all_doclens = flatten(all_doclens)
18 | total_num_embeddings = sum(all_doclens)
19 |
20 | emb2pid = torch.zeros(total_num_embeddings, dtype=torch.int)
21 |
22 | """
23 | EVENTUALLY: Use two tensors. emb2pid_offsets will have every 256th element.
24 | emb2pid_delta will have the delta from the corresponding offset,
25 | """
26 |
27 | offset_doclens = 0
28 | for pid, dlength in enumerate(all_doclens):
29 | emb2pid[offset_doclens: offset_doclens + dlength] = pid
30 | offset_doclens += dlength
31 |
32 | if verbose > 1:
33 | print_message("len(emb2pid) =", len(emb2pid))
34 |
35 | ivf = emb2pid[orig_ivf]
36 | unique_pids_per_centroid = []
37 | ivf_lengths = []
38 |
39 | offset = 0
40 | for length in tqdm.tqdm(orig_ivf_lengths.tolist()):
41 | pids = torch.unique(ivf[offset:offset+length])
42 | unique_pids_per_centroid.append(pids)
43 | ivf_lengths.append(pids.shape[0])
44 | offset += length
45 | ivf = torch.cat(unique_pids_per_centroid)
46 | ivf_lengths = torch.tensor(ivf_lengths)
47 |
48 | max_stride = ivf_lengths.max().item()
49 | zero = torch.zeros(1, dtype=torch.long, device=ivf_lengths.device)
50 | offsets = torch.cat((zero, torch.cumsum(ivf_lengths, dim=0)))
51 | inner_dims = ivf.size()[1:]
52 |
53 | if offsets[-2] + max_stride > ivf.size(0):
54 | padding = torch.zeros(max_stride, *inner_dims, dtype=ivf.dtype, device=ivf.device)
55 | ivf = torch.cat((ivf, padding))
56 |
57 | original_ivf_path = os.path.join(index_path, 'ivf.pt')
58 | optimized_ivf_path = os.path.join(index_path, 'ivf.pid.pt')
59 | torch.save((ivf, ivf_lengths), optimized_ivf_path)
60 | if verbose > 1:
61 | print_message(f"#> Saved optimized IVF to {optimized_ivf_path}")
62 | if os.path.exists(original_ivf_path):
63 | print_message(f"#> Original IVF at path \"{original_ivf_path}\" can now be removed")
64 |
65 | return ivf, ivf_lengths
66 |
67 |
--------------------------------------------------------------------------------
/third_party/colbert/infra/__init__.py:
--------------------------------------------------------------------------------
1 | from .run import *
2 | from .config import *
--------------------------------------------------------------------------------
/third_party/colbert/infra/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import *
2 | from .settings import *
--------------------------------------------------------------------------------
/third_party/colbert/infra/config/base_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import ujson
4 | from huggingface_hub import hf_hub_download
5 | from huggingface_hub.utils import RepositoryNotFoundError
6 | import dataclasses
7 |
8 | from typing import Any
9 | from collections import defaultdict
10 | from dataclasses import dataclass, fields
11 | from third_party.colbert.utils.utils import timestamp, torch_load_dnn
12 |
13 | from third_party.utility.utils.save_metadata import get_metadata_only
14 | from .core_config import *
15 |
16 |
17 | @dataclass
18 | class BaseConfig(CoreConfig):
19 | @classmethod
20 | def from_existing(cls, *sources):
21 | kw_args = {}
22 |
23 | for source in sources:
24 | if source is None:
25 | continue
26 |
27 | local_kw_args = dataclasses.asdict(source)
28 | local_kw_args = {k: local_kw_args[k] for k in source.assigned}
29 | kw_args = {**kw_args, **local_kw_args}
30 |
31 | obj = cls(**kw_args)
32 |
33 | return obj
34 |
35 | @classmethod
36 | def from_deprecated_args(cls, args):
37 | obj = cls()
38 | ignored = obj.configure(ignore_unrecognized=True, **args)
39 |
40 | return obj, ignored
41 |
42 | @classmethod
43 | def from_path(cls, name):
44 | with open(name) as f:
45 | args = ujson.load(f)
46 |
47 | if "config" in args:
48 | args = args["config"]
49 |
50 | return cls.from_deprecated_args(
51 | args
52 | ) # the new, non-deprecated version functions the same at this level.
53 |
54 | @classmethod
55 | def load_from_checkpoint(cls, checkpoint_path):
56 | if checkpoint_path.endswith(".dnn"):
57 | dnn = torch_load_dnn(checkpoint_path)
58 | config, _ = cls.from_deprecated_args(dnn.get("arguments", {}))
59 |
60 | # TODO: FIXME: Decide if the line below will have any unintended consequences. We don't want to overwrite those!
61 | config.set("checkpoint", checkpoint_path)
62 |
63 | return config
64 |
65 | try:
66 | checkpoint_path = hf_hub_download(
67 | repo_id=checkpoint_path, filename="artifact.metadata"
68 | ).split("artifact")[0]
69 | except Exception:
70 | pass
71 | loaded_config_path = os.path.join(checkpoint_path, "artifact.metadata")
72 | if os.path.exists(loaded_config_path):
73 | loaded_config, _ = cls.from_path(loaded_config_path)
74 | loaded_config.set("checkpoint", checkpoint_path)
75 |
76 | return loaded_config
77 |
78 | return (
79 | None # can happen if checkpoint_path is something like 'bert-base-uncased'
80 | )
81 |
82 | @classmethod
83 | def load_from_index(cls, index_path):
84 | # FIXME: We should start here with initial_config = ColBERTConfig(config, Run().config).
85 | # This should allow us to say initial_config.index_root. Then, below, set config = Config(..., initial_c)
86 |
87 | # default_index_root = os.path.join(Run().root, Run().experiment, 'indexes/')
88 | # index_path = os.path.join(default_index_root, index_path)
89 |
90 | # CONSIDER: No more plan/metadata.json. Only metadata.json to avoid weird issues when loading an index.
91 |
92 | try:
93 | metadata_path = os.path.join(index_path, "metadata.json")
94 | loaded_config, _ = cls.from_path(metadata_path)
95 | except:
96 | metadata_path = os.path.join(index_path, "plan.json")
97 | loaded_config, _ = cls.from_path(metadata_path)
98 |
99 | return loaded_config
100 |
101 | def save(self, path, overwrite=False):
102 | assert overwrite or not os.path.exists(path), path
103 |
104 | with open(path, "w") as f:
105 | args = self.export() # dict(self.__config)
106 | args["meta"] = get_metadata_only()
107 | args["meta"]["version"] = "colbert-v0.4"
108 | # TODO: Add git_status details.. It can't be too large! It should be a path that Runs() saves on exit, maybe!
109 |
110 | f.write(ujson.dumps(args, indent=4) + "\n")
111 |
112 | def save_for_checkpoint(self, checkpoint_path):
113 | assert not checkpoint_path.endswith(
114 | ".dnn"
115 | ), f"{checkpoint_path}: We reserve *.dnn names for the deprecated checkpoint format."
116 |
117 | output_config_path = os.path.join(checkpoint_path, "artifact.metadata")
118 | self.save(output_config_path, overwrite=True)
119 |
--------------------------------------------------------------------------------
/third_party/colbert/infra/config/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from .base_config import BaseConfig
3 | from .settings import *
4 | from .core_config import DefaultVal
5 |
6 |
7 | @dataclass
8 | class RunConfig(BaseConfig, RunSettings):
9 | pass
10 |
11 |
12 | @dataclass
13 | class ColBERTConfig(RunSettings, ResourceSettings, DocSettings, QuerySettings, TrainingSettings,
14 | IndexingSettings, SearchSettings, BaseConfig, TokenizerSettings):
15 | checkpoint_path: str = DefaultVal(None)
16 |
--------------------------------------------------------------------------------
/third_party/colbert/infra/config/core_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import ujson
4 | import dataclasses
5 |
6 | from typing import Any
7 | from collections import defaultdict
8 | from dataclasses import dataclass, fields
9 | from third_party.colbert.utils.utils import timestamp, torch_load_dnn
10 |
11 | from third_party.utility.utils.save_metadata import get_metadata_only
12 |
13 |
14 | @dataclass
15 | class DefaultVal:
16 | val: Any
17 |
18 | def __hash__(self):
19 | return hash(repr(self.val))
20 |
21 | def __eq__(self, other):
22 | self.val == other.val
23 |
24 | @dataclass
25 | class CoreConfig:
26 | def __post_init__(self):
27 | """
28 | Source: https://stackoverflow.com/a/58081120/1493011
29 | """
30 |
31 | self.assigned = {}
32 |
33 | for field in fields(self):
34 | field_val = getattr(self, field.name)
35 |
36 | if isinstance(field_val, DefaultVal) or field_val is None:
37 | setattr(self, field.name, field.default.val)
38 |
39 | if not isinstance(field_val, DefaultVal):
40 | self.assigned[field.name] = True
41 |
42 | def assign_defaults(self):
43 | for field in fields(self):
44 | setattr(self, field.name, field.default.val)
45 | self.assigned[field.name] = True
46 |
47 | def configure(self, ignore_unrecognized=True, **kw_args):
48 | ignored = set()
49 |
50 | for key, value in kw_args.items():
51 | self.set(key, value, ignore_unrecognized) or ignored.update({key})
52 |
53 | return ignored
54 |
55 | """
56 | # TODO: Take a config object, not kw_args.
57 |
58 | for key in config.assigned:
59 | value = getattr(config, key)
60 | """
61 |
62 | def set(self, key, value, ignore_unrecognized=False):
63 | if hasattr(self, key):
64 | setattr(self, key, value)
65 | self.assigned[key] = True
66 | return True
67 |
68 | if not ignore_unrecognized:
69 | raise Exception(f"Unrecognized key `{key}` for {type(self)}")
70 |
71 | def help(self):
72 | print(ujson.dumps(self.export(), indent=4))
73 |
74 | def __export_value(self, v):
75 | v = v.provenance() if hasattr(v, 'provenance') else v
76 |
77 | if isinstance(v, list) and len(v) > 100:
78 | v = (f"list with {len(v)} elements starting with...", v[:3])
79 |
80 | if isinstance(v, dict) and len(v) > 100:
81 | v = (f"dict with {len(v)} keys starting with...", list(v.keys())[:3])
82 |
83 | return v
84 |
85 | def export(self):
86 | d = dataclasses.asdict(self)
87 |
88 | for k, v in d.items():
89 | d[k] = self.__export_value(v)
90 |
91 | return d
92 |
--------------------------------------------------------------------------------
/third_party/colbert/infra/provenance.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import traceback
3 | import inspect
4 |
5 |
6 | class Provenance:
7 | def __init__(self) -> None:
8 | self.initial_stacktrace = self.stacktrace()
9 |
10 | def stacktrace(self):
11 | trace = inspect.stack()
12 | output = []
13 |
14 | for frame in trace[2:-1]:
15 | try:
16 | frame = f'{frame.filename}:{frame.lineno}:{frame.function}: {frame.code_context[0].strip()}'
17 | output.append(frame)
18 | except:
19 | output.append(None)
20 |
21 | return output
22 |
23 | def toDict(self): # for ujson
24 | self.serialization_stacktrace = self.stacktrace()
25 | return dict(self.__dict__)
26 |
27 |
28 | if __name__ == '__main__':
29 | p = Provenance()
30 | print(p.toDict().keys())
31 |
32 | import ujson
33 | print(ujson.dumps(p, indent=4))
34 |
35 |
36 | class X:
37 | def __init__(self) -> None:
38 | pass
39 |
40 | def toDict(self):
41 | return {'key': 1}
42 |
43 | print(ujson.dumps(X()))
--------------------------------------------------------------------------------
/third_party/colbert/infra/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import atexit
3 |
4 | from third_party.colbert.utils.utils import create_directory, print_message, timestamp
5 | from contextlib import contextmanager
6 |
7 | from third_party.colbert.infra.config import RunConfig
8 |
9 |
10 | class Run(object):
11 | _instance = None
12 |
13 | os.environ["TOKENIZERS_PARALLELISM"] = "true" # NOTE: If a deadlock arises, switch to false!!
14 |
15 | def __new__(cls):
16 | """
17 | Singleton Pattern. See https://python-patterns.guide/gang-of-four/singleton/
18 | """
19 | if cls._instance is None:
20 | cls._instance = super().__new__(cls)
21 | cls._instance.stack = []
22 |
23 | # TODO: Save a timestamp here! And re-use it! But allow the user to override it on calling Run().context a second time.
24 | run_config = RunConfig()
25 | run_config.assign_defaults()
26 |
27 | cls._instance.__append(run_config)
28 |
29 | # TODO: atexit.register(all_done)
30 |
31 | return cls._instance
32 |
33 | @property
34 | def config(self):
35 | return self.stack[-1]
36 |
37 | def __getattr__(self, name):
38 | if hasattr(self.config, name):
39 | return getattr(self.config, name)
40 |
41 | super().__getattr__(name)
42 |
43 | def __append(self, runconfig: RunConfig):
44 | # runconfig.disallow_writes(readonly=True)
45 | self.stack.append(runconfig)
46 |
47 | def __pop(self):
48 | self.stack.pop()
49 |
50 | @contextmanager
51 | def context(self, runconfig: RunConfig, inherit_config=True):
52 | if inherit_config:
53 | runconfig = RunConfig.from_existing(self.config, runconfig)
54 |
55 | self.__append(runconfig)
56 |
57 | try:
58 | yield
59 | finally:
60 | self.__pop()
61 |
62 | def open(self, path, mode='r'):
63 | path = os.path.join(self.path_, path)
64 |
65 | if not os.path.exists(self.path_):
66 | create_directory(self.path_)
67 |
68 | if ('w' in mode or 'a' in mode) and not self.overwrite:
69 | assert not os.path.exists(path), (self.overwrite, path)
70 |
71 | # create directory if it doesn't exist
72 | os.makedirs(os.path.dirname(path), exist_ok=True)
73 |
74 | return open(path, mode=mode)
75 |
76 | def print(self, *args):
77 | print_message("[" + str(self.rank) + "]", "\t\t", *args)
78 |
79 | def print_main(self, *args):
80 | if self.rank == 0:
81 | self.print(*args)
82 |
83 |
84 | if __name__ == '__main__':
85 | print(Run().root, '!')
86 |
87 | with Run().context(RunConfig(rank=0, nranks=1)):
88 | with Run().context(RunConfig(experiment='newproject')):
89 | print(Run().nranks, '!')
90 |
91 | print(Run().config, '!')
92 | print(Run().rank)
93 |
94 |
95 | # TODO: Handle logging all prints to a file. There should be a way to determine the level of logs that go to stdout.
--------------------------------------------------------------------------------
/third_party/colbert/infra/utilities/annotate_em.py:
--------------------------------------------------------------------------------
1 |
2 | from third_party.colbert.infra.run import Run
3 | from third_party.colbert.data.collection import Collection
4 | import os
5 | import sys
6 | import git
7 | import tqdm
8 | import ujson
9 | import random
10 |
11 | from argparse import ArgumentParser
12 | from multiprocessing import Pool
13 |
14 | from third_party.colbert.utils.utils import groupby_first_item, print_message
15 | from third_party.utility.utils.qa_loaders import load_qas_, load_collection_
16 | from third_party.utility.utils.save_metadata import format_metadata, get_metadata
17 | from third_party.utility.evaluate.annotate_EM_helpers import *
18 |
19 | from third_party.colbert.data.ranking import Ranking
20 |
21 |
22 | class AnnotateEM:
23 | def __init__(self, collection, qas):
24 | # TODO: These should just be Queries! But Queries needs to support looking up answers as qid2answers below.
25 | qas = load_qas_(qas)
26 | collection = Collection.cast(collection) # .tolist() #load_collection_(collection, retain_titles=True)
27 |
28 | self.parallel_pool = Pool(30)
29 |
30 | print_message('#> Tokenize the answers in the Q&As in parallel...')
31 | qas = list(self.parallel_pool.map(tokenize_all_answers, qas))
32 |
33 | qid2answers = {qid: tok_answers for qid, _, tok_answers in qas}
34 | assert len(qas) == len(qid2answers), (len(qas), len(qid2answers))
35 |
36 | self.qas, self.collection = qas, collection
37 | self.qid2answers = qid2answers
38 |
39 | def annotate(self, ranking):
40 | rankings = Ranking.cast(ranking)
41 |
42 | # print(len(rankings), rankings[0])
43 |
44 | print_message('#> Lookup passages from PIDs...')
45 | expanded_rankings = [(qid, pid, rank, self.collection[pid], self.qid2answers[qid])
46 | for qid, pid, rank, *_ in rankings.tolist()]
47 |
48 | print_message('#> Assign labels in parallel...')
49 | labeled_rankings = list(self.parallel_pool.map(assign_label_to_passage, enumerate(expanded_rankings)))
50 |
51 | # Dump output.
52 | self.qid2rankings = groupby_first_item(labeled_rankings)
53 |
54 | self.num_judged_queries, self.num_ranked_queries = check_sizes(self.qid2answers, self.qid2rankings)
55 |
56 | # Evaluation metrics and depths.
57 | self.success, self.counts = self._compute_labels(self.qid2answers, self.qid2rankings)
58 |
59 | print(rankings.provenance(), self.success)
60 |
61 | return Ranking(data=self.qid2rankings, provenance=("AnnotateEM", rankings.provenance()))
62 |
63 | def _compute_labels(self, qid2answers, qid2rankings):
64 | cutoffs = [1, 5, 10, 20, 30, 50, 100, 1000, 'all']
65 | success = {cutoff: 0.0 for cutoff in cutoffs}
66 | counts = {cutoff: 0.0 for cutoff in cutoffs}
67 |
68 | for qid in qid2answers:
69 | if qid not in qid2rankings:
70 | continue
71 |
72 | prev_rank = 0 # ranks should start at one (i.e., and not zero)
73 | labels = []
74 |
75 | for pid, rank, label in qid2rankings[qid]:
76 | assert rank == prev_rank+1, (qid, pid, (prev_rank, rank))
77 | prev_rank = rank
78 |
79 | labels.append(label)
80 |
81 | for cutoff in cutoffs:
82 | if cutoff != 'all':
83 | success[cutoff] += sum(labels[:cutoff]) > 0
84 | counts[cutoff] += sum(labels[:cutoff])
85 | else:
86 | success[cutoff] += sum(labels) > 0
87 | counts[cutoff] += sum(labels)
88 |
89 | return success, counts
90 |
91 | def save(self, new_path):
92 | print_message("#> Dumping output to", new_path, "...")
93 |
94 | Ranking(data=self.qid2rankings).save(new_path)
95 |
96 | # Dump metrics.
97 | with Run().open(f'{new_path}.metrics', 'w') as f:
98 | d = {'num_ranked_queries': self.num_ranked_queries, 'num_judged_queries': self.num_judged_queries}
99 |
100 | extra = '__WARNING' if self.num_judged_queries != self.num_ranked_queries else ''
101 | d[f'success{extra}'] = {k: v / self.num_judged_queries for k, v in self.success.items()}
102 | d[f'counts{extra}'] = {k: v / self.num_judged_queries for k, v in self.counts.items()}
103 | # d['arguments'] = get_metadata(args) # TODO: Need arguments...
104 |
105 | f.write(format_metadata(d) + '\n')
106 |
107 |
108 | if __name__ == '__main__':
109 | r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.50.02/ranking.tsv'
110 | r = '/future/u/okhattab/root/unit/experiments/2021.08/retrieve.py/2021-09-04_15.59.37/ranking.tsv'
111 | r = sys.argv[1]
112 |
113 | a = AnnotateEM(collection='/future/u/okhattab/root/unit/data/NQ-mini/collection.tsv',
114 | qas='/future/u/okhattab/root/unit/data/NQ-mini/dev/qas.json')
115 | a.annotate(ranking=r)
116 |
--------------------------------------------------------------------------------
/third_party/colbert/infra/utilities/create_triples.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | from third_party.colbert.utils.utils import print_message
4 | from third_party.utility.utils.save_metadata import save_metadata
5 | from third_party.utility.supervision.triples import sample_for_query
6 |
7 | from third_party.colbert.data.ranking import Ranking
8 | from third_party.colbert.data.examples import Examples
9 |
10 | MAX_NUM_TRIPLES = 40_000_000
11 |
12 |
13 | class Triples:
14 | def __init__(self, ranking, seed=12345):
15 | random.seed(seed) # TODO: Use internal RNG instead..
16 | self.qid2rankings = Ranking.cast(ranking).todict()
17 |
18 | def create(self, positives, depth):
19 | assert all(len(x) == 2 for x in positives)
20 | assert all(maxBest <= maxDepth for maxBest, maxDepth in positives), positives
21 |
22 | Triples = []
23 | NonEmptyQIDs = 0
24 |
25 | for processing_idx, qid in enumerate(self.qid2rankings):
26 | l = sample_for_query(qid, self.qid2rankings[qid], positives, depth, False, None)
27 | NonEmptyQIDs += (len(l) > 0)
28 | Triples.extend(l)
29 |
30 | if processing_idx % (10_000) == 0:
31 | print_message(f"#> Done with {processing_idx+1} questions!\t\t "
32 | f"{str(len(Triples) / 1000)}k triples for {NonEmptyQIDs} unqiue QIDs.")
33 |
34 | print_message(f"#> Sub-sample the triples (if > {MAX_NUM_TRIPLES})..")
35 | print_message(f"#> len(Triples) = {len(Triples)}")
36 |
37 | if len(Triples) > MAX_NUM_TRIPLES:
38 | Triples = random.sample(Triples, MAX_NUM_TRIPLES)
39 |
40 | ### Prepare the triples ###
41 | print_message("#> Shuffling the triples...")
42 | random.shuffle(Triples)
43 |
44 | self.Triples = Examples(data=Triples)
45 |
46 | return Triples
47 |
48 | def save(self, new_path):
49 | Examples(data=self.Triples).save(new_path)
50 |
51 | # save_metadata(f'{output}.meta', args) # TODO: What args to save?? {seed, positives, depth, rankings if path or else whatever provenance the rankings object shares}
52 |
53 |
--------------------------------------------------------------------------------
/third_party/colbert/infra/utilities/minicorpus.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | from third_party.colbert.utils.utils import create_directory
5 |
6 | from third_party.colbert.data import Collection, Queries, Ranking
7 |
8 |
9 | def sample_minicorpus(name, factor, topk=30, maxdev=3000):
10 | """
11 | Factor:
12 | * nano=1
13 | * micro=10
14 | * mini=100
15 | * small=100 with topk=100
16 | * medium=150 with topk=300
17 | """
18 |
19 | random.seed(12345)
20 |
21 | # Load collection
22 | collection = Collection(path='/dfs/scratch0/okhattab/OpenQA/collection.tsv')
23 |
24 | # Load train and dev queries
25 | qas_train = Queries(path='/dfs/scratch0/okhattab/OpenQA/NQ/train/qas.json').qas()
26 | qas_dev = Queries(path='/dfs/scratch0/okhattab/OpenQA/NQ/dev/qas.json').qas()
27 |
28 | # Load train and dev C3 rankings
29 | ranking_train = Ranking(path='/dfs/scratch0/okhattab/OpenQA/NQ/train/rankings/C3.tsv.annotated').todict()
30 | ranking_dev = Ranking(path='/dfs/scratch0/okhattab/OpenQA/NQ/dev/rankings/C3.tsv.annotated').todict()
31 |
32 | # Sample NT and ND queries from each, keep only the top-k passages for those
33 | sample_train = random.sample(list(qas_train.keys()), min(len(qas_train.keys()), 300*factor))
34 | sample_dev = random.sample(list(qas_dev.keys()), min(len(qas_dev.keys()), maxdev, 30*factor))
35 |
36 | train_pids = [pid for qid in sample_train for qpids in ranking_train[qid][:topk] for pid in qpids]
37 | dev_pids = [pid for qid in sample_dev for qpids in ranking_dev[qid][:topk] for pid in qpids]
38 |
39 | sample_pids = sorted(list(set(train_pids + dev_pids)))
40 | print(f'len(sample_pids) = {len(sample_pids)}')
41 |
42 | # Save the new query sets: train and dev
43 | ROOT = f'/future/u/okhattab/root/unit/data/NQ-{name}'
44 |
45 | create_directory(os.path.join(ROOT, 'train'))
46 | create_directory(os.path.join(ROOT, 'dev'))
47 |
48 | new_train = Queries(data={qid: qas_train[qid] for qid in sample_train})
49 | new_train.save(os.path.join(ROOT, 'train/questions.tsv'))
50 | new_train.save_qas(os.path.join(ROOT, 'train/qas.json'))
51 |
52 | new_dev = Queries(data={qid: qas_dev[qid] for qid in sample_dev})
53 | new_dev.save(os.path.join(ROOT, 'dev/questions.tsv'))
54 | new_dev.save_qas(os.path.join(ROOT, 'dev/qas.json'))
55 |
56 | # Save the new collection
57 | print(f"Saving to {os.path.join(ROOT, 'collection.tsv')}")
58 | Collection(data=[collection[pid] for pid in sample_pids]).save(os.path.join(ROOT, 'collection.tsv'))
59 |
60 | print('#> Done!')
61 |
62 |
63 | if __name__ == '__main__':
64 | sample_minicorpus('medium', 150, topk=300)
65 |
--------------------------------------------------------------------------------
/third_party/colbert/modeling/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/modeling/__init__.py
--------------------------------------------------------------------------------
/third_party/colbert/modeling/base_colbert.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import sys
4 |
5 | from third_party.colbert.utils.utils import torch_load_dnn
6 |
7 | from transformers import AutoTokenizer
8 | from third_party.colbert.modeling.hf_colbert import class_factory
9 | from third_party.colbert.infra.config import ColBERTConfig
10 | from third_party.colbert.parameters import DEVICE
11 |
12 |
13 | class BaseColBERT(torch.nn.Module):
14 | """
15 | Shallow module that wraps the ColBERT parameters, custom configuration, and underlying tokenizer.
16 | This class provides direct instantiation and saving of the model/colbert_config/tokenizer package.
17 |
18 | Like HF, evaluation mode is the default.
19 | """
20 |
21 | def __init__(self, name_or_path, colbert_config=None):
22 | super().__init__()
23 |
24 | self.colbert_config = ColBERTConfig.from_existing(ColBERTConfig.load_from_checkpoint(name_or_path), colbert_config)
25 | self.name = self.colbert_config.model_name or name_or_path
26 |
27 | try:
28 | HF_ColBERT = class_factory(self.name)
29 | except:
30 | self.name = 'bert-base-uncased' # TODO: Double check that this is appropriate here in all cases
31 | HF_ColBERT = class_factory(self.name)
32 |
33 | # assert self.name is not None
34 | # HF_ColBERT = class_factory(self.name)
35 |
36 | self.model = HF_ColBERT.from_pretrained(name_or_path, colbert_config=self.colbert_config)
37 | self.model.to(DEVICE)
38 | self.raw_tokenizer = AutoTokenizer.from_pretrained(name_or_path)
39 |
40 | self.eval()
41 |
42 | @property
43 | def device(self):
44 | return self.model.device
45 |
46 | @property
47 | def bert(self):
48 | return self.model.LM
49 |
50 | @property
51 | def linear(self):
52 | return self.model.linear
53 |
54 | @property
55 | def score_scaler(self):
56 | return self.model.score_scaler
57 |
58 | def save(self, path):
59 | assert not path.endswith('.dnn'), f"{path}: We reserve *.dnn names for the deprecated checkpoint format."
60 |
61 | self.model.save_pretrained(path)
62 | self.raw_tokenizer.save_pretrained(path)
63 |
64 | self.colbert_config.save_for_checkpoint(path)
65 |
66 |
67 | if __name__ == '__main__':
68 | import random
69 | import numpy as np
70 |
71 | from third_party.colbert.infra.run import Run
72 | from third_party.colbert.infra.config import RunConfig
73 |
74 | random.seed(12345)
75 | np.random.seed(12345)
76 | torch.manual_seed(12345)
77 |
78 | with Run().context(RunConfig(gpus=2)):
79 | m = BaseColBERT('bert-base-uncased', colbert_config=ColBERTConfig(Run().config, doc_maxlen=300, similarity='l2'))
80 | m.colbert_config.help()
81 | print(m.linear.weight)
82 | m.save('/future/u/okhattab/tmp/2021/08/model.deleteme2/')
83 |
84 | m2 = BaseColBERT('/future/u/okhattab/tmp/2021/08/model.deleteme2/')
85 | m2.colbert_config.help()
86 | print(m2.linear.weight)
87 |
88 | exit()
89 |
90 | m = BaseColBERT('/future/u/okhattab/tmp/2021/08/model.deleteme/')
91 | print('BaseColBERT', m.linear.weight)
92 | print('BaseColBERT', m.colbert_config)
93 |
94 | exit()
95 |
96 | # m = HF_ColBERT.from_pretrained('nreimers/MiniLMv2-L6-H768-distilled-from-BERT-Large')
97 | m = HF_ColBERT.from_pretrained('/future/u/okhattab/tmp/2021/08/model.deleteme/')
98 | print('HF_ColBERT', m.linear.weight)
99 |
100 | m.save_pretrained('/future/u/okhattab/tmp/2021/08/model.deleteme/')
101 |
102 | # old = OldColBERT.from_pretrained('bert-base-uncased')
103 | # print(old.bert.encoder.layer[10].attention.self.value.weight)
104 |
105 | # random.seed(12345)
106 | # np.random.seed(12345)
107 | # torch.manual_seed(12345)
108 |
109 | dnn = torch_load_dnn(
110 | "/future/u/okhattab/root/TACL21/experiments/Feb26.NQ/train.py/ColBERT.C3/checkpoints/colbert-60000.dnn")
111 | # base = dnn.get('arguments', {}).get('model', 'bert-base-uncased')
112 |
113 | # new = BaseColBERT.from_pretrained('bert-base-uncased', state_dict=dnn['model_state_dict'])
114 |
115 | # print(new.bert.encoder.layer[10].attention.self.value.weight)
116 |
117 | print(dnn['model_state_dict']['linear.weight'])
118 | # print(dnn['model_state_dict']['bert.encoder.layer.10.attention.self.value.weight'])
119 |
120 | # # base_model_prefix
121 |
--------------------------------------------------------------------------------
/third_party/colbert/modeling/reranker/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/modeling/reranker/__init__.py
--------------------------------------------------------------------------------
/third_party/colbert/modeling/reranker/electra.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from transformers import ElectraPreTrainedModel, ElectraModel, AutoTokenizer
4 |
5 | class ElectraReranker(ElectraPreTrainedModel):
6 | """
7 | Shallow wrapper around HuggingFace transformers. All new parameters should be defined at this level.
8 |
9 | This makes sure `{from,save}_pretrained` and `init_weights` are applied to new parameters correctly.
10 | """
11 | _keys_to_ignore_on_load_unexpected = [r"cls"]
12 |
13 | def __init__(self, config):
14 | super().__init__(config)
15 |
16 | self.electra = ElectraModel(config)
17 | self.linear = nn.Linear(config.hidden_size, 1)
18 | self.raw_tokenizer = AutoTokenizer.from_pretrained('google/electra-large-discriminator')
19 |
20 | self.init_weights()
21 |
22 | def forward(self, encoding):
23 | outputs = self.electra(encoding.input_ids,
24 | attention_mask=encoding.attention_mask,
25 | token_type_ids=encoding.token_type_ids)[0]
26 |
27 | scores = self.linear(outputs[:, 0]).squeeze(-1)
28 |
29 | return scores
30 |
31 | def save(self, path):
32 | assert not path.endswith('.dnn'), f"{path}: We reserve *.dnn names for the deprecated checkpoint format."
33 |
34 | self.save_pretrained(path)
35 | self.raw_tokenizer.save_pretrained(path)
--------------------------------------------------------------------------------
/third_party/colbert/modeling/reranker/tokenizer.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer
2 |
3 | class RerankerTokenizer():
4 | def __init__(self, total_maxlen, base):
5 | self.total_maxlen = total_maxlen
6 | self.tok = AutoTokenizer.from_pretrained(base)
7 |
8 | def tensorize(self, questions, passages):
9 | assert type(questions) in [list, tuple], type(questions)
10 | assert type(passages) in [list, tuple], type(passages)
11 |
12 | encoding = self.tok(questions, passages, padding='longest', truncation='longest_first',
13 | return_tensors='pt', max_length=self.total_maxlen, add_special_tokens=True)
14 |
15 | return encoding
16 |
--------------------------------------------------------------------------------
/third_party/colbert/modeling/segmented_maxsim.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include
5 | #include
6 |
7 | typedef struct {
8 | int tid;
9 | int nthreads;
10 |
11 | int ndocs;
12 | int ndoc_vectors;
13 | int nquery_vectors;
14 |
15 | int64_t* lengths;
16 | float* scores;
17 | int64_t* offsets;
18 |
19 | float* max_scores;
20 | } max_args_t;
21 |
22 | void* max(void* args) {
23 | max_args_t* max_args = (max_args_t*)args;
24 |
25 | int ndocs_per_thread =
26 | std::ceil(((float)max_args->ndocs) / max_args->nthreads);
27 | int start = max_args->tid * ndocs_per_thread;
28 | int end = std::min((max_args->tid + 1) * ndocs_per_thread, max_args->ndocs);
29 |
30 | auto max_scores_offset =
31 | max_args->max_scores + (start * max_args->nquery_vectors);
32 | auto scores_offset =
33 | max_args->scores + (max_args->offsets[start] * max_args->nquery_vectors);
34 |
35 | for (int i = start; i < end; i++) {
36 | for (int j = 0; j < max_args->lengths[i]; j++) {
37 | std::transform(max_scores_offset,
38 | max_scores_offset + max_args->nquery_vectors,
39 | scores_offset, max_scores_offset,
40 | [](float a, float b) { return std::max(a, b); });
41 | scores_offset += max_args->nquery_vectors;
42 | }
43 | max_scores_offset += max_args->nquery_vectors;
44 | }
45 |
46 | return NULL;
47 | }
48 |
49 | torch::Tensor segmented_maxsim(const torch::Tensor scores,
50 | const torch::Tensor lengths) {
51 | auto lengths_a = lengths.data_ptr();
52 | auto scores_a = scores.data_ptr();
53 | auto ndocs = lengths.size(0);
54 | auto ndoc_vectors = scores.size(0);
55 | auto nquery_vectors = scores.size(1);
56 | auto nthreads = at::get_num_threads();
57 |
58 | torch::Tensor max_scores =
59 | torch::zeros({ndocs, nquery_vectors}, scores.options());
60 |
61 | int64_t offsets[ndocs + 1];
62 | offsets[0] = 0;
63 | std::partial_sum(lengths_a, lengths_a + ndocs, offsets + 1);
64 |
65 | pthread_t threads[nthreads];
66 | max_args_t args[nthreads];
67 |
68 | for (int i = 0; i < nthreads; i++) {
69 | args[i].tid = i;
70 | args[i].nthreads = nthreads;
71 |
72 | args[i].ndocs = ndocs;
73 | args[i].ndoc_vectors = ndoc_vectors;
74 | args[i].nquery_vectors = nquery_vectors;
75 |
76 | args[i].lengths = lengths_a;
77 | args[i].scores = scores_a;
78 | args[i].offsets = offsets;
79 |
80 | args[i].max_scores = max_scores.data_ptr();
81 |
82 | int rc = pthread_create(&threads[i], NULL, max, (void*)&args[i]);
83 | if (rc) {
84 | fprintf(stderr, "Unable to create thread %d: %d\n", i, rc);
85 | }
86 | }
87 |
88 | for (int i = 0; i < nthreads; i++) {
89 | pthread_join(threads[i], NULL);
90 | }
91 |
92 | return max_scores.sum(1);
93 | }
94 |
95 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
96 | m.def("segmented_maxsim_cpp", &segmented_maxsim, "Segmented MaxSim");
97 | }
98 |
--------------------------------------------------------------------------------
/third_party/colbert/modeling/tokenization/__init__.py:
--------------------------------------------------------------------------------
1 | from third_party.colbert.modeling.tokenization.query_tokenization import *
2 | from third_party.colbert.modeling.tokenization.doc_tokenization import *
3 | from third_party.colbert.modeling.tokenization.utils import tensorize_triples
4 |
--------------------------------------------------------------------------------
/third_party/colbert/modeling/tokenization/doc_tokenization.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # from transformers import BertTokenizerFast
4 |
5 | from third_party.colbert.modeling.hf_colbert import class_factory
6 | from third_party.colbert.infra import ColBERTConfig
7 | from third_party.colbert.modeling.tokenization.utils import _split_into_batches, _sort_by_length, _insert_prefix_token
8 | from third_party.colbert.parameters import DEVICE
9 |
10 | class DocTokenizer():
11 | def __init__(self, config: ColBERTConfig):
12 | HF_ColBERT = class_factory(config.checkpoint)
13 | self.tok = HF_ColBERT.raw_tokenizer_from_pretrained(config.checkpoint)
14 |
15 | self.config = config
16 | self.doc_maxlen = config.doc_maxlen
17 |
18 | self.D_marker_token, self.D_marker_token_id = self.config.doc_token, self.tok.convert_tokens_to_ids(self.config.doc_token_id)
19 | self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
20 | self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
21 |
22 | def tokenize(self, batch_text, add_special_tokens=False):
23 | assert type(batch_text) in [list, tuple], (type(batch_text))
24 |
25 | tokens = [self.tok.tokenize(x, add_special_tokens=False).to(DEVICE) for x in batch_text]
26 |
27 | if not add_special_tokens:
28 | return tokens
29 |
30 | prefix, suffix = [self.cls_token, self.D_marker_token], [self.sep_token]
31 | tokens = [prefix + lst + suffix for lst in tokens]
32 |
33 | return tokens
34 |
35 | def encode(self, batch_text, add_special_tokens=False):
36 | assert type(batch_text) in [list, tuple], (type(batch_text))
37 |
38 | ids = self.tok(batch_text, add_special_tokens=False).to(DEVICE)['input_ids']
39 |
40 | if not add_special_tokens:
41 | return ids
42 |
43 | prefix, suffix = [self.cls_token_id, self.D_marker_token_id], [self.sep_token_id]
44 | ids = [prefix + lst + suffix for lst in ids]
45 |
46 | return ids
47 |
48 | def tensorize(self, batch_text, bsize=None):
49 | assert type(batch_text) in [list, tuple], (type(batch_text))
50 |
51 | obj = self.tok(batch_text, padding='longest', truncation='longest_first',
52 | return_tensors='pt', max_length=(self.doc_maxlen - 1)).to(DEVICE)
53 |
54 | ids = _insert_prefix_token(obj['input_ids'], self.D_marker_token_id)
55 | mask = _insert_prefix_token(obj['attention_mask'], 1)
56 |
57 | if bsize:
58 | ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
59 | batches = _split_into_batches(ids, mask, bsize)
60 | return batches, reverse_indices
61 |
62 | return ids, mask
63 |
--------------------------------------------------------------------------------
/third_party/colbert/modeling/tokenization/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def tensorize_triples(query_tokenizer, doc_tokenizer, queries, passages, scores, bsize, nway):
5 | # assert len(passages) == len(scores) == bsize * nway
6 | # assert bsize is None or len(queries) % bsize == 0
7 |
8 | # N = len(queries)
9 | Q_ids, Q_mask = query_tokenizer.tensorize(queries)
10 | D_ids, D_mask = doc_tokenizer.tensorize(passages)
11 | # D_ids, D_mask = D_ids.view(2, N, -1), D_mask.view(2, N, -1)
12 |
13 | # # Compute max among {length of i^th positive, length of i^th negative} for i \in N
14 | # maxlens = D_mask.sum(-1).max(0).values
15 |
16 | # # Sort by maxlens
17 | # indices = maxlens.sort().indices
18 | # Q_ids, Q_mask = Q_ids[indices], Q_mask[indices]
19 | # D_ids, D_mask = D_ids[:, indices], D_mask[:, indices]
20 |
21 | # (positive_ids, negative_ids), (positive_mask, negative_mask) = D_ids, D_mask
22 |
23 | query_batches = _split_into_batches(Q_ids, Q_mask, bsize)
24 | doc_batches = _split_into_batches(D_ids, D_mask, bsize * nway)
25 | # positive_batches = _split_into_batches(positive_ids, positive_mask, bsize)
26 | # negative_batches = _split_into_batches(negative_ids, negative_mask, bsize)
27 |
28 | if len(scores):
29 | score_batches = _split_into_batches2(scores, bsize * nway)
30 | else:
31 | score_batches = [[] for _ in doc_batches]
32 |
33 | batches = []
34 | for Q, D, S in zip(query_batches, doc_batches, score_batches):
35 | batches.append((Q, D, S))
36 |
37 | return batches
38 |
39 |
40 | def _sort_by_length(ids, mask, bsize):
41 | if ids.size(0) <= bsize:
42 | return ids, mask, torch.arange(ids.size(0))
43 |
44 | indices = mask.sum(-1).sort().indices
45 | reverse_indices = indices.sort().indices
46 |
47 | return ids[indices], mask[indices], reverse_indices
48 |
49 |
50 | def _split_into_batches(ids, mask, bsize):
51 | batches = []
52 | for offset in range(0, ids.size(0), bsize):
53 | batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize]))
54 |
55 | return batches
56 |
57 |
58 | def _split_into_batches2(scores, bsize):
59 | batches = []
60 | for offset in range(0, len(scores), bsize):
61 | batches.append(scores[offset:offset+bsize])
62 |
63 | return batches
64 |
65 | def _insert_prefix_token(tensor: torch.Tensor, prefix_id: int):
66 | prefix_tensor = torch.full(
67 | (tensor.size(0), 1),
68 | prefix_id,
69 | dtype=tensor.dtype,
70 | device=tensor.device,
71 | )
72 | return torch.cat([tensor[:, :1], prefix_tensor, tensor[:, 1:]], dim=1)
73 |
--------------------------------------------------------------------------------
/third_party/colbert/parameters.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4 |
5 | SAVED_CHECKPOINTS = [32*1000, 100*1000, 150*1000, 200*1000, 250*1000, 300*1000, 400*1000]
6 | SAVED_CHECKPOINTS += [10*1000, 20*1000, 30*1000, 40*1000, 50*1000, 60*1000, 70*1000, 80*1000, 90*1000]
7 | SAVED_CHECKPOINTS += [25*1000, 50*1000, 75*1000]
8 |
9 | SAVED_CHECKPOINTS = set(SAVED_CHECKPOINTS)
10 |
11 |
12 | # TODO: final_ckpt 2k, 5k, 10k 20k, 50k, 100k 150k 200k, 500k, 1M 2M, 5M, 10M
--------------------------------------------------------------------------------
/third_party/colbert/ranking/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/ranking/__init__.py
--------------------------------------------------------------------------------
/third_party/colbert/search/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/search/__init__.py
--------------------------------------------------------------------------------
/third_party/colbert/search/candidate_generation.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from third_party.colbert.search.strided_tensor import StridedTensor
4 | from .strided_tensor_core import _create_mask, _create_view
5 |
6 |
7 | class CandidateGeneration:
8 |
9 | def __init__(self, use_gpu=True):
10 | self.use_gpu = use_gpu
11 |
12 | def get_cells(self, Q, ncells):
13 | scores = (self.codec.centroids @ Q.T)
14 | if ncells == 1:
15 | cells = scores.argmax(dim=0, keepdim=True).permute(1, 0)
16 | else:
17 | cells = scores.topk(ncells, dim=0, sorted=False).indices.permute(1, 0) # (32, ncells)
18 | cells = cells.flatten().contiguous() # (32 * ncells,)
19 | cells = cells.unique(sorted=False)
20 | return cells, scores
21 |
22 | def generate_candidate_eids(self, Q, ncells):
23 | cells, scores = self.get_cells(Q, ncells)
24 |
25 | eids, cell_lengths = self.ivf.lookup(cells) # eids = (packedlen,) lengths = (32 * ncells,)
26 | eids = eids.long()
27 | if self.use_gpu:
28 | eids = eids.cuda()
29 | return eids, scores
30 |
31 | def generate_candidate_pids(self, Q, ncells):
32 | cells, scores = self.get_cells(Q, ncells)
33 |
34 | pids, cell_lengths = self.ivf.lookup(cells)
35 | if self.use_gpu:
36 | pids = pids.cuda()
37 | return pids, scores
38 |
39 | def generate_candidate_scores(self, Q, eids):
40 | E = self.lookup_eids(eids)
41 | if self.use_gpu:
42 | E = E.cuda()
43 | return (Q.unsqueeze(0) @ E.unsqueeze(2)).squeeze(-1).T
44 |
45 | def generate_candidates(self, config, Q):
46 | ncells = config.ncells
47 |
48 | assert isinstance(self.ivf, StridedTensor)
49 |
50 | Q = Q.squeeze(0)
51 | if self.use_gpu:
52 | Q = Q.cuda().half()
53 | assert Q.dim() == 2
54 |
55 | pids, centroid_scores = self.generate_candidate_pids(Q, ncells)
56 |
57 | sorter = pids.sort()
58 | pids = sorter.values
59 |
60 | pids, pids_counts = torch.unique_consecutive(pids, return_counts=True)
61 | if self.use_gpu:
62 | pids, pids_counts = pids.cuda(), pids_counts.cuda()
63 |
64 | return pids, centroid_scores
65 |
--------------------------------------------------------------------------------
/third_party/colbert/search/index_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ujson
3 | import torch
4 | import numpy as np
5 | import tqdm
6 |
7 | from third_party.colbert.utils.utils import lengths2offsets, print_message, dotdict, flatten
8 | from third_party.colbert.indexing.codecs.residual import ResidualCodec
9 | from third_party.colbert.indexing.utils import optimize_ivf
10 | from third_party.colbert.search.strided_tensor import StridedTensor
11 |
12 |
13 | class IndexLoader:
14 | def __init__(self, index_path, use_gpu=True, load_index_with_mmap=False):
15 | self.index_path = index_path
16 | self.use_gpu = use_gpu
17 | self.load_index_with_mmap = load_index_with_mmap
18 |
19 | self._load_codec()
20 | self._load_ivf()
21 |
22 | self._load_doclens()
23 | self._load_embeddings()
24 |
25 | def _load_codec(self):
26 | print_message(f"#> Loading codec...")
27 | self.codec = ResidualCodec.load(self.index_path)
28 |
29 | def _load_ivf(self):
30 | print_message(f"#> Loading IVF...")
31 |
32 | if os.path.exists(os.path.join(self.index_path, "ivf.pid.pt")):
33 | ivf, ivf_lengths = torch.load(os.path.join(self.index_path, "ivf.pid.pt"), map_location='cpu')
34 | else:
35 | assert os.path.exists(os.path.join(self.index_path, "ivf.pt"))
36 | ivf, ivf_lengths = torch.load(os.path.join(self.index_path, "ivf.pt"), map_location='cpu')
37 | ivf, ivf_lengths = optimize_ivf(ivf, ivf_lengths, self.index_path)
38 |
39 | if False:
40 | ivf = ivf.tolist()
41 | ivf = [ivf[offset:endpos] for offset, endpos in lengths2offsets(ivf_lengths)]
42 | else:
43 | # ivf, ivf_lengths = ivf.cuda(), torch.LongTensor(ivf_lengths).cuda() # FIXME: REMOVE THIS LINE!
44 | ivf = StridedTensor(ivf, ivf_lengths, use_gpu=self.use_gpu)
45 |
46 | self.ivf = ivf
47 |
48 | def _load_doclens(self):
49 | doclens = []
50 |
51 | print_message("#> Loading doclens...")
52 |
53 | for chunk_idx in tqdm.tqdm(range(self.num_chunks)):
54 | with open(os.path.join(self.index_path, f'doclens.{chunk_idx}.json')) as f:
55 | chunk_doclens = ujson.load(f)
56 | doclens.extend(chunk_doclens)
57 |
58 | self.doclens = torch.tensor(doclens)
59 |
60 | def _load_embeddings(self):
61 | self.embeddings = ResidualCodec.Embeddings.load_chunks(
62 | self.index_path,
63 | range(self.num_chunks),
64 | self.num_embeddings,
65 | self.load_index_with_mmap,
66 | )
67 |
68 | @property
69 | def metadata(self):
70 | try:
71 | self._metadata
72 | except:
73 | with open(os.path.join(self.index_path, 'metadata.json')) as f:
74 | self._metadata = ujson.load(f)
75 |
76 | return self._metadata
77 |
78 | @property
79 | def config(self):
80 | raise NotImplementedError() # load from dict at metadata['config']
81 |
82 | @property
83 | def num_chunks(self):
84 | # EVENTUALLY: If num_chunks doesn't exist (i.e., old index), fall back to counting doclens.*.json files.
85 | return self.metadata['num_chunks']
86 |
87 | @property
88 | def num_embeddings(self):
89 | # EVENTUALLY: If num_embeddings doesn't exist (i.e., old index), sum the values in doclens.*.json files.
90 | return self.metadata['num_embeddings']
91 |
92 |
--------------------------------------------------------------------------------
/third_party/colbert/search/segmented_lookup.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include
5 | #include
6 |
7 | typedef struct {
8 | int tid;
9 | pthread_mutex_t* mutex;
10 | std::queue* queue;
11 |
12 | int64_t ndocs;
13 | int64_t noutputs;
14 | int64_t dim;
15 |
16 | void* input;
17 | int64_t* lengths;
18 | int64_t* offsets;
19 | int64_t* cumulative_lengths;
20 |
21 | void* output;
22 | } lookup_args_t;
23 |
24 | template
25 | void* lookup(void* args) {
26 | lookup_args_t* lookup_args = (lookup_args_t*)args;
27 |
28 | int64_t* lengths = lookup_args->lengths;
29 | int64_t* cumulative_lengths = lookup_args->cumulative_lengths;
30 | int64_t* offsets = lookup_args->offsets;
31 | int64_t dim = lookup_args->dim;
32 |
33 | T* input = static_cast(lookup_args->input);
34 | T* output = static_cast(lookup_args->output);
35 |
36 | while (1) {
37 | pthread_mutex_lock(lookup_args->mutex);
38 | if (lookup_args->queue->empty()) {
39 | pthread_mutex_unlock(lookup_args->mutex);
40 | return NULL;
41 | }
42 | int i = lookup_args->queue->front();
43 | lookup_args->queue->pop();
44 | pthread_mutex_unlock(lookup_args->mutex);
45 |
46 | std::memcpy(output + (cumulative_lengths[i] * dim),
47 | input + (offsets[i] * dim), lengths[i] * dim * sizeof(T));
48 | }
49 | }
50 |
51 | template
52 | torch::Tensor segmented_lookup_impl(const torch::Tensor input,
53 | const torch::Tensor pids,
54 | const torch::Tensor lengths,
55 | const torch::Tensor offsets) {
56 | auto lengths_a = lengths.data_ptr();
57 | auto offsets_a = offsets.data_ptr();
58 |
59 | int64_t ndocs = pids.size(0);
60 | int64_t noutputs = std::accumulate(lengths_a, lengths_a + ndocs, 0);
61 |
62 | int nthreads = at::get_num_threads();
63 |
64 | int64_t dim;
65 | torch::Tensor output;
66 |
67 | if (input.dim() == 1) {
68 | dim = 1;
69 | output = torch::zeros({noutputs}, input.options());
70 | } else {
71 | assert(input.dim() == 2);
72 | dim = input.size(1);
73 | output = torch::zeros({noutputs, dim}, input.options());
74 | }
75 |
76 | int64_t cumulative_lengths[ndocs + 1];
77 | cumulative_lengths[0] = 0;
78 | std::partial_sum(lengths_a, lengths_a + ndocs, cumulative_lengths + 1);
79 |
80 | pthread_mutex_t mutex;
81 | int rc = pthread_mutex_init(&mutex, NULL);
82 | if (rc) {
83 | fprintf(stderr, "Unable to init mutex: %d\n", rc);
84 | }
85 |
86 | std::queue queue;
87 | for (int i = 0; i < ndocs; i++) {
88 | queue.push(i);
89 | }
90 |
91 | pthread_t threads[nthreads];
92 | lookup_args_t args[nthreads];
93 | for (int i = 0; i < nthreads; i++) {
94 | args[i].tid = i;
95 | args[i].mutex = &mutex;
96 | args[i].queue = &queue;
97 |
98 | args[i].ndocs = ndocs;
99 | args[i].noutputs = noutputs;
100 | args[i].dim = dim;
101 |
102 | args[i].input = (void*)input.data_ptr();
103 | args[i].lengths = lengths_a;
104 | args[i].offsets = offsets_a;
105 | args[i].cumulative_lengths = cumulative_lengths;
106 |
107 | args[i].output = (void*)output.data_ptr();
108 |
109 | rc = pthread_create(&threads[i], NULL, lookup, (void*)&args[i]);
110 | if (rc) {
111 | fprintf(stderr, "Unable to create thread %d: %d\n", i, rc);
112 | }
113 | }
114 |
115 | for (int i = 0; i < nthreads; i++) {
116 | pthread_join(threads[i], NULL);
117 | }
118 |
119 | rc = pthread_mutex_destroy(&mutex);
120 | if (rc) {
121 | fprintf(stderr, "Unable to destroy mutex: %d\n", rc);
122 | }
123 |
124 | return output;
125 | }
126 |
127 | torch::Tensor segmented_lookup(const torch::Tensor input,
128 | const torch::Tensor pids,
129 | const torch::Tensor lengths,
130 | const torch::Tensor offsets) {
131 | if (input.dtype() == torch::kUInt8) {
132 | return segmented_lookup_impl(input, pids, lengths, offsets);
133 | } else if (input.dtype() == torch::kInt32) {
134 | return segmented_lookup_impl(input, pids, lengths, offsets);
135 | } else if (input.dtype() == torch::kInt64) {
136 | return segmented_lookup_impl(input, pids, lengths, offsets);
137 | } else if (input.dtype() == torch::kFloat32) {
138 | return segmented_lookup_impl(input, pids, lengths, offsets);
139 | } else if (input.dtype() == torch::kFloat16) {
140 | return segmented_lookup_impl(input, pids, lengths, offsets);
141 | } else {
142 | assert(false);
143 | }
144 | }
145 |
146 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
147 | m.def("segmented_lookup_cpp", &segmented_lookup, "Segmented lookup");
148 | }
149 |
--------------------------------------------------------------------------------
/third_party/colbert/search/strided_tensor_core.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 |
4 | import numpy as np
5 |
6 | from third_party.colbert.utils.utils import flatten
7 |
8 |
9 | """
10 | import line_profiler
11 | import atexit
12 | profile = line_profiler.LineProfiler()
13 | atexit.register(profile.print_stats)
14 | """
15 |
16 |
17 | class StridedTensorCore:
18 | # # @profile
19 | def __init__(self, packed_tensor, lengths, dim=None, use_gpu=True):
20 | self.dim = dim
21 | self.tensor = packed_tensor
22 | self.inner_dims = self.tensor.size()[1:]
23 | self.use_gpu = use_gpu
24 |
25 | self.lengths = lengths.long() if torch.is_tensor(lengths) else torch.LongTensor(lengths)
26 |
27 | self.strides = _select_strides(self.lengths, [.5, .75, .9, .95]) + [self.lengths.max().item()]
28 | self.max_stride = self.strides[-1]
29 |
30 | zero = torch.zeros(1, dtype=torch.long, device=self.lengths.device)
31 | self.offsets = torch.cat((zero, torch.cumsum(self.lengths, dim=0)))
32 |
33 | if self.offsets[-2] + self.max_stride > self.tensor.size(0):
34 | # if self.tensor.size(0) > 10_000_000:
35 | # print("#> WARNING: StridedTensor has to add padding, internally, to a large tensor.")
36 | # print("#> WARNING: Consider doing this padding in advance to save memory!")
37 |
38 | padding = torch.zeros(self.max_stride, *self.inner_dims, dtype=self.tensor.dtype, device=self.tensor.device)
39 | self.tensor = torch.cat((self.tensor, padding))
40 |
41 | self.views = {stride: _create_view(self.tensor, stride, self.inner_dims) for stride in self.strides}
42 |
43 | @classmethod
44 | def from_packed_tensor(cls, tensor, lengths):
45 | return cls(tensor, lengths)
46 |
47 | @classmethod
48 | def from_padded_tensor(cls, tensor, mask):
49 | pass
50 |
51 | @classmethod
52 | def from_nested_list(cls, lst):
53 | flat_lst = flatten(lst)
54 |
55 | tensor = torch.Tensor(flat_lst)
56 | lengths = [len(sublst) for sublst in lst]
57 |
58 | return cls(tensor, lengths, dim=0)
59 |
60 | @classmethod
61 | def from_tensors_list(cls, tensors):
62 | # torch.cat(tensors)
63 | # lengths.
64 | # cls(tensor, lengths)
65 | raise NotImplementedError()
66 |
67 | def as_packed_tensor(self, return_offsets=False):
68 | unpadded_packed_tensor = self.tensor # [:self.offsets[-1]]
69 |
70 | return_vals = [unpadded_packed_tensor, self.lengths]
71 |
72 | if return_offsets:
73 | return_vals.append(self.offsets)
74 |
75 | return tuple(return_vals)
76 |
77 | # # @profile
78 | def as_padded_tensor(self):
79 | if self.use_gpu:
80 | view = _create_view(self.tensor.cuda(), self.max_stride, self.inner_dims)[self.offsets[:-1]]
81 | mask = _create_mask(self.lengths.cuda(), self.max_stride, like=view, use_gpu=self.use_gpu)
82 | else:
83 | #import pdb
84 | #pdb.set_trace()
85 | view = _create_view(self.tensor, self.max_stride, self.inner_dims)
86 | view = view[self.offsets[:-1]]
87 | mask = _create_mask(self.lengths, self.max_stride, like=view, use_gpu=self.use_gpu)
88 |
89 | return view, mask
90 |
91 | def as_tensors_list(self):
92 | raise NotImplementedError()
93 |
94 |
95 |
96 | def _select_strides(lengths, quantiles):
97 | if lengths.size(0) < 5_000:
98 | return _get_quantiles(lengths, quantiles)
99 |
100 | sample = torch.randint(0, lengths.size(0), size=(2_000,))
101 |
102 | return _get_quantiles(lengths[sample], quantiles)
103 |
104 | def _get_quantiles(lengths, quantiles):
105 | return torch.quantile(lengths.float(), torch.tensor(quantiles, device=lengths.device)).int().tolist()
106 |
107 |
108 | def _create_view(tensor, stride, inner_dims):
109 | outdim = tensor.size(0) - stride + 1
110 | size = (outdim, stride, *inner_dims)
111 |
112 | inner_dim_prod = int(np.prod(inner_dims))
113 | multidim_stride = [inner_dim_prod, inner_dim_prod] + [1] * len(inner_dims)
114 |
115 | return torch.as_strided(tensor, size=size, stride=multidim_stride)
116 |
117 |
118 | def _create_mask(lengths, stride, like=None, use_gpu=True):
119 | if use_gpu:
120 | mask = torch.arange(stride).cuda() + 1
121 | mask = mask.unsqueeze(0) <= lengths.cuda().unsqueeze(-1)
122 | else:
123 | mask = torch.arange(stride) + 1
124 | mask = mask.unsqueeze(0) <= lengths.unsqueeze(-1)
125 |
126 | if like is not None:
127 | for _ in range(like.dim() - mask.dim()):
128 | mask = mask.unsqueeze(-1)
129 |
130 | return mask
131 |
--------------------------------------------------------------------------------
/third_party/colbert/tests/e2e_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from collections import namedtuple
4 | from datasets import load_dataset
5 | from third_party.utility.utils.dpr import has_answer, DPR_normalize
6 | import tqdm
7 |
8 | from third_party.colbert import Indexer, Searcher
9 | from third_party.colbert.infra import ColBERTConfig, RunConfig, Run
10 |
11 | SquadExample = namedtuple("SquadExample", "id title context question answers")
12 |
13 |
14 | def build_index_and_init_searcher(checkpoint, collection, experiment_dir):
15 | nbits = 1 # encode each dimension with 1 bits
16 | doc_maxlen = 180 # truncate passages at 180 tokens
17 | experiment = f"e2etest.nbits={nbits}"
18 |
19 | with Run().context(RunConfig(nranks=1)):
20 | config = ColBERTConfig(
21 | doc_maxlen=doc_maxlen,
22 | nbits=nbits,
23 | root=experiment_dir,
24 | experiment=experiment,
25 | )
26 | indexer = Indexer(checkpoint, config=config)
27 | indexer.index(name=experiment, collection=collection, overwrite=True)
28 |
29 | config = ColBERTConfig(
30 | root=experiment_dir,
31 | experiment=experiment,
32 | )
33 | searcher = Searcher(
34 | index=experiment,
35 | config=config,
36 | )
37 |
38 | return searcher
39 |
40 |
41 | def success_at_k(searcher, examples, k):
42 | scores = []
43 | for ex in tqdm.tqdm(examples):
44 | scores.append(evaluate_retrieval_example(searcher, ex, k))
45 | return sum(scores) / len(scores)
46 |
47 |
48 | def evaluate_retrieval_example(searcher, ex, k):
49 | results = searcher.search(ex.question, k=k)
50 | for passage_id, passage_rank, passage_score in zip(*results):
51 | passage = searcher.collection[passage_id]
52 | score = has_answer([DPR_normalize(ans) for ans in ex.answers], passage)
53 | if score:
54 | return 1
55 | return 0
56 |
57 |
58 | def get_squad_split(squad, split="validation"):
59 | fields = squad[split].features
60 | data = zip(*[squad[split][field] for field in fields])
61 | return [
62 | SquadExample(eid, title, context, question, answers["text"])
63 | for eid, title, context, question, answers in data
64 | ]
65 |
66 |
67 | def main(args):
68 | checkpoint = args.checkpoint
69 | collection = args.collection
70 | experiment_dir = args.expdir
71 |
72 | # Start the test
73 | k = 5
74 | searcher = build_index_and_init_searcher(checkpoint, collection, experiment_dir)
75 |
76 | squad = load_dataset("squad")
77 | squad_dev = get_squad_split(squad)
78 | success_rate = success_at_k(searcher, squad_dev[:1000], k)
79 | assert success_rate > 0.93, f"success rate at {success_rate} is lower than expected"
80 | print(f"test passed with success rate {success_rate}")
81 |
82 |
83 | if __name__ == "__main__":
84 | parser = argparse.ArgumentParser(description="Start end-to-end test.")
85 | parser.add_argument(
86 | "--checkpoint", type=str, required=True, help="Model checkpoint"
87 | )
88 | parser.add_argument(
89 | "--collection", type=str, required=True, help="Path to collection"
90 | )
91 | parser.add_argument(
92 | "--expdir", type=str, required=True, help="Experiment directory"
93 | )
94 | args = parser.parse_args()
95 | main(args)
96 |
--------------------------------------------------------------------------------
/third_party/colbert/tests/index_coalesce_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import ujson
4 | from tqdm import tqdm
5 |
6 | import argparse
7 |
8 | def main(args):
9 | # TODO: compare residual, codes, and doclens
10 | # Get the number of chunks in the multi-file index
11 | single_path = args.single
12 | multi_path = args.multi
13 |
14 | # Get num_chunks and num_embeddings from metadata
15 | filepath = os.path.join(multi_path, 'metadata.json')
16 | with open(filepath, 'r') as f:
17 | metadata = ujson.load(f)
18 | num_chunks = metadata['num_chunks']
19 | print(f"Num_chunks = {num_chunks}")
20 |
21 | num_embeddings = metadata['num_embeddings']
22 | print(f"Num_embeddings = {num_embeddings}")
23 |
24 | dim = metadata['config']['dim']
25 | nbits = metadata['config']['nbits']
26 |
27 | ## Load and compare doclens ##
28 | # load multi-file doclens
29 | print("Loading doclens from multi-file index")
30 | multi_doclens = []
31 | for chunk_idx in tqdm(range(num_chunks)):
32 | with open(os.path.join(multi_path, f"doclens.{chunk_idx}.json"), 'r') as f:
33 | chunk = ujson.load(f)
34 | multi_doclens.extend(chunk)
35 |
36 | # load single-file doclens
37 | print("Loading doclens from single-file index")
38 | single_doclens = []
39 | for _ in tqdm(range(1)):
40 | with open(os.path.join(single_path, "doclens.0.json"), 'r') as f:
41 | single_doclens = ujson.load(f)
42 |
43 | # compare doclens
44 | if (multi_doclens != single_doclens):
45 | print("Doclens do not match!")
46 | print("Multi-file doclens size = {}".format(len(multi_doclens)))
47 | print("Single-file doclens size = {}".format(len(single_doclens)))
48 | else:
49 | print("Doclens match")
50 |
51 | ## Load and compare codes ##
52 | # load multi-file codes
53 | print("Loading codes from multi-file index")
54 | multi_codes = torch.empty(num_embeddings, dtype=torch.int32)
55 | offset = 0
56 | for chunk_idx in tqdm(range(num_chunks)):
57 | chunk = torch.load(os.path.join(multi_path, f"{chunk_idx}.codes.pt"))
58 | endpos = offset + chunk.size(0)
59 | multi_codes[offset:endpos] = chunk
60 | offset = endpos
61 |
62 | # load single-file codes
63 | print("Loading codes from single-file index")
64 | single_codes = []
65 | for _ in tqdm(range(1)):
66 | single_codes = torch.load(os.path.join(single_path, "0.codes.pt"))
67 |
68 | if (single_codes.size(0) != num_embeddings):
69 | print("Codes are the wrong size!")
70 |
71 | # compare codes
72 | if torch.equal(multi_codes, single_codes):
73 | print("Codes match")
74 | else:
75 | print("Codes do not match!")
76 |
77 | ## Load and compare residuals ##
78 | # load multi-file residuals
79 | print("Loading residuals from multi-file index")
80 | multi_residuals = torch.empty(num_embeddings, dim // 8 * nbits, dtype=torch.uint8)
81 | offset = 0
82 | for chunk_idx in tqdm(range(num_chunks)):
83 | chunk = torch.load(os.path.join(multi_path, f"{chunk_idx}.residuals.pt"))
84 | endpos = offset + chunk.size(0)
85 | multi_residuals[offset:endpos] = chunk
86 | offset = endpos
87 |
88 | # load single-file residuals
89 | print("Loading residuals from single-file index")
90 | single_residuals = []
91 | for _ in tqdm(range(1)):
92 | single_residuals = torch.load(os.path.join(single_path, "0.residuals.pt"))
93 |
94 | # compare residuals
95 | if torch.equal(multi_residuals, single_residuals):
96 | print("Residuals match")
97 | else:
98 | print("Residuals do not match!")
99 |
100 | if __name__ == "__main__":
101 | parser = argparse.ArgumentParser(description="Compare single-file and multi-file indexes.")
102 | parser.add_argument(
103 | "--single", type=str, required=True, help="Path to single-file index."
104 | )
105 | parser.add_argument(
106 | "--multi", type=str, required=True, help="Path to multi-file index."
107 | )
108 |
109 | args = parser.parse_args()
110 | main(args)
111 |
112 | print("Exiting test")
113 |
--------------------------------------------------------------------------------
/third_party/colbert/trainer.py:
--------------------------------------------------------------------------------
1 | from third_party.colbert.infra.run import Run
2 | from third_party.colbert.infra.launcher import Launcher
3 | from third_party.colbert.infra.config import ColBERTConfig, RunConfig
4 |
5 | from third_party.colbert.training.training import train
6 |
7 |
8 | class Trainer:
9 | def __init__(self, triples, queries, collection, config=None):
10 | self.config = ColBERTConfig.from_existing(config, Run().config)
11 |
12 | self.triples = triples
13 | self.queries = queries
14 | self.collection = collection
15 |
16 | def configure(self, **kw_args):
17 | self.config.configure(**kw_args)
18 |
19 | def train(self, checkpoint='bert-base-uncased'):
20 | """
21 | Note that config.checkpoint is ignored. Only the supplied checkpoint here is used.
22 | """
23 |
24 | # Resources don't come from the config object. They come from the input parameters.
25 | # TODO: After the API stabilizes, make this "self.config.assign()" to emphasize this distinction.
26 | self.configure(triples=self.triples, queries=self.queries, collection=self.collection)
27 | self.configure(checkpoint=checkpoint)
28 |
29 | launcher = Launcher(train)
30 |
31 | self._best_checkpoint_path = launcher.launch(self.config, self.triples, self.queries, self.collection)
32 |
33 |
34 | def best_checkpoint_path(self):
35 | return self._best_checkpoint_path
36 |
37 |
--------------------------------------------------------------------------------
/third_party/colbert/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/training/__init__.py
--------------------------------------------------------------------------------
/third_party/colbert/training/eager_batcher.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ujson
3 |
4 | from functools import partial
5 | from third_party.colbert.utils.utils import print_message
6 | from third_party.colbert.modeling.tokenization import QueryTokenizer, DocTokenizer, tensorize_triples
7 |
8 | from third_party.colbert.utils.runs import Run
9 |
10 |
11 | class EagerBatcher():
12 | def __init__(self, args, rank=0, nranks=1):
13 | self.rank, self.nranks = rank, nranks
14 | self.bsize, self.accumsteps = args.bsize, args.accumsteps
15 |
16 | self.query_tokenizer = QueryTokenizer(args.query_maxlen)
17 | self.doc_tokenizer = DocTokenizer(args.doc_maxlen)
18 | self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer)
19 |
20 | self.triples_path = args.triples
21 | self._reset_triples()
22 |
23 | def _reset_triples(self):
24 | self.reader = open(self.triples_path, mode='r', encoding="utf-8")
25 | self.position = 0
26 |
27 | def __iter__(self):
28 | return self
29 |
30 | def __next__(self):
31 | queries, positives, negatives = [], [], []
32 |
33 | for line_idx, line in zip(range(self.bsize * self.nranks), self.reader):
34 | if (self.position + line_idx) % self.nranks != self.rank:
35 | continue
36 |
37 | query, pos, neg = line.strip().split('\t')
38 |
39 | queries.append(query)
40 | positives.append(pos)
41 | negatives.append(neg)
42 |
43 | self.position += line_idx + 1
44 |
45 | if len(queries) < self.bsize:
46 | raise StopIteration
47 |
48 | return self.collate(queries, positives, negatives)
49 |
50 | def collate(self, queries, positives, negatives):
51 | assert len(queries) == len(positives) == len(negatives) == self.bsize
52 |
53 | return self.tensorize_triples(queries, positives, negatives, self.bsize // self.accumsteps)
54 |
55 | def skip_to_batch(self, batch_idx, intended_batch_size):
56 | self._reset_triples()
57 |
58 | Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.')
59 |
60 | _ = [self.reader.readline() for _ in range(batch_idx * intended_batch_size)]
61 |
62 | return None
63 |
--------------------------------------------------------------------------------
/third_party/colbert/training/lazy_batcher.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ujson
3 |
4 | from functools import partial
5 | from third_party.colbert.infra.config.config import ColBERTConfig
6 | from third_party.colbert.utils.utils import print_message, zipstar
7 | from third_party.colbert.modeling.tokenization import QueryTokenizer, DocTokenizer, tensorize_triples
8 | from third_party.colbert.evaluation.loaders import load_collection
9 |
10 | from third_party.colbert.data.collection import Collection
11 | from third_party.colbert.data.queries import Queries
12 | from third_party.colbert.data.examples import Examples
13 |
14 | # from third_party.colbert.utils.runs import Run
15 |
16 |
17 | class LazyBatcher():
18 | def __init__(self, config: ColBERTConfig, triples, queries, collection, rank=0, nranks=1):
19 | self.bsize, self.accumsteps = config.bsize, config.accumsteps
20 | self.nway = config.nway
21 |
22 | self.query_tokenizer = QueryTokenizer(config)
23 | self.doc_tokenizer = DocTokenizer(config)
24 | self.tensorize_triples = partial(tensorize_triples, self.query_tokenizer, self.doc_tokenizer)
25 | self.position = 0
26 |
27 | self.triples = Examples.cast(triples, nway=self.nway).tolist(rank, nranks)
28 | self.queries = Queries.cast(queries)
29 | self.collection = Collection.cast(collection)
30 | assert len(self.triples) > 0, "Received no triples on which to train."
31 | assert len(self.queries) > 0, "Received no queries on which to train."
32 | assert len(self.collection) > 0, "Received no collection on which to train."
33 |
34 | def __iter__(self):
35 | return self
36 |
37 | def __len__(self):
38 | return len(self.triples)
39 |
40 | def __next__(self):
41 | offset, endpos = self.position, min(self.position + self.bsize, len(self.triples))
42 | self.position = endpos
43 |
44 | if offset + self.bsize > len(self.triples):
45 | raise StopIteration
46 |
47 | all_queries, all_passages, all_scores = [], [], []
48 |
49 | for position in range(offset, endpos):
50 | query, *pids = self.triples[position]
51 | pids = pids[:self.nway]
52 |
53 | query = self.queries[query]
54 |
55 | try:
56 | pids, scores = zipstar(pids)
57 | except:
58 | scores = []
59 |
60 | passages = [self.collection[pid] for pid in pids]
61 |
62 | all_queries.append(query)
63 | all_passages.extend(passages)
64 | all_scores.extend(scores)
65 |
66 | assert len(all_scores) in [0, len(all_passages)], len(all_scores)
67 |
68 | return self.collate(all_queries, all_passages, all_scores)
69 |
70 | def collate(self, queries, passages, scores):
71 | assert len(queries) == self.bsize
72 | assert len(passages) == self.nway * self.bsize
73 |
74 | return self.tensorize_triples(queries, passages, scores, self.bsize // self.accumsteps, self.nway)
75 |
76 | # def skip_to_batch(self, batch_idx, intended_batch_size):
77 | # Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.')
78 | # self.position = intended_batch_size * batch_idx
79 |
--------------------------------------------------------------------------------
/third_party/colbert/training/rerank_batcher.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ujson
3 |
4 | from functools import partial
5 | from third_party.colbert.infra.config.config import ColBERTConfig
6 | from third_party.colbert.utils.utils import flatten, print_message, zipstar
7 | from third_party.colbert.modeling.reranker.tokenizer import RerankerTokenizer
8 |
9 | from third_party.colbert.data.collection import Collection
10 | from third_party.colbert.data.queries import Queries
11 | from third_party.colbert.data.examples import Examples
12 |
13 | # from third_party.colbert.utils.runs import Run
14 |
15 |
16 | class RerankBatcher():
17 | def __init__(self, config: ColBERTConfig, triples, queries, collection, rank=0, nranks=1):
18 | self.bsize, self.accumsteps = config.bsize, config.accumsteps
19 | self.nway = config.nway
20 |
21 | assert self.accumsteps == 1, "The tensorizer doesn't support larger accumsteps yet --- but it's easy to add."
22 |
23 | self.tokenizer = RerankerTokenizer(total_maxlen=config.doc_maxlen, base=config.checkpoint)
24 | self.position = 0
25 |
26 | self.triples = Examples.cast(triples, nway=self.nway).tolist(rank, nranks)
27 | self.queries = Queries.cast(queries)
28 | self.collection = Collection.cast(collection)
29 |
30 | assert len(self.triples) > 0, "Received no triples on which to train."
31 | assert len(self.queries) > 0, "Received no queries on which to train."
32 | assert len(self.collection) > 0, "Received no collection on which to train."
33 |
34 | def __iter__(self):
35 | return self
36 |
37 | def __len__(self):
38 | return len(self.triples)
39 |
40 | def __next__(self):
41 | offset, endpos = self.position, min(self.position + self.bsize, len(self.triples))
42 | self.position = endpos
43 |
44 | if offset + self.bsize > len(self.triples):
45 | raise StopIteration
46 |
47 | all_queries, all_passages, all_scores = [], [], []
48 |
49 | for position in range(offset, endpos):
50 | query, *pids = self.triples[position]
51 | pids = pids[:self.nway]
52 |
53 | query = self.queries[query]
54 |
55 | try:
56 | pids, scores = zipstar(pids)
57 | except:
58 | scores = []
59 |
60 | passages = [self.collection[pid] for pid in pids]
61 |
62 | all_queries.append(query)
63 | all_passages.extend(passages)
64 | all_scores.extend(scores)
65 |
66 | assert len(all_scores) in [0, len(all_passages)], len(all_scores)
67 |
68 | return self.collate(all_queries, all_passages, all_scores)
69 |
70 | def collate(self, queries, passages, scores):
71 | assert len(queries) == self.bsize
72 | assert len(passages) == self.nway * self.bsize
73 |
74 | queries = flatten([[query] * self.nway for query in queries])
75 | return [(self.tokenizer.tensorize(queries, passages), scores)]
76 |
77 | # def skip_to_batch(self, batch_idx, intended_batch_size):
78 | # Run.warn(f'Skipping to batch #{batch_idx} (with intended_batch_size = {intended_batch_size}) for training.')
79 | # self.position = intended_batch_size * batch_idx
80 |
--------------------------------------------------------------------------------
/third_party/colbert/training/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | # from third_party.colbert.utils.runs import Run
5 | from third_party.colbert.utils.utils import print_message, save_checkpoint
6 | from third_party.colbert.parameters import SAVED_CHECKPOINTS
7 | from third_party.colbert.infra.run import Run
8 |
9 |
10 | def print_progress(scores):
11 | positive_avg, negative_avg = round(scores[:, 0].mean().item(), 2), round(scores[:, 1].mean().item(), 2)
12 | print("#>>> ", positive_avg, negative_avg, '\t\t|\t\t', positive_avg - negative_avg)
13 |
14 |
15 | def manage_checkpoints(args, colbert, optimizer, batch_idx, savepath=None, consumed_all_triples=False):
16 | # arguments = dict(args)
17 |
18 | # TODO: Call provenance() on the values that support it??
19 |
20 | checkpoints_path = savepath or os.path.join(Run().path_, 'checkpoints')
21 | name = None
22 |
23 | try:
24 | save = colbert.save
25 | except:
26 | save = colbert.module.save
27 |
28 | if not os.path.exists(checkpoints_path):
29 | os.makedirs(checkpoints_path)
30 |
31 | path_save = None
32 |
33 | if consumed_all_triples or (batch_idx % 2000 == 0):
34 | # name = os.path.join(path, "colbert.dnn")
35 | # save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments)
36 | path_save = os.path.join(checkpoints_path, "colbert")
37 |
38 | if batch_idx in SAVED_CHECKPOINTS:
39 | # name = os.path.join(path, "colbert-{}.dnn".format(batch_idx))
40 | # save_checkpoint(name, 0, batch_idx, colbert, optimizer, arguments)
41 | path_save = os.path.join(checkpoints_path, f"colbert-{batch_idx}")
42 |
43 | if path_save:
44 | print(f"#> Saving a checkpoint to {path_save} ..")
45 |
46 | checkpoint = {}
47 | checkpoint['batch'] = batch_idx
48 | # checkpoint['epoch'] = 0
49 | # checkpoint['model_state_dict'] = model.state_dict()
50 | # checkpoint['optimizer_state_dict'] = optimizer.state_dict()
51 | # checkpoint['arguments'] = arguments
52 |
53 | save(path_save)
54 |
55 | return path_save
56 |
--------------------------------------------------------------------------------
/third_party/colbert/utilities/annotate_em.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import sys
4 | import git
5 | import tqdm
6 | import ujson
7 | import random
8 |
9 | from argparse import ArgumentParser
10 | from multiprocessing import Pool
11 |
12 | from third_party.colbert.utils.utils import groupby_first_item, print_message
13 | from third_party.utility.utils.qa_loaders import load_qas_, load_collection_
14 | from third_party.utility.utils.save_metadata import format_metadata, get_metadata
15 | from third_party.utility.evaluate.annotate_EM_helpers import *
16 |
17 | from third_party.colbert.infra.run import Run
18 | from third_party.colbert.data.collection import Collection
19 | from third_party.colbert.data.ranking import Ranking
20 |
21 |
22 | class AnnotateEM:
23 | def __init__(self, collection, qas):
24 | # TODO: These should just be Queries! But Queries needs to support looking up answers as qid2answers below.
25 | qas = load_qas_(qas)
26 | collection = Collection.cast(collection) # .tolist() #load_collection_(collection, retain_titles=True)
27 |
28 | self.parallel_pool = Pool(30)
29 |
30 | print_message('#> Tokenize the answers in the Q&As in parallel...')
31 | qas = list(self.parallel_pool.map(tokenize_all_answers, qas))
32 |
33 | qid2answers = {qid: tok_answers for qid, _, tok_answers in qas}
34 | assert len(qas) == len(qid2answers), (len(qas), len(qid2answers))
35 |
36 | self.qas, self.collection = qas, collection
37 | self.qid2answers = qid2answers
38 |
39 | def annotate(self, ranking):
40 | rankings = Ranking.cast(ranking)
41 |
42 | # print(len(rankings), rankings[0])
43 |
44 | print_message('#> Lookup passages from PIDs...')
45 | expanded_rankings = [(qid, pid, rank, self.collection[pid], self.qid2answers[qid])
46 | for qid, pid, rank, *_ in rankings.tolist()]
47 |
48 | print_message('#> Assign labels in parallel...')
49 | labeled_rankings = list(self.parallel_pool.map(assign_label_to_passage, enumerate(expanded_rankings)))
50 |
51 | # Dump output.
52 | self.qid2rankings = groupby_first_item(labeled_rankings)
53 |
54 | self.num_judged_queries, self.num_ranked_queries = check_sizes(self.qid2answers, self.qid2rankings)
55 |
56 | # Evaluation metrics and depths.
57 | self.success, self.counts = self._compute_labels(self.qid2answers, self.qid2rankings)
58 |
59 | print(rankings.provenance(), self.success)
60 |
61 | return Ranking(data=self.qid2rankings, provenance=("AnnotateEM", rankings.provenance()))
62 |
63 | def _compute_labels(self, qid2answers, qid2rankings):
64 | cutoffs = [1, 5, 10, 20, 30, 50, 100, 1000, 'all']
65 | success = {cutoff: 0.0 for cutoff in cutoffs}
66 | counts = {cutoff: 0.0 for cutoff in cutoffs}
67 |
68 | for qid in qid2answers:
69 | if qid not in qid2rankings:
70 | continue
71 |
72 | prev_rank = 0 # ranks should start at one (i.e., and not zero)
73 | labels = []
74 |
75 | for pid, rank, label in qid2rankings[qid]:
76 | assert rank == prev_rank+1, (qid, pid, (prev_rank, rank))
77 | prev_rank = rank
78 |
79 | labels.append(label)
80 |
81 | for cutoff in cutoffs:
82 | if cutoff != 'all':
83 | success[cutoff] += sum(labels[:cutoff]) > 0
84 | counts[cutoff] += sum(labels[:cutoff])
85 | else:
86 | success[cutoff] += sum(labels) > 0
87 | counts[cutoff] += sum(labels)
88 |
89 | return success, counts
90 |
91 | def save(self, new_path):
92 | print_message("#> Dumping output to", new_path, "...")
93 |
94 | Ranking(data=self.qid2rankings).save(new_path)
95 |
96 | # Dump metrics.
97 | with Run().open(f'{new_path}.metrics', 'w') as f:
98 | d = {'num_ranked_queries': self.num_ranked_queries, 'num_judged_queries': self.num_judged_queries}
99 |
100 | extra = '__WARNING' if self.num_judged_queries != self.num_ranked_queries else ''
101 | d[f'success{extra}'] = {k: v / self.num_judged_queries for k, v in self.success.items()}
102 | d[f'counts{extra}'] = {k: v / self.num_judged_queries for k, v in self.counts.items()}
103 | # d['arguments'] = get_metadata(args) # TODO: Need arguments...
104 |
105 | f.write(format_metadata(d) + '\n')
106 |
107 |
108 | if __name__ == '__main__':
109 | r = sys.argv[2]
110 |
111 | a = AnnotateEM(collection='/dfs/scratch0/okhattab/OpenQA/collection.tsv',
112 | qas=sys.argv[1])
113 | a.annotate(ranking=r)
114 |
--------------------------------------------------------------------------------
/third_party/colbert/utilities/create_triples.py:
--------------------------------------------------------------------------------
1 | import random
2 | from third_party.colbert.infra.provenance import Provenance
3 |
4 | from third_party.utility.utils.save_metadata import save_metadata
5 | from third_party.utility.supervision.triples import sample_for_query
6 |
7 | from third_party.colbert.utils.utils import print_message
8 |
9 | from third_party.colbert.data.ranking import Ranking
10 | from third_party.colbert.data.examples import Examples
11 |
12 | MAX_NUM_TRIPLES = 40_000_000
13 |
14 |
15 | class Triples:
16 | def __init__(self, ranking, seed=12345):
17 | random.seed(seed) # TODO: Use internal RNG instead..
18 | self.seed = seed
19 |
20 | ranking = Ranking.cast(ranking)
21 | self.ranking_provenance = ranking.provenance()
22 | self.qid2rankings = ranking.todict()
23 |
24 | def create(self, positives, depth):
25 | assert all(len(x) == 2 for x in positives)
26 | assert all(maxBest <= maxDepth for maxBest, maxDepth in positives), positives
27 |
28 | self.positives = positives
29 | self.depth = depth
30 |
31 | Triples = []
32 | NonEmptyQIDs = 0
33 |
34 | for processing_idx, qid in enumerate(self.qid2rankings):
35 | l = sample_for_query(qid, self.qid2rankings[qid], positives, depth, False, None)
36 | NonEmptyQIDs += (len(l) > 0)
37 | Triples.extend(l)
38 |
39 | if processing_idx % (10_000) == 0:
40 | print_message(f"#> Done with {processing_idx+1} questions!\t\t "
41 | f"{str(len(Triples) / 1000)}k triples for {NonEmptyQIDs} unqiue QIDs.")
42 |
43 | print_message(f"#> Sub-sample the triples (if > {MAX_NUM_TRIPLES})..")
44 | print_message(f"#> len(Triples) = {len(Triples)}")
45 |
46 | if len(Triples) > MAX_NUM_TRIPLES:
47 | Triples = random.sample(Triples, MAX_NUM_TRIPLES)
48 |
49 | ### Prepare the triples ###
50 | print_message("#> Shuffling the triples...")
51 | random.shuffle(Triples)
52 |
53 | self.Triples = Examples(data=Triples)
54 |
55 | return Triples
56 |
57 | def save(self, new_path):
58 | provenance = Provenance()
59 | provenance.source = 'Triples::create'
60 | provenance.seed = self.seed
61 | provenance.positives = self.positives
62 | provenance.depth = self.depth
63 | provenance.ranking = self.ranking_provenance
64 |
65 | Examples(data=self.Triples, provenance=provenance).save(new_path)
66 |
--------------------------------------------------------------------------------
/third_party/colbert/utilities/minicorpus.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | from third_party.colbert.utils.utils import create_directory
5 |
6 | from third_party.colbert.data.collection import Collection
7 | from third_party.colbert.data.queries import Queries
8 | from third_party.colbert.data.ranking import Ranking
9 |
10 |
11 | def sample_minicorpus(name, factor, topk=30, maxdev=3000):
12 | """
13 | Factor:
14 | * nano=1
15 | * micro=10
16 | * mini=100
17 | * small=100 with topk=100
18 | * medium=150 with topk=300
19 | """
20 |
21 | random.seed(12345)
22 |
23 | # Load collection
24 | collection = Collection(path='/dfs/scratch0/okhattab/OpenQA/collection.tsv')
25 |
26 | # Load train and dev queries
27 | qas_train = Queries(path='/dfs/scratch0/okhattab/OpenQA/NQ/train/qas.json').qas()
28 | qas_dev = Queries(path='/dfs/scratch0/okhattab/OpenQA/NQ/dev/qas.json').qas()
29 |
30 | # Load train and dev C3 rankings
31 | ranking_train = Ranking(path='/dfs/scratch0/okhattab/OpenQA/NQ/train/rankings/C3.tsv.annotated').todict()
32 | ranking_dev = Ranking(path='/dfs/scratch0/okhattab/OpenQA/NQ/dev/rankings/C3.tsv.annotated').todict()
33 |
34 | # Sample NT and ND queries from each, keep only the top-k passages for those
35 | sample_train = random.sample(list(qas_train.keys()), min(len(qas_train.keys()), 300*factor))
36 | sample_dev = random.sample(list(qas_dev.keys()), min(len(qas_dev.keys()), maxdev, 30*factor))
37 |
38 | train_pids = [pid for qid in sample_train for qpids in ranking_train[qid][:topk] for pid in qpids]
39 | dev_pids = [pid for qid in sample_dev for qpids in ranking_dev[qid][:topk] for pid in qpids]
40 |
41 | sample_pids = sorted(list(set(train_pids + dev_pids)))
42 | print(f'len(sample_pids) = {len(sample_pids)}')
43 |
44 | # Save the new query sets: train and dev
45 | ROOT = f'/future/u/okhattab/root/unit/data/NQ-{name}'
46 |
47 | create_directory(os.path.join(ROOT, 'train'))
48 | create_directory(os.path.join(ROOT, 'dev'))
49 |
50 | new_train = Queries(data={qid: qas_train[qid] for qid in sample_train})
51 | new_train.save(os.path.join(ROOT, 'train/questions.tsv'))
52 | new_train.save_qas(os.path.join(ROOT, 'train/qas.json'))
53 |
54 | new_dev = Queries(data={qid: qas_dev[qid] for qid in sample_dev})
55 | new_dev.save(os.path.join(ROOT, 'dev/questions.tsv'))
56 | new_dev.save_qas(os.path.join(ROOT, 'dev/qas.json'))
57 |
58 | # Save the new collection
59 | print(f"Saving to {os.path.join(ROOT, 'collection.tsv')}")
60 | Collection(data=[collection[pid] for pid in sample_pids]).save(os.path.join(ROOT, 'collection.tsv'))
61 |
62 | print('#> Done!')
63 |
64 |
65 | if __name__ == '__main__':
66 | sample_minicorpus('medium', 150, topk=300)
67 |
--------------------------------------------------------------------------------
/third_party/colbert/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/colbert/utils/__init__.py
--------------------------------------------------------------------------------
/third_party/colbert/utils/amp.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from contextlib import contextmanager
4 | from third_party.colbert.utils.utils import NullContextManager
5 |
6 |
7 | class MixedPrecisionManager():
8 | def __init__(self, activated):
9 | self.activated = activated
10 |
11 | if self.activated:
12 | self.scaler = torch.cuda.amp.GradScaler()
13 |
14 | def context(self):
15 | return torch.cuda.amp.autocast() if self.activated else NullContextManager()
16 |
17 | def backward(self, loss):
18 | if self.activated:
19 | self.scaler.scale(loss).backward()
20 | else:
21 | loss.backward()
22 |
23 | def step(self, colbert, optimizer, scheduler=None):
24 | if self.activated:
25 | self.scaler.unscale_(optimizer)
26 | torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0, error_if_nonfinite=False)
27 |
28 | self.scaler.step(optimizer)
29 | self.scaler.update()
30 | else:
31 | torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0)
32 | optimizer.step()
33 |
34 | if scheduler is not None:
35 | scheduler.step()
36 |
37 | optimizer.zero_grad()
38 |
--------------------------------------------------------------------------------
/third_party/colbert/utils/coalesce.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import torch
4 | from tqdm import tqdm
5 | import ujson
6 | import shutil
7 |
8 |
9 | def main(args):
10 | in_file = args.input
11 | out_file = args.output
12 |
13 | # Get num_chunks from metadata
14 | filepath = os.path.join(in_file, 'metadata.json')
15 | with open(filepath, 'r') as f:
16 | metadata = ujson.load(f)
17 | num_chunks = metadata['num_chunks']
18 | print(f"Num_chunks = {num_chunks}")
19 |
20 | # Create output dir if not already created
21 | if not os.path.exists(out_file):
22 | os.makedirs(out_file)
23 |
24 | ## Coalesce doclens ##
25 | print("Coalescing doclens files...")
26 |
27 | temp = []
28 | # read into one large list
29 | for i in tqdm(range(num_chunks)):
30 | filepath = os.path.join(in_file, f'doclens.{i}.json')
31 | with open(filepath, 'r') as f:
32 | chunk = ujson.load(f)
33 | temp.extend(chunk)
34 |
35 | # write to output json
36 | filepath = os.path.join(out_file, 'doclens.0.json')
37 | with open(filepath, 'w') as f:
38 | ujson.dump(temp, f)
39 |
40 | ## Coalesce codes ##
41 | print("Coalescing codes files...")
42 |
43 | temp = torch.empty(0, dtype=torch.int32)
44 | # read into one large tensor
45 | for i in tqdm(range(num_chunks)):
46 | filepath = os.path.join(in_file, f'{i}.codes.pt')
47 | chunk = torch.load(filepath)
48 | temp = torch.cat((temp, chunk))
49 |
50 | # save length of index
51 | index_len = temp.size()[0]
52 |
53 | # write to output tensor
54 | filepath = os.path.join(out_file, '0.codes.pt')
55 | torch.save(temp, filepath)
56 |
57 | ## Coalesce residuals ##
58 | print("Coalescing residuals files...")
59 |
60 | # Allocate all the memory needed in the beginning. Starting from torch.empty() and concatenating repeatedly results in excessive memory use and is much much slower.
61 | temp = torch.zeros(((metadata['num_embeddings'], int(metadata['config']['dim'] * metadata['config']['nbits'] // 8))), dtype=torch.uint8)
62 | cur_offset = 0
63 | # read into one large tensor
64 | for i in tqdm(range(num_chunks)):
65 | filepath = os.path.join(in_file, f'{i}.residuals.pt')
66 | chunk = torch.load(filepath)
67 | temp[cur_offset : cur_offset + len(chunk):] = chunk
68 | cur_offset += len(chunk)
69 |
70 | print("Saving residuals to output directory (this may take a few minutes)...")
71 |
72 | # write to output tensor
73 | filepath = os.path.join(out_file, '0.residuals.pt')
74 | torch.save(temp, filepath)
75 |
76 | # save metadata.json
77 | metadata['num_chunks'] = 1
78 | filepath = os.path.join(out_file, 'metadata.json')
79 | with open(filepath, 'w') as f:
80 | ujson.dump(metadata, f, indent=4)
81 |
82 | metadata_0 = {}
83 | metadata_0["num_embeddings"] = metadata["num_embeddings"]
84 | metadata_0["passage_offset"] = 0
85 | metadata_0["embedding_offset"] = 0
86 |
87 | filepath = os.path.join(in_file, str(num_chunks-1) + '.metadata.json')
88 | with open(filepath, 'r') as f:
89 | metadata_last = ujson.load(f)
90 | metadata_0["num_passages"] = int(metadata_last["num_passages"]) + int(metadata_last["passage_offset"])
91 |
92 | filepath = os.path.join(out_file, '0.metadata.json')
93 | with open(filepath, 'w') as f:
94 | ujson.dump(metadata_0, f, indent=4)
95 |
96 | filepath = os.path.join(in_file, 'plan.json')
97 | with open(filepath, 'r') as f:
98 | plan = ujson.load(f)
99 | plan['num_chunks'] = 1
100 | filepath = os.path.join(out_file, 'plan.json')
101 | with open(filepath, 'w') as f:
102 | ujson.dump(plan, f, indent=4)
103 |
104 | other_files = ['avg_residual.pt', 'buckets.pt', 'centroids.pt', 'ivf.pt', 'ivf.pid.pt']
105 | for filename in other_files:
106 | filepath = os.path.join(in_file, filename)
107 | if os.path.isfile(filepath):
108 | shutil.copy(filepath, out_file)
109 |
110 | print("Saved index to output directory {}.".format(out_file))
111 | print("Number of embeddings = {}".format(index_len))
112 |
113 |
114 | if __name__ == "__main__":
115 | parser = argparse.ArgumentParser(description="Coalesce multi-file index into a single file.")
116 | parser.add_argument(
117 | "--input", type=str, required=True, help="Path to input index directory"
118 | )
119 | parser.add_argument(
120 | "--output", type=str, required=True, help="Path to output index directory"
121 | )
122 |
123 | args = parser.parse_args()
124 | main(args)
125 |
--------------------------------------------------------------------------------
/third_party/colbert/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import torch
4 | import numpy as np
5 | import datetime
6 |
7 | ALREADY_INITALIZED = False
8 |
9 | # TODO: Consider torch.distributed.is_initialized() instead
10 |
11 |
12 | def init(rank):
13 | nranks = 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE'])
14 | nranks = max(1, nranks)
15 | is_distributed = (nranks > 1) or ('WORLD_SIZE' in os.environ)
16 |
17 | global ALREADY_INITALIZED
18 | if ALREADY_INITALIZED:
19 | return nranks, is_distributed
20 |
21 | ALREADY_INITALIZED = True
22 |
23 | if is_distributed and torch.cuda.is_available():
24 | num_gpus = torch.cuda.device_count()
25 | print(f'nranks = {nranks} \t num_gpus = {num_gpus} \t device={rank % num_gpus}')
26 |
27 | torch.cuda.set_device(rank % num_gpus)
28 |
29 | # increase the timeout for indexing large datasets with unbalanced split sizes across devices
30 | torch.distributed.init_process_group(backend='nccl', init_method='env://', timeout=datetime.timedelta(hours=1))
31 |
32 | return nranks, is_distributed
33 |
34 |
35 | def barrier(rank):
36 | nranks = 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE'])
37 | nranks = max(1, nranks)
38 |
39 | if rank >= 0 and nranks > 1:
40 | torch.distributed.barrier(device_ids=[rank % torch.cuda.device_count()])
41 |
--------------------------------------------------------------------------------
/third_party/colbert/utils/logging.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import ujson
4 | # import mlflow
5 | import traceback
6 |
7 | # from torch.utils.tensorboard import SummaryWriter
8 | from third_party.colbert.utils.utils import print_message, create_directory
9 |
10 |
11 | class Logger():
12 | def __init__(self, rank, run):
13 | self.rank = rank
14 | self.is_main = self.rank in [-1, 0]
15 | self.run = run
16 | self.logs_path = os.path.join(self.run.path, "logs/")
17 |
18 | if self.is_main:
19 | # self._init_mlflow()
20 | # self.initialized_tensorboard = False
21 | create_directory(self.logs_path)
22 |
23 | # def _init_mlflow(self):
24 | # mlflow.set_tracking_uri('file://' + os.path.join(self.run.experiments_root, "logs/mlruns/"))
25 | # mlflow.set_experiment('/'.join([self.run.experiment, self.run.script]))
26 |
27 | # mlflow.set_tag('experiment', self.run.experiment)
28 | # mlflow.set_tag('name', self.run.name)
29 | # mlflow.set_tag('path', self.run.path)
30 |
31 | # def _init_tensorboard(self):
32 | # root = os.path.join(self.run.experiments_root, "logs/tensorboard/")
33 | # logdir = '__'.join([self.run.experiment, self.run.script, self.run.name])
34 | # logdir = os.path.join(root, logdir)
35 |
36 | # self.writer = SummaryWriter(log_dir=logdir)
37 | # self.initialized_tensorboard = True
38 |
39 | def _log_exception(self, etype, value, tb):
40 | if not self.is_main:
41 | return
42 |
43 | output_path = os.path.join(self.logs_path, 'exception.txt')
44 | trace = ''.join(traceback.format_exception(etype, value, tb)) + '\n'
45 | print_message(trace, '\n\n')
46 |
47 | self.log_new_artifact(output_path, trace)
48 |
49 | def _log_all_artifacts(self):
50 | if not self.is_main:
51 | return
52 |
53 | # mlflow.log_artifacts(self.logs_path)
54 |
55 | def _log_args(self, args):
56 | if not self.is_main:
57 | return
58 |
59 | # for key in vars(args):
60 | # value = getattr(args, key)
61 | # if type(value) in [int, float, str, bool]:
62 | # mlflow.log_param(key, value)
63 |
64 | # with open(os.path.join(self.logs_path, 'args.json'), 'w') as output_metadata:
65 | # # TODO: Call provenance() on the values that support it
66 | # ujson.dump(args.input_arguments.__dict__, output_metadata, indent=4)
67 | # output_metadata.write('\n')
68 |
69 | with open(os.path.join(self.logs_path, 'args.txt'), 'w') as output_metadata:
70 | output_metadata.write(' '.join(sys.argv) + '\n')
71 |
72 | def log_metric(self, name, value, step, log_to_mlflow=True):
73 | if not self.is_main:
74 | return
75 |
76 | # if not self.initialized_tensorboard:
77 | # self._init_tensorboard()
78 |
79 | # if log_to_mlflow:
80 | # mlflow.log_metric(name, value, step=step)
81 | # self.writer.add_scalar(name, value, step)
82 |
83 | def log_new_artifact(self, path, content):
84 | with open(path, 'w') as f:
85 | f.write(content)
86 |
87 | # mlflow.log_artifact(path)
88 |
89 | def warn(self, *args):
90 | msg = print_message('[WARNING]', '\t', *args)
91 |
92 | with open(os.path.join(self.logs_path, 'warnings.txt'), 'a') as output_metadata:
93 | output_metadata.write(msg + '\n\n\n')
94 |
95 | def info_all(self, *args):
96 | print_message('[' + str(self.rank) + ']', '\t', *args)
97 |
98 | def info(self, *args):
99 | if self.is_main:
100 | print_message(*args)
101 |
--------------------------------------------------------------------------------
/third_party/colbert/utils/runs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import __main__
5 | import traceback
6 | # import mlflow
7 |
8 | import third_party.colbert.utils.distributed as distributed
9 |
10 | from contextlib import contextmanager
11 | from third_party.colbert.utils.logging import Logger
12 | from third_party.colbert.utils.utils import timestamp, create_directory, print_message
13 |
14 |
15 | class _RunManager():
16 | def __init__(self):
17 | self.experiments_root = None
18 | self.experiment = None
19 | self.path = None
20 | self.script = self._get_script_name()
21 | self.name = self._generate_default_run_name()
22 | self.original_name = self.name
23 | self.exit_status = 'FINISHED'
24 |
25 | self._logger = None
26 | self.start_time = time.time()
27 |
28 | def init(self, rank, root, experiment, name):
29 | assert '/' not in experiment, experiment
30 | assert '/' not in name, name
31 |
32 | self.experiments_root = os.path.abspath(root)
33 | self.experiment = experiment
34 | self.name = name
35 | self.path = os.path.join(self.experiments_root, self.experiment, self.script, self.name)
36 |
37 | if rank < 1:
38 | if os.path.exists(self.path):
39 | print('\n\n')
40 | print_message("It seems that ", self.path, " already exists.")
41 | print_message("Do you want to overwrite it? \t yes/no \n")
42 |
43 | # TODO: This should timeout and exit (i.e., fail) given no response for 60 seconds.
44 |
45 | response = input()
46 | if response.strip() != 'yes':
47 | assert not os.path.exists(self.path), self.path
48 | else:
49 | create_directory(self.path)
50 |
51 | distributed.barrier(rank)
52 |
53 | self._logger = Logger(rank, self)
54 | self._log_args = self._logger._log_args
55 | self.warn = self._logger.warn
56 | self.info = self._logger.info
57 | self.info_all = self._logger.info_all
58 | self.log_metric = self._logger.log_metric
59 | self.log_new_artifact = self._logger.log_new_artifact
60 |
61 | def _generate_default_run_name(self):
62 | return timestamp()
63 |
64 | def _get_script_name(self):
65 | return os.path.basename(__main__.__file__) if '__file__' in dir(__main__) else 'none'
66 |
67 | @contextmanager
68 | def context(self, consider_failed_if_interrupted=True):
69 | try:
70 | yield
71 |
72 | except KeyboardInterrupt as ex:
73 | print('\n\nInterrupted\n\n')
74 | self._logger._log_exception(ex.__class__, ex, ex.__traceback__)
75 | self._logger._log_all_artifacts()
76 |
77 | if consider_failed_if_interrupted:
78 | self.exit_status = 'KILLED' # mlflow.entities.RunStatus.KILLED
79 |
80 | sys.exit(128 + 2)
81 |
82 | except Exception as ex:
83 | self._logger._log_exception(ex.__class__, ex, ex.__traceback__)
84 | self._logger._log_all_artifacts()
85 |
86 | self.exit_status = 'FAILED' # mlflow.entities.RunStatus.FAILED
87 |
88 | raise ex
89 |
90 | finally:
91 | total_seconds = str(time.time() - self.start_time) + '\n'
92 | original_name = str(self.original_name)
93 | name = str(self.name)
94 |
95 | self.log_new_artifact(os.path.join(self._logger.logs_path, 'elapsed.txt'), total_seconds)
96 | self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.original.txt'), original_name)
97 | self.log_new_artifact(os.path.join(self._logger.logs_path, 'name.txt'), name)
98 |
99 | self._logger._log_all_artifacts()
100 |
101 | # mlflow.end_run(status=self.exit_status)
102 |
103 |
104 | Run = _RunManager()
105 |
--------------------------------------------------------------------------------
/third_party/colbert_ai.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | LICENSE
2 | MANIFEST.in
3 | README.md
4 | setup.py
5 | colbert/__init__.py
6 | colbert/index.py
7 | colbert/index_updater.py
8 | colbert/indexer.py
9 | colbert/parameters.py
10 | colbert/searcher.py
11 | colbert/trainer.py
12 | colbert/data/__init__.py
13 | colbert/data/collection.py
14 | colbert/data/dataset.py
15 | colbert/data/examples.py
16 | colbert/data/queries.py
17 | colbert/data/ranking.py
18 | colbert/evaluation/__init__.py
19 | colbert/evaluation/load_model.py
20 | colbert/evaluation/loaders.py
21 | colbert/evaluation/metrics.py
22 | colbert/indexing/__init__.py
23 | colbert/indexing/collection_encoder.py
24 | colbert/indexing/collection_indexer.py
25 | colbert/indexing/index_manager.py
26 | colbert/indexing/index_saver.py
27 | colbert/indexing/loaders.py
28 | colbert/indexing/utils.py
29 | colbert/indexing/codecs/__init__.py
30 | colbert/indexing/codecs/decompress_residuals.cpp
31 | colbert/indexing/codecs/decompress_residuals.cu
32 | colbert/indexing/codecs/packbits.cpp
33 | colbert/indexing/codecs/packbits.cu
34 | colbert/indexing/codecs/residual.py
35 | colbert/indexing/codecs/residual_embeddings.py
36 | colbert/indexing/codecs/residual_embeddings_strided.py
37 | colbert/infra/__init__.py
38 | colbert/infra/launcher.py
39 | colbert/infra/provenance.py
40 | colbert/infra/run.py
41 | colbert/infra/config/__init__.py
42 | colbert/infra/config/base_config.py
43 | colbert/infra/config/config.py
44 | colbert/infra/config/core_config.py
45 | colbert/infra/config/settings.py
46 | colbert/modeling/__init__.py
47 | colbert/modeling/base_colbert.py
48 | colbert/modeling/checkpoint.py
49 | colbert/modeling/colbert.py
50 | colbert/modeling/hf_colbert.py
51 | colbert/modeling/segmented_maxsim.cpp
52 | colbert/modeling/reranker/__init__.py
53 | colbert/modeling/reranker/electra.py
54 | colbert/modeling/reranker/tokenizer.py
55 | colbert/modeling/tokenization/__init__.py
56 | colbert/modeling/tokenization/doc_tokenization.py
57 | colbert/modeling/tokenization/query_tokenization.py
58 | colbert/modeling/tokenization/utils.py
59 | colbert/ranking/__init__.py
60 | colbert/search/__init__.py
61 | colbert/search/candidate_generation.py
62 | colbert/search/decompress_residuals.cpp
63 | colbert/search/filter_pids.cpp
64 | colbert/search/index_loader.py
65 | colbert/search/index_storage.py
66 | colbert/search/segmented_lookup.cpp
67 | colbert/search/strided_tensor.py
68 | colbert/search/strided_tensor_core.py
69 | colbert/training/__init__.py
70 | colbert/training/eager_batcher.py
71 | colbert/training/lazy_batcher.py
72 | colbert/training/rerank_batcher.py
73 | colbert/training/training.py
74 | colbert/training/utils.py
75 | colbert/utils/__init__.py
76 | colbert/utils/amp.py
77 | colbert/utils/coalesce.py
78 | colbert/utils/distributed.py
79 | colbert/utils/logging.py
80 | colbert/utils/parser.py
81 | colbert/utils/runs.py
82 | colbert/utils/utils.py
83 | colbert_ai.egg-info/PKG-INFO
84 | colbert_ai.egg-info/SOURCES.txt
85 | colbert_ai.egg-info/dependency_links.txt
86 | colbert_ai.egg-info/requires.txt
87 | colbert_ai.egg-info/top_level.txt
88 | utility/__init__.py
89 | utility/evaluate/__init__.py
90 | utility/evaluate/annotate_EM.py
91 | utility/evaluate/annotate_EM_helpers.py
92 | utility/evaluate/evaluate_lotte_rankings.py
93 | utility/evaluate/msmarco_passages.py
94 | utility/preprocess/__init__.py
95 | utility/preprocess/docs2passages.py
96 | utility/preprocess/queries_split.py
97 | utility/rankings/__init__.py
98 | utility/rankings/dev_subsample.py
99 | utility/rankings/merge.py
100 | utility/rankings/split_by_offset.py
101 | utility/rankings/split_by_queries.py
102 | utility/rankings/tune.py
103 | utility/supervision/__init__.py
104 | utility/supervision/self_training.py
105 | utility/supervision/triples.py
106 | utility/utils/__init__.py
107 | utility/utils/dpr.py
108 | utility/utils/qa_loaders.py
109 | utility/utils/save_metadata.py
--------------------------------------------------------------------------------
/third_party/colbert_ai.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/third_party/colbert_ai.egg-info/requires.txt:
--------------------------------------------------------------------------------
1 | bitarray
2 | datasets
3 | flask
4 | git-python
5 | python-dotenv
6 | ninja
7 | scipy
8 | tqdm
9 | transformers
10 | ujson
11 |
12 | [faiss-cpu]
13 | faiss-cpu>=1.7.0
14 |
15 | [faiss-gpu]
16 | faiss-gpu>=1.7.0
17 |
18 | [torch]
19 | torch==1.13.1
20 |
--------------------------------------------------------------------------------
/third_party/colbert_ai.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | colbert
2 | utility
3 |
--------------------------------------------------------------------------------
/third_party/conda_env.yml:
--------------------------------------------------------------------------------
1 | name: colbert
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - python=3.8
9 | - faiss-gpu
10 | - pip=21.0
11 | - pytorch=1.13.1
12 | - torchaudio=0.13.1
13 | - torchvision=0.14.1
14 | - cudatoolkit=11.3
15 | - gcc=9.4.0
16 | - gxx=9.4.0
17 | - pip:
18 | - bitarray
19 | - datasets
20 | - gitpython
21 | - jupyter
22 | - jupyterlab
23 | - ninja
24 | - scipy
25 | - tqdm
26 | - transformers
27 | - ujson
28 | - flask
29 | - python-dotenv
30 |
--------------------------------------------------------------------------------
/third_party/conda_env_cpu.yml:
--------------------------------------------------------------------------------
1 | name: colbert
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - python=3.8
8 | - faiss-cpu
9 | - pip=21.0
10 | - pytorch-cpu=1.13
11 | - pip:
12 | - bitarray
13 | - datasets
14 | - gitpython
15 | - ninja
16 | - scipy
17 | - tqdm
18 | - transformers
19 | - ujson
20 | - flask
21 | - python-dotenv
22 |
--------------------------------------------------------------------------------
/third_party/server.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, render_template, request
2 | from functools import lru_cache
3 | import math
4 | import os
5 | from dotenv import load_dotenv
6 |
7 | from third_party.colbert.infra import Run, RunConfig, ColBERTConfig
8 | from third_party.colbert import Searcher
9 |
10 | load_dotenv()
11 |
12 | INDEX_NAME = os.getenv("INDEX_NAME")
13 | INDEX_ROOT = os.getenv("INDEX_ROOT")
14 | app = Flask(__name__)
15 |
16 | searcher = Searcher(index=INDEX_NAME, index_root=INDEX_ROOT)
17 | counter = {"api" : 0}
18 |
19 | @lru_cache(maxsize=1000000)
20 | def api_search_query(query, k):
21 | print(f"Query={query}")
22 | if k == None: k = 10
23 | k = min(int(k), 100)
24 | pids, ranks, scores = searcher.search(query, k=100)
25 | pids, ranks, scores = pids[:k], ranks[:k], scores[:k]
26 | passages = [searcher.collection[pid] for pid in pids]
27 | probs = [math.exp(score) for score in scores]
28 | probs = [prob / sum(probs) for prob in probs]
29 | topk = []
30 | for pid, rank, score, prob in zip(pids, ranks, scores, probs):
31 | text = searcher.collection[pid]
32 | d = {'text': text, 'pid': pid, 'rank': rank, 'score': score, 'prob': prob}
33 | topk.append(d)
34 | topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid'])))
35 | return {"query" : query, "topk": topk}
36 |
37 | @app.route("/api/search", methods=["GET"])
38 | def api_search():
39 | if request.method == "GET":
40 | counter["api"] += 1
41 | print("API request count:", counter["api"])
42 | return api_search_query(request.args.get("query"), request.args.get("k"))
43 | else:
44 | return ('', 405)
45 |
46 | if __name__ == "__main__":
47 | app.run("0.0.0.0", int(os.getenv("PORT")))
48 |
49 |
--------------------------------------------------------------------------------
/third_party/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r") as f:
4 | long_description = f.read()
5 |
6 | package_data = {
7 | "": ["*.cpp", "*.cu"],
8 | }
9 |
10 | setuptools.setup(
11 | name="colbert-ai",
12 | version="0.2.20",
13 | author="Omar Khattab",
14 | author_email="okhattab@stanford.edu",
15 | description="Efficient and Effective Passage Search via Contextualized Late Interaction over BERT",
16 | long_description=long_description,
17 | long_description_content_type="text/markdown",
18 | url="https://github.com/stanford-futuredata/ColBERT",
19 | packages=setuptools.find_packages(),
20 | python_requires=">=3.8",
21 | install_requires=[
22 | "bitarray",
23 | "datasets",
24 | "flask",
25 | "git-python",
26 | "python-dotenv",
27 | "ninja",
28 | "scipy",
29 | "tqdm",
30 | "transformers",
31 | "ujson",
32 | ],
33 | extras_require={
34 | "faiss-gpu": ["faiss-gpu>=1.7.0"],
35 | "faiss-cpu": ["faiss-cpu>=1.7.0"],
36 | "torch": ["torch==1.13.1"],
37 | },
38 | include_package_data=True,
39 | package_data=package_data,
40 | )
41 |
--------------------------------------------------------------------------------
/third_party/utility/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/utility/__init__.py
--------------------------------------------------------------------------------
/third_party/utility/evaluate/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/utility/evaluate/__init__.py
--------------------------------------------------------------------------------
/third_party/utility/evaluate/annotate_EM.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import git
4 | import tqdm
5 | import ujson
6 | import random
7 |
8 | from argparse import ArgumentParser
9 | from multiprocessing import Pool
10 |
11 | from third_party.colbert.utils.utils import print_message, load_ranking, groupby_first_item
12 | from third_party.utility.utils.qa_loaders import load_qas_, load_collection_
13 | from third_party.utility.utils.save_metadata import format_metadata, get_metadata
14 | from third_party.utility.evaluate.annotate_EM_helpers import *
15 |
16 |
17 | # TODO: Tokenize passages in advance, especially if the ranked list is long! This requires changes to the has_answer input, slightly.
18 |
19 | def main(args):
20 | qas = load_qas_(args.qas)
21 | collection = load_collection_(args.collection, retain_titles=True)
22 | rankings = load_ranking(args.ranking)
23 | parallel_pool = Pool(30)
24 |
25 | print_message('#> Tokenize the answers in the Q&As in parallel...')
26 | qas = list(parallel_pool.map(tokenize_all_answers, qas))
27 |
28 | qid2answers = {qid: tok_answers for qid, _, tok_answers in qas}
29 | assert len(qas) == len(qid2answers), (len(qas), len(qid2answers))
30 |
31 | print_message('#> Lookup passages from PIDs...')
32 | expanded_rankings = [(qid, pid, rank, collection[pid], qid2answers[qid])
33 | for qid, pid, rank, *_ in rankings]
34 |
35 | print_message('#> Assign labels in parallel...')
36 | labeled_rankings = list(parallel_pool.map(assign_label_to_passage, enumerate(expanded_rankings)))
37 |
38 | # Dump output.
39 | print_message("#> Dumping output to", args.output, "...")
40 | qid2rankings = groupby_first_item(labeled_rankings)
41 |
42 | num_judged_queries, num_ranked_queries = check_sizes(qid2answers, qid2rankings)
43 |
44 | # Evaluation metrics and depths.
45 | success, counts = compute_and_write_labels(args.output, qid2answers, qid2rankings)
46 |
47 | # Dump metrics.
48 | with open(args.output_metrics, 'w') as f:
49 | d = {'num_ranked_queries': num_ranked_queries, 'num_judged_queries': num_judged_queries}
50 |
51 | extra = '__WARNING' if num_judged_queries != num_ranked_queries else ''
52 | d[f'success{extra}'] = {k: v / num_judged_queries for k, v in success.items()}
53 | d[f'counts{extra}'] = {k: v / num_judged_queries for k, v in counts.items()}
54 | d['arguments'] = get_metadata(args)
55 |
56 | f.write(format_metadata(d) + '\n')
57 |
58 | print('\n\n')
59 | print(args.output)
60 | print(args.output_metrics)
61 | print("#> Done\n")
62 |
63 |
64 | if __name__ == "__main__":
65 | random.seed(12345)
66 |
67 | parser = ArgumentParser(description='.')
68 |
69 | # Input / Output Arguments
70 | parser.add_argument('--qas', dest='qas', required=True, type=str)
71 | parser.add_argument('--collection', dest='collection', required=True, type=str)
72 | parser.add_argument('--ranking', dest='ranking', required=True, type=str)
73 |
74 | args = parser.parse_args()
75 |
76 | args.output = f'{args.ranking}.annotated'
77 | args.output_metrics = f'{args.ranking}.annotated.metrics'
78 |
79 | assert not os.path.exists(args.output), args.output
80 |
81 | main(args)
82 |
--------------------------------------------------------------------------------
/third_party/utility/evaluate/annotate_EM_helpers.py:
--------------------------------------------------------------------------------
1 | from third_party.colbert.utils.utils import print_message
2 | from third_party.utility.utils.dpr import DPR_normalize, has_answer
3 |
4 |
5 | def tokenize_all_answers(args):
6 | qid, question, answers = args
7 | return qid, question, [DPR_normalize(ans) for ans in answers]
8 |
9 |
10 | def assign_label_to_passage(args):
11 | idx, (qid, pid, rank, passage, tokenized_answers) = args
12 |
13 | if idx % (1*1000*1000) == 0:
14 | print(idx)
15 |
16 | return qid, pid, rank, has_answer(tokenized_answers, passage)
17 |
18 |
19 | def check_sizes(qid2answers, qid2rankings):
20 | num_judged_queries = len(qid2answers)
21 | num_ranked_queries = len(qid2rankings)
22 |
23 | print_message('num_judged_queries =', num_judged_queries)
24 | print_message('num_ranked_queries =', num_ranked_queries)
25 |
26 | if num_judged_queries != num_ranked_queries:
27 | assert num_ranked_queries <= num_judged_queries
28 |
29 | print('\n\n')
30 | print_message('[WARNING] num_judged_queries != num_ranked_queries')
31 | print('\n\n')
32 |
33 | return num_judged_queries, num_ranked_queries
34 |
35 |
36 | def compute_and_write_labels(output_path, qid2answers, qid2rankings):
37 | cutoffs = [1, 5, 10, 20, 30, 50, 100, 1000, 'all']
38 | success = {cutoff: 0.0 for cutoff in cutoffs}
39 | counts = {cutoff: 0.0 for cutoff in cutoffs}
40 |
41 | with open(output_path, 'w') as f:
42 | for qid in qid2answers:
43 | if qid not in qid2rankings:
44 | continue
45 |
46 | prev_rank = 0 # ranks should start at one (i.e., and not zero)
47 | labels = []
48 |
49 | for pid, rank, label in qid2rankings[qid]:
50 | assert rank == prev_rank+1, (qid, pid, (prev_rank, rank))
51 | prev_rank = rank
52 |
53 | labels.append(label)
54 | line = '\t'.join(map(str, [qid, pid, rank, int(label)])) + '\n'
55 | f.write(line)
56 |
57 | for cutoff in cutoffs:
58 | if cutoff != 'all':
59 | success[cutoff] += sum(labels[:cutoff]) > 0
60 | counts[cutoff] += sum(labels[:cutoff])
61 | else:
62 | success[cutoff] += sum(labels) > 0
63 | counts[cutoff] += sum(labels)
64 |
65 | return success, counts
66 |
67 |
68 | # def dump_metrics(f, nqueries, cutoffs, success, counts):
69 | # for cutoff in cutoffs:
70 | # success_log = "#> P@{} = {}".format(cutoff, success[cutoff] / nqueries)
71 | # counts_log = "#> D@{} = {}".format(cutoff, counts[cutoff] / nqueries)
72 | # print('\n'.join([success_log, counts_log]) + '\n')
73 |
74 | # f.write('\n'.join([success_log, counts_log]) + '\n\n')
75 |
--------------------------------------------------------------------------------
/third_party/utility/evaluate/evaluate_lotte_rankings.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import defaultdict
3 | import jsonlines
4 | import os
5 | import sys
6 |
7 |
8 | def evaluate_dataset(query_type, dataset, split, k, data_rootdir, rankings_rootdir):
9 | data_path = os.path.join(data_rootdir, dataset, split)
10 | rankings_path = os.path.join(
11 | rankings_rootdir, split, f"{dataset}.{query_type}.ranking.tsv"
12 | )
13 | if not os.path.exists(rankings_path):
14 | print(f"[query_type={query_type}, dataset={dataset}] Success@{k}: ???")
15 | return
16 | rankings = defaultdict(list)
17 | with open(rankings_path, "r") as f:
18 | for line in f:
19 | items = line.strip().split("\t")
20 | qid, pid, rank = items[:3]
21 | qid = int(qid)
22 | pid = int(pid)
23 | rank = int(rank)
24 | rankings[qid].append(pid)
25 | assert rank == len(rankings[qid])
26 |
27 | success = 0
28 | qas_path = os.path.join(data_path, f"qas.{query_type}.jsonl")
29 |
30 | num_total_qids = 0
31 | with jsonlines.open(qas_path, mode="r") as f:
32 | for line in f:
33 | qid = int(line["qid"])
34 | num_total_qids += 1
35 | if qid not in rankings:
36 | print(f"WARNING: qid {qid} not found in {rankings_path}!", file=sys.stderr)
37 | continue
38 | answer_pids = set(line["answer_pids"])
39 | if len(set(rankings[qid][:k]).intersection(answer_pids)) > 0:
40 | success += 1
41 | print(
42 | f"[query_type={query_type}, dataset={dataset}] "
43 | f"Success@{k}: {success / num_total_qids * 100:.1f}"
44 | )
45 |
46 |
47 | def main(args):
48 | for query_type in ["search", "forum"]:
49 | for dataset in [
50 | "writing",
51 | "recreation",
52 | "science",
53 | "technology",
54 | "lifestyle",
55 | "pooled",
56 | ]:
57 | evaluate_dataset(
58 | query_type,
59 | dataset,
60 | args.split,
61 | args.k,
62 | args.data_dir,
63 | args.rankings_dir,
64 | )
65 | print()
66 |
67 |
68 | if __name__ == "__main__":
69 | parser = argparse.ArgumentParser(description="LoTTE evaluation script")
70 | parser.add_argument("--k", type=int, default=5, help="Success@k")
71 | parser.add_argument(
72 | "-s", "--split", choices=["dev", "test"], required=True, help="Split"
73 | )
74 | parser.add_argument(
75 | "-d", "--data_dir", type=str, required=True, help="Path to LoTTE data directory"
76 | )
77 | parser.add_argument(
78 | "-r",
79 | "--rankings_dir",
80 | type=str,
81 | required=True,
82 | help="Path to LoTTE rankings directory",
83 | )
84 | args = parser.parse_args()
85 | main(args)
86 |
--------------------------------------------------------------------------------
/third_party/utility/evaluate/msmarco_passages.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluate MS MARCO Passages ranking.
3 | """
4 |
5 | import os
6 | import math
7 | import tqdm
8 | import ujson
9 | import random
10 |
11 | from argparse import ArgumentParser
12 | from collections import defaultdict
13 | from third_party.colbert.utils.utils import print_message, file_tqdm
14 |
15 |
16 | def main(args):
17 | qid2positives = defaultdict(list)
18 | qid2ranking = defaultdict(list)
19 | qid2mrr = {}
20 | qid2recall = {depth: {} for depth in [50, 200, 1000, 5000, 10000]}
21 |
22 | with open(args.qrels) as f:
23 | print_message(f"#> Loading QRELs from {args.qrels} ..")
24 | for line in file_tqdm(f):
25 | qid, _, pid, label = map(int, line.strip().split())
26 | assert label == 1
27 |
28 | qid2positives[qid].append(pid)
29 |
30 | with open(args.ranking) as f:
31 | print_message(f"#> Loading ranked lists from {args.ranking} ..")
32 | for line in file_tqdm(f):
33 | qid, pid, rank, *score = line.strip().split('\t')
34 | qid, pid, rank = int(qid), int(pid), int(rank)
35 |
36 | if len(score) > 0:
37 | assert len(score) == 1
38 | score = float(score[0])
39 | else:
40 | score = None
41 |
42 | qid2ranking[qid].append((rank, pid, score))
43 |
44 | assert set.issubset(set(qid2ranking.keys()), set(qid2positives.keys()))
45 |
46 | num_judged_queries = len(qid2positives)
47 | num_ranked_queries = len(qid2ranking)
48 |
49 | if num_judged_queries != num_ranked_queries:
50 | print()
51 | print_message("#> [WARNING] num_judged_queries != num_ranked_queries")
52 | print_message(f"#> {num_judged_queries} != {num_ranked_queries}")
53 | print()
54 |
55 | print_message(f"#> Computing MRR@10 for {num_judged_queries} queries.")
56 |
57 | for qid in tqdm.tqdm(qid2positives):
58 | ranking = qid2ranking[qid]
59 | positives = qid2positives[qid]
60 |
61 | for rank, (_, pid, _) in enumerate(ranking):
62 | rank = rank + 1 # 1-indexed
63 |
64 | if pid in positives:
65 | if rank <= 10:
66 | qid2mrr[qid] = 1.0 / rank
67 | break
68 |
69 | for rank, (_, pid, _) in enumerate(ranking):
70 | rank = rank + 1 # 1-indexed
71 |
72 | if pid in positives:
73 | for depth in qid2recall:
74 | if rank <= depth:
75 | qid2recall[depth][qid] = qid2recall[depth].get(qid, 0) + 1.0 / len(positives)
76 |
77 | assert len(qid2mrr) <= num_ranked_queries, (len(qid2mrr), num_ranked_queries)
78 |
79 | print()
80 | mrr_10_sum = sum(qid2mrr.values())
81 | print_message(f"#> MRR@10 = {mrr_10_sum / num_judged_queries}")
82 | print_message(f"#> MRR@10 (only for ranked queries) = {mrr_10_sum / num_ranked_queries}")
83 | print()
84 |
85 | for depth in qid2recall:
86 | assert len(qid2recall[depth]) <= num_ranked_queries, (len(qid2recall[depth]), num_ranked_queries)
87 |
88 | print()
89 | metric_sum = sum(qid2recall[depth].values())
90 | print_message(f"#> Recall@{depth} = {metric_sum / num_judged_queries}")
91 | print_message(f"#> Recall@{depth} (only for ranked queries) = {metric_sum / num_ranked_queries}")
92 | print()
93 |
94 | if args.annotate:
95 | print_message(f"#> Writing annotations to {args.output} ..")
96 |
97 | with open(args.output, 'w') as f:
98 | for qid in tqdm.tqdm(qid2positives):
99 | ranking = qid2ranking[qid]
100 | positives = qid2positives[qid]
101 |
102 | for rank, (_, pid, score) in enumerate(ranking):
103 | rank = rank + 1 # 1-indexed
104 | label = int(pid in positives)
105 |
106 | line = [qid, pid, rank, score, label]
107 | line = [x for x in line if x is not None]
108 | line = '\t'.join(map(str, line)) + '\n'
109 | f.write(line)
110 |
111 |
112 | if __name__ == "__main__":
113 | parser = ArgumentParser(description="msmarco_passages.")
114 |
115 | # Input Arguments.
116 | parser.add_argument('--qrels', dest='qrels', required=True, type=str)
117 | parser.add_argument('--ranking', dest='ranking', required=True, type=str)
118 | parser.add_argument('--annotate', dest='annotate', default=False, action='store_true')
119 |
120 | args = parser.parse_args()
121 |
122 | if args.annotate:
123 | args.output = f'{args.ranking}.annotated'
124 | assert not os.path.exists(args.output), args.output
125 |
126 | main(args)
127 |
--------------------------------------------------------------------------------
/third_party/utility/preprocess/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/utility/preprocess/__init__.py
--------------------------------------------------------------------------------
/third_party/utility/preprocess/queries_split.py:
--------------------------------------------------------------------------------
1 | """
2 | Divide a query set into two.
3 | """
4 |
5 | import os
6 | import math
7 | import ujson
8 | import random
9 |
10 | from argparse import ArgumentParser
11 | from collections import OrderedDict
12 | from third_party.colbert.utils.utils import print_message
13 |
14 |
15 | def main(args):
16 | random.seed(12345)
17 |
18 | """
19 | Load the queries
20 | """
21 | Queries = OrderedDict()
22 |
23 | print_message(f"#> Loading queries from {args.input}..")
24 | with open(args.input) as f:
25 | for line in f:
26 | qid, query = line.strip().split('\t')
27 |
28 | assert qid not in Queries
29 | Queries[qid] = query
30 |
31 | """
32 | Apply the splitting
33 | """
34 | size_a = len(Queries) - args.holdout
35 | size_b = args.holdout
36 | size_a, size_b = max(size_a, size_b), min(size_a, size_b)
37 |
38 | assert size_a > 0 and size_b > 0, (len(Queries), size_a, size_b)
39 |
40 | print_message(f"#> Deterministically splitting the queries into ({size_a}, {size_b})-sized splits.")
41 |
42 | keys = list(Queries.keys())
43 | sample_b_indices = sorted(list(random.sample(range(len(keys)), size_b)))
44 | sample_a_indices = sorted(list(set.difference(set(list(range(len(keys)))), set(sample_b_indices))))
45 |
46 | assert len(sample_a_indices) == size_a
47 | assert len(sample_b_indices) == size_b
48 |
49 | sample_a = [keys[idx] for idx in sample_a_indices]
50 | sample_b = [keys[idx] for idx in sample_b_indices]
51 |
52 | """
53 | Write the output
54 | """
55 |
56 | output_path_a = f'{args.input}.a'
57 | output_path_b = f'{args.input}.b'
58 |
59 | assert not os.path.exists(output_path_a), output_path_a
60 | assert not os.path.exists(output_path_b), output_path_b
61 |
62 | print_message(f"#> Writing the splits out to {output_path_a} and {output_path_b} ...")
63 |
64 | for output_path, sample in [(output_path_a, sample_a), (output_path_b, sample_b)]:
65 | with open(output_path, 'w') as f:
66 | for qid in sample:
67 | query = Queries[qid]
68 | line = '\t'.join([qid, query]) + '\n'
69 | f.write(line)
70 |
71 |
72 | if __name__ == "__main__":
73 | parser = ArgumentParser(description="queries_split.")
74 |
75 | # Input Arguments.
76 | parser.add_argument('--input', dest='input', required=True)
77 | parser.add_argument('--holdout', dest='holdout', required=True, type=int)
78 |
79 | args = parser.parse_args()
80 |
81 | main(args)
82 |
--------------------------------------------------------------------------------
/third_party/utility/rankings/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/utility/rankings/__init__.py
--------------------------------------------------------------------------------
/third_party/utility/rankings/dev_subsample.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ujson
3 | import random
4 |
5 | from argparse import ArgumentParser
6 |
7 | from third_party.colbert.utils.utils import print_message, create_directory, load_ranking, groupby_first_item
8 | from third_party.utility.utils.qa_loaders import load_qas_
9 |
10 |
11 | def main(args):
12 | print_message("#> Loading all..")
13 | qas = load_qas_(args.qas)
14 | rankings = load_ranking(args.ranking)
15 | qid2rankings = groupby_first_item(rankings)
16 |
17 | print_message("#> Subsampling all..")
18 | qas_sample = random.sample(qas, args.sample)
19 |
20 | with open(args.output, 'w') as f:
21 | for qid, *_ in qas_sample:
22 | for items in qid2rankings[qid]:
23 | items = [qid] + items
24 | line = '\t'.join(map(str, items)) + '\n'
25 | f.write(line)
26 |
27 | print('\n\n')
28 | print(args.output)
29 | print("#> Done.")
30 |
31 |
32 | if __name__ == "__main__":
33 | random.seed(12345)
34 |
35 | parser = ArgumentParser(description='Subsample the dev set.')
36 | parser.add_argument('--qas', dest='qas', required=True, type=str)
37 | parser.add_argument('--ranking', dest='ranking', required=True)
38 | parser.add_argument('--output', dest='output', required=True)
39 |
40 | parser.add_argument('--sample', dest='sample', default=1500, type=int)
41 |
42 | args = parser.parse_args()
43 |
44 | assert not os.path.exists(args.output), args.output
45 | create_directory(os.path.dirname(args.output))
46 |
47 | main(args)
48 |
--------------------------------------------------------------------------------
/third_party/utility/rankings/merge.py:
--------------------------------------------------------------------------------
1 | """
2 | Divide two or more ranking files, by score.
3 | """
4 |
5 | import os
6 | import tqdm
7 |
8 | from argparse import ArgumentParser
9 | from collections import defaultdict
10 | from third_party.colbert.utils.utils import print_message, file_tqdm
11 |
12 |
13 | def main(args):
14 | Rankings = defaultdict(list)
15 |
16 | for path in args.input:
17 | print_message(f"#> Loading the rankings in {path} ..")
18 |
19 | with open(path) as f:
20 | for line in file_tqdm(f):
21 | qid, pid, rank, score = line.strip().split('\t')
22 | qid, pid, rank = map(int, [qid, pid, rank])
23 | score = float(score)
24 |
25 | Rankings[qid].append((score, rank, pid))
26 |
27 | with open(args.output, 'w') as f:
28 | print_message(f"#> Writing the output rankings to {args.output} ..")
29 |
30 | for qid in tqdm.tqdm(Rankings):
31 | ranking = sorted(Rankings[qid], reverse=True)
32 |
33 | for rank, (score, original_rank, pid) in enumerate(ranking):
34 | rank = rank + 1 # 1-indexed
35 |
36 | if (args.depth > 0) and (rank > args.depth):
37 | break
38 |
39 | line = [qid, pid, rank, score]
40 | line = '\t'.join(map(str, line)) + '\n'
41 | f.write(line)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = ArgumentParser(description="merge_rankings.")
46 |
47 | # Input Arguments.
48 | parser.add_argument('--input', dest='input', required=True, nargs='+')
49 | parser.add_argument('--output', dest='output', required=True, type=str)
50 |
51 | parser.add_argument('--depth', dest='depth', required=True, type=int)
52 |
53 | args = parser.parse_args()
54 |
55 | assert not os.path.exists(args.output), args.output
56 |
57 | main(args)
58 |
--------------------------------------------------------------------------------
/third_party/utility/rankings/split_by_offset.py:
--------------------------------------------------------------------------------
1 | """
2 | Split the ranked lists after retrieval with a merged query set.
3 | """
4 |
5 | import os
6 | import random
7 |
8 | from argparse import ArgumentParser
9 |
10 |
11 | def main(args):
12 | output_paths = ['{}.{}'.format(args.ranking, split) for split in args.names]
13 | assert all(not os.path.exists(path) for path in output_paths), output_paths
14 |
15 | output_files = [open(path, 'w') for path in output_paths]
16 |
17 | with open(args.ranking) as f:
18 | for line in f:
19 | qid, pid, rank, *other = line.strip().split('\t')
20 | qid = int(qid)
21 | split_output_path = output_files[qid // args.gap - 1]
22 | qid = qid % args.gap
23 |
24 | split_output_path.write('\t'.join([str(x) for x in [qid, pid, rank, *other]]) + '\n')
25 |
26 | print(f.name)
27 |
28 | _ = [f.close() for f in output_files]
29 |
30 | print("#> Done!")
31 |
32 |
33 | if __name__ == "__main__":
34 | random.seed(12345)
35 |
36 | parser = ArgumentParser(description='Subsample the dev set.')
37 | parser.add_argument('--ranking', dest='ranking', required=True)
38 |
39 | parser.add_argument('--names', dest='names', required=False, default=['train', 'dev', 'test'], type=str, nargs='+') # order matters!
40 | parser.add_argument('--gap', dest='gap', required=False, default=1_000_000_000, type=int) # larger than any individual query set
41 |
42 | args = parser.parse_args()
43 |
44 | main(args)
45 |
--------------------------------------------------------------------------------
/third_party/utility/rankings/split_by_queries.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import tqdm
4 | import ujson
5 | import random
6 |
7 | from argparse import ArgumentParser
8 | from collections import OrderedDict
9 | from third_party.colbert.utils.utils import print_message, file_tqdm
10 |
11 |
12 | def main(args):
13 | qid_to_file_idx = {}
14 |
15 | for qrels_idx, qrels in enumerate(args.all_queries):
16 | with open(qrels) as f:
17 | for line in f:
18 | qid, *_ = line.strip().split('\t')
19 | qid = int(qid)
20 |
21 | assert qid_to_file_idx.get(qid, qrels_idx) == qrels_idx, (qid, qrels_idx)
22 | qid_to_file_idx[qid] = qrels_idx
23 |
24 | all_outputs_paths = [f'{args.ranking}.{idx}' for idx in range(len(args.all_queries))]
25 |
26 | assert all(not os.path.exists(path) for path in all_outputs_paths)
27 |
28 | all_outputs = [open(path, 'w') for path in all_outputs_paths]
29 |
30 | with open(args.ranking) as f:
31 | print_message(f"#> Loading ranked lists from {f.name} ..")
32 |
33 | last_file_idx = -1
34 |
35 | for line in file_tqdm(f):
36 | qid, *_ = line.strip().split('\t')
37 |
38 | file_idx = qid_to_file_idx[int(qid)]
39 |
40 | if file_idx != last_file_idx:
41 | print_message(f"#> Switched to file #{file_idx} at {all_outputs[file_idx].name}")
42 |
43 | last_file_idx = file_idx
44 |
45 | all_outputs[file_idx].write(line)
46 |
47 | print()
48 |
49 | for f in all_outputs:
50 | print(f.name)
51 | f.close()
52 |
53 | print("#> Done!")
54 |
55 |
56 | if __name__ == "__main__":
57 | random.seed(12345)
58 |
59 | parser = ArgumentParser(description='.')
60 |
61 | # Input Arguments
62 | parser.add_argument('--ranking', dest='ranking', required=True, type=str)
63 | parser.add_argument('--all-queries', dest='all_queries', required=True, type=str, nargs='+')
64 |
65 | args = parser.parse_args()
66 |
67 | main(args)
68 |
--------------------------------------------------------------------------------
/third_party/utility/rankings/tune.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ujson
3 | import random
4 |
5 | from argparse import ArgumentParser
6 | from third_party.colbert.utils.utils import print_message, create_directory
7 | from third_party.utility.utils.save_metadata import save_metadata
8 |
9 |
10 | def main(args):
11 | AllMetrics = {}
12 | Scores = {}
13 |
14 | for path in args.paths:
15 | with open(path) as f:
16 | metric = ujson.load(f)
17 | AllMetrics[path] = metric
18 |
19 | for k in args.metric:
20 | metric = metric[k]
21 |
22 | assert type(metric) is float
23 | Scores[path] = metric
24 |
25 | MaxKey = max(Scores, key=Scores.get)
26 |
27 | MaxCKPT = int(MaxKey.split('/')[-2].split('.')[-1])
28 | MaxARGS = os.path.join(os.path.dirname(MaxKey), 'logs', 'args.json')
29 |
30 | with open(MaxARGS) as f:
31 | logs = ujson.load(f)
32 | MaxCHECKPOINT = logs['checkpoint']
33 |
34 | assert MaxCHECKPOINT.endswith(f'colbert-{MaxCKPT}.dnn'), (MaxCHECKPOINT, MaxCKPT)
35 |
36 | with open(args.output, 'w') as f:
37 | f.write(MaxCHECKPOINT)
38 |
39 | args.Scores = Scores
40 | args.AllMetrics = AllMetrics
41 |
42 | save_metadata(f'{args.output}.meta', args)
43 |
44 | print('\n\n', args, '\n\n')
45 | print(args.output)
46 | print_message("#> Done.")
47 |
48 |
49 | if __name__ == "__main__":
50 | random.seed(12345)
51 |
52 | parser = ArgumentParser(description='.')
53 |
54 | # Input / Output Arguments
55 | parser.add_argument('--metric', dest='metric', required=True, type=str) # e.g., success.20
56 | parser.add_argument('--paths', dest='paths', required=True, type=str, nargs='+')
57 | parser.add_argument('--output', dest='output', required=True, type=str)
58 |
59 | args = parser.parse_args()
60 |
61 | args.metric = args.metric.split('.')
62 |
63 | assert not os.path.exists(args.output), args.output
64 | create_directory(os.path.dirname(args.output))
65 |
66 | main(args)
67 |
--------------------------------------------------------------------------------
/third_party/utility/supervision/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/utility/supervision/__init__.py
--------------------------------------------------------------------------------
/third_party/utility/supervision/self_training.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import git
4 | import tqdm
5 | import ujson
6 | import random
7 |
8 | from argparse import ArgumentParser
9 | from third_party.colbert.utils.utils import print_message, load_ranking, groupby_first_item
10 |
11 |
12 | MAX_NUM_TRIPLES = 40_000_000
13 |
14 |
15 | def sample_negatives(negatives, num_sampled, biased=False):
16 | num_sampled = min(len(negatives), num_sampled)
17 |
18 | if biased:
19 | assert num_sampled % 2 == 0
20 | num_sampled_top100 = num_sampled // 2
21 | num_sampled_rest = num_sampled - num_sampled_top100
22 |
23 | return random.sample(negatives[:100], num_sampled_top100) + random.sample(negatives[100:], num_sampled_rest)
24 |
25 | return random.sample(negatives, num_sampled)
26 |
27 |
28 | def sample_for_query(qid, ranking, npositives, depth_positive, depth_negative, cutoff_negative):
29 | """
30 | Requires that the ranks are sorted per qid.
31 | """
32 | assert npositives <= depth_positive < cutoff_negative < depth_negative
33 |
34 | positives, negatives, triples = [], [], []
35 |
36 | for pid, rank, *_ in ranking:
37 | assert rank >= 1, f"ranks should start at 1 \t\t got rank = {rank}"
38 |
39 | if rank > depth_negative:
40 | break
41 |
42 | if rank <= depth_positive:
43 | positives.append(pid)
44 | elif rank > cutoff_negative:
45 | negatives.append(pid)
46 |
47 | num_sampled = 100
48 |
49 | for neg in sample_negatives(negatives, num_sampled):
50 | positives_ = random.sample(positives, npositives)
51 | positives_ = positives_[0] if npositives == 1 else positives_
52 | triples.append((qid, positives_, neg))
53 |
54 | return triples
55 |
56 |
57 | def main(args):
58 | rankings = load_ranking(args.ranking, types=[int, int, int, float, int])
59 |
60 | print_message("#> Group by QID")
61 | qid2rankings = groupby_first_item(tqdm.tqdm(rankings))
62 |
63 | Triples = []
64 | NonEmptyQIDs = 0
65 |
66 | for processing_idx, qid in enumerate(qid2rankings):
67 | l = sample_for_query(qid, qid2rankings[qid], args.positives, args.depth_positive, args.depth_negative, args.cutoff_negative)
68 | NonEmptyQIDs += (len(l) > 0)
69 | Triples.extend(l)
70 |
71 | if processing_idx % (10_000) == 0:
72 | print_message(f"#> Done with {processing_idx+1} questions!\t\t "
73 | f"{str(len(Triples) / 1000)}k triples for {NonEmptyQIDs} unqiue QIDs.")
74 |
75 | print_message(f"#> Sub-sample the triples (if > {MAX_NUM_TRIPLES})..")
76 | print_message(f"#> len(Triples) = {len(Triples)}")
77 |
78 | if len(Triples) > MAX_NUM_TRIPLES:
79 | Triples = random.sample(Triples, MAX_NUM_TRIPLES)
80 |
81 | ### Prepare the triples ###
82 | print_message("#> Shuffling the triples...")
83 | random.shuffle(Triples)
84 |
85 | print_message("#> Writing {}M examples to file.".format(len(Triples) / 1000.0 / 1000.0))
86 |
87 | with open(args.output, 'w') as f:
88 | for example in Triples:
89 | ujson.dump(example, f)
90 | f.write('\n')
91 |
92 | with open(f'{args.output}.meta', 'w') as f:
93 | args.cmd = ' '.join(sys.argv)
94 | args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha
95 | ujson.dump(args.__dict__, f, indent=4)
96 | f.write('\n')
97 |
98 | print('\n\n', args, '\n\n')
99 | print(args.output)
100 | print_message("#> Done.")
101 |
102 |
103 | if __name__ == "__main__":
104 | random.seed(12345)
105 |
106 | parser = ArgumentParser(description='Create training triples from ranked list.')
107 |
108 | # Input / Output Arguments
109 | parser.add_argument('--ranking', dest='ranking', required=True, type=str)
110 | parser.add_argument('--output', dest='output', required=True, type=str)
111 |
112 | # Weak Supervision Arguments.
113 | parser.add_argument('--positives', dest='positives', required=True, type=int)
114 | parser.add_argument('--depth+', dest='depth_positive', required=True, type=int)
115 |
116 | parser.add_argument('--depth-', dest='depth_negative', required=True, type=int)
117 | parser.add_argument('--cutoff-', dest='cutoff_negative', required=True, type=int)
118 |
119 | args = parser.parse_args()
120 |
121 | assert not os.path.exists(args.output), args.output
122 |
123 | main(args)
124 |
--------------------------------------------------------------------------------
/third_party/utility/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aimagelab/ReT/3fee6aabccbf5f2e2577446c8f2de0010219f496/third_party/utility/utils/__init__.py
--------------------------------------------------------------------------------
/third_party/utility/utils/qa_loaders.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ujson
3 |
4 | from collections import defaultdict
5 | from third_party.colbert.utils.utils import print_message, file_tqdm
6 |
7 |
8 | def load_collection_(path, retain_titles):
9 | with open(path) as f:
10 | collection = []
11 |
12 | for line in file_tqdm(f):
13 | _, passage, title = line.strip().split('\t')
14 |
15 | if retain_titles:
16 | passage = title + ' | ' + passage
17 |
18 | collection.append(passage)
19 |
20 | return collection
21 |
22 |
23 | def load_qas_(path):
24 | print_message("#> Loading the reference QAs from", path)
25 |
26 | triples = []
27 |
28 | with open(path) as f:
29 | for line in f:
30 | qa = ujson.loads(line)
31 | triples.append((qa['qid'], qa['question'], qa['answers']))
32 |
33 | return triples
34 |
--------------------------------------------------------------------------------
/third_party/utility/utils/save_metadata.py:
--------------------------------------------------------------------------------
1 | from third_party.colbert.utils.utils import dotdict
2 | import os
3 | import sys
4 | # import git
5 | import time
6 | import copy
7 | import ujson
8 | import socket
9 |
10 |
11 | def get_metadata_only():
12 | args = dotdict()
13 |
14 | args.hostname = socket.gethostname()
15 | # try:
16 | # args.git_branch = git.Repo(search_parent_directories=True).active_branch.name
17 | # args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha
18 | # args.git_commit_datetime = str(git.Repo(search_parent_directories=True).head.object.committed_datetime)
19 | # except git.exc.InvalidGitRepositoryError as e:
20 | # pass
21 | args.current_datetime = time.strftime('%b %d, %Y ; %l:%M%p %Z (%z)')
22 | args.cmd = ' '.join(sys.argv)
23 |
24 | return args
25 |
26 |
27 | def get_metadata(args):
28 | args = copy.deepcopy(args)
29 |
30 | args.hostname = socket.gethostname()
31 | args.git_branch = git.Repo(search_parent_directories=True).active_branch.name
32 | args.git_hash = git.Repo(search_parent_directories=True).head.object.hexsha
33 | args.git_commit_datetime = str(git.Repo(search_parent_directories=True).head.object.committed_datetime)
34 | args.current_datetime = time.strftime('%b %d, %Y ; %l:%M%p %Z (%z)')
35 | args.cmd = ' '.join(sys.argv)
36 |
37 | try:
38 | args.input_arguments = copy.deepcopy(args.input_arguments.__dict__)
39 | except:
40 | args.input_arguments = None
41 |
42 | return dict(args.__dict__)
43 |
44 | # TODO: No reason for deepcopy. But: (a) Call provenance() on objects that can, (b) Only save simple, small objects. No massive lists or models or weird stuff!
45 | # With that, I think we don't even need (necessarily) to restrict things to input_arguments.
46 |
47 | def format_metadata(metadata):
48 | assert type(metadata) == dict
49 |
50 | return ujson.dumps(metadata, indent=4)
51 |
52 |
53 | def save_metadata(path, args):
54 | assert not os.path.exists(path), path
55 |
56 | with open(path, 'w') as output_metadata:
57 | data = get_metadata(args)
58 | output_metadata.write(format_metadata(data) + '\n')
59 |
60 | return data
61 |
--------------------------------------------------------------------------------