├── stage1 ├── grad_cache │ ├── pytorch_lightning │ │ ├── requirements.txt │ │ ├── readme.md │ │ ├── pl_example.py │ │ └── pl_gradcache.py │ ├── __init__.py │ ├── cachex │ │ ├── __init__.py │ │ ├── tree_utils.py │ │ ├── training.py │ │ └── functional.py │ ├── context_managers.py │ ├── loss.py │ ├── functional.py │ └── grad_cache.py ├── dataset_process.py ├── train.py ├── dataset.py ├── grad_cache_custom.py ├── loss.py ├── eval_metric.py ├── custom_trainer.py └── trainer_headers.py ├── requirements.txt ├── cp-retrieval-server ├── templates │ ├── problem.html │ ├── base.html │ ├── stats.html │ └── index.html ├── download.py ├── example.py ├── compute_embs.py └── app.py ├── .gitignore ├── stage2 ├── dataset_process.py ├── trainer_custom.py └── train.py ├── TestCases.md └── README.md /stage1/grad_cache/pytorch_lightning/requirements.txt: -------------------------------------------------------------------------------- 1 | lightning 2 | pytorch_metric_learning 3 | -------------------------------------------------------------------------------- /stage1/grad_cache/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .grad_cache import GradCache 3 | except ModuleNotFoundError: 4 | pass 5 | -------------------------------------------------------------------------------- /stage1/grad_cache/cachex/__init__.py: -------------------------------------------------------------------------------- 1 | from .functional import chunk_encode, cache_grad, unchunk_args, grad_cached 2 | from .tree_utils import tree_chunk, tree_unchunk 3 | 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sentence-transformers==3.4.1 2 | transformers==4.52.4 3 | torch==2.6.0 4 | torchaudio==2.6.0 5 | torchvision==0.21.0 6 | pandas 7 | tensorboard 8 | datasets 9 | accelerate 10 | einops 11 | Flask -------------------------------------------------------------------------------- /stage1/grad_cache/cachex/tree_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import jax 4 | 5 | 6 | def tree_chunk(tree: Any, n_chunk: int, axis: int = 0) -> Any: 7 | return jax.tree_map( 8 | lambda v: v.reshape(v.shape[:axis] + (n_chunk, -1) + v.shape[axis + 1:]), 9 | tree 10 | ) 11 | 12 | 13 | def tree_unchunk(tree: Any, axis: int = 0) -> Any: 14 | return jax.tree_map( 15 | lambda x: x.reshape(x.shape[:axis] + (-1,) + x.shape[axis + 2:]), 16 | tree 17 | ) 18 | -------------------------------------------------------------------------------- /cp-retrieval-server/templates/problem.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block content %} 3 | 4 | ← {{ t.back }} 5 | 6 | 7 |

{{ title }}

8 |

{{ source }}

9 | 10 |
11 | 12 |
13 | {{ text_html }} 14 |
15 |
16 | 17 | 18 | {{ t.view_origin }} 🔗 19 | 20 | {% endblock %} 21 | -------------------------------------------------------------------------------- /stage1/grad_cache/context_managers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.checkpoint import get_device_states, set_device_states 3 | 4 | 5 | class RandContext: 6 | def __init__(self, *tensors): 7 | self.fwd_cpu_state = torch.get_rng_state() 8 | self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors) 9 | 10 | def __enter__(self): 11 | self._fork = torch.random.fork_rng( 12 | devices=self.fwd_gpu_devices, 13 | enabled=True 14 | ) 15 | self._fork.__enter__() 16 | torch.set_rng_state(self.fwd_cpu_state) 17 | set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states) 18 | 19 | def __exit__(self, exc_type, exc_val, exc_tb): 20 | self._fork.__exit__(exc_type, exc_val, exc_tb) 21 | self._fork = None -------------------------------------------------------------------------------- /stage1/grad_cache/cachex/training.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | from .functional import chunk_encode, cache_grad, unchunk_args 7 | 8 | 9 | def cache_train_step(loss_fn, state, ss, tt, axis='device'): 10 | def encode_with_params(params, **kwargs): 11 | return state.apply_fn(params=params, **kwargs) 12 | 13 | encode_fn = chunk_encode(partial(encode_with_params, state.params)) 14 | grad_fn = cache_grad(encode_with_params) 15 | 16 | s_reps = encode_fn(**ss) 17 | t_reps = encode_fn(**tt) 18 | 19 | @unchunk_args(axis=0, argnums=(0, 1)) 20 | def grad_cache_fn(xx, yy): 21 | return jnp.mean(loss_fn(xx, yy, axis=axis)) 22 | loss, (s_grads, t_grads) = jax.value_and_grad(grad_cache_fn, argnums=(0, 1))(s_reps, t_reps) 23 | 24 | grads = jax.tree_map(lambda v: jnp.zeros_like(v), state.params) 25 | grads = grad_fn(state.params, grads, s_grads, **ss) 26 | grads = grad_fn(state.params, grads, t_grads, **tt) 27 | 28 | loss, grads = jax.lax.pmean([loss, grads], axis) 29 | new_state = state.apply_gradients(grads=grads) 30 | return loss, new_state 31 | -------------------------------------------------------------------------------- /cp-retrieval-server/templates/base.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | {{ t.site_name }} 6 | 7 | 8 | 9 | 10 | 11 | 23 |
{% block content %}{% endblock %}
24 | 25 | 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | models/ 3 | results/ 4 | dataset/ 5 | *.npy 6 | *.jsonl 7 | *.json 8 | 9 | # Python 相关 10 | __pycache__/ 11 | *.py[cod] 12 | *.pyo 13 | *.pyd 14 | *.so 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | .pytest_cache/ 33 | 34 | # 虚拟环境 35 | venv/ 36 | env/ 37 | .venv/ 38 | ENV/ 39 | env.bak/ 40 | venv.bak/ 41 | 42 | # 操作系统生成的文件 43 | .DS_Store 44 | Thumbs.db 45 | 46 | # IDE 和编辑器生成的文件 47 | .idea/ 48 | .vscode/ 49 | *.suo 50 | *.ntvs* 51 | *.njsproj 52 | *.sln 53 | *.sw? 54 | 55 | # 依赖目录 56 | node_modules/ 57 | bower_components/ 58 | vendor/ 59 | 60 | # 编译生成的文件 61 | *.log 62 | *.out 63 | *.exe 64 | *.dll 65 | *.dylib 66 | 67 | # 调试文件 68 | *.debug 69 | *.pdb 70 | 71 | # 打包文件 72 | *.zip 73 | *.tar.gz 74 | *.rar 75 | 76 | # 环境文件 77 | .env 78 | .env.local 79 | .env.development 80 | .env.test 81 | .env.production 82 | 83 | # 日志文件 84 | logs/ 85 | *.log 86 | npm-debug.log* 87 | yarn-debug.log* 88 | yarn-error.log* 89 | 90 | # 缓存目录 91 | .cache/ 92 | .temp/ 93 | 94 | # 测试生成的报告 95 | coverage/ 96 | *.lcov 97 | 98 | # 系统文件 99 | *.swp 100 | *.swo -------------------------------------------------------------------------------- /cp-retrieval-server/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | # 修改为镜像源 3 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' 4 | 5 | from huggingface_hub import snapshot_download 6 | 7 | download_dir = './' 8 | 9 | # download probs and embeddings from Hugging Face 10 | repo_id = 'coldchair16/CPRet-Embeddings' 11 | snapshot_download( 12 | repo_id=repo_id, 13 | repo_type='dataset', 14 | local_dir=os.path.join(download_dir), 15 | allow_patterns=['probs_2511*'], 16 | local_dir_use_symlinks=False, 17 | resume_download=True, 18 | max_workers=4, 19 | ) 20 | print(f"Finished downloading {repo_id} to {download_dir}") 21 | 22 | 23 | # download models from Hugging Face 24 | # download_dir = './' 25 | # repo_id_list = [ 26 | # 'coldchair16/CPRetriever-Prob-Qwen3-4B-2510', 27 | # ] 28 | # for repo_id in repo_id_list: 29 | # model_name = repo_id.split("/")[-1] 30 | # # replace '.' with 'p' to avoid issues in SentenceTransformer 31 | # model_name = model_name.replace(".", "p") 32 | # print('Begin downloading:', repo_id) 33 | # snapshot_download( 34 | # repo_id=repo_id, 35 | # repo_type="model", 36 | # local_dir=os.path.join(download_dir, model_name), 37 | # allow_patterns=['*'], 38 | # local_dir_use_symlinks=False, 39 | # resume_download=True, 40 | # max_workers=4, 41 | # ) 42 | # print(f"Finished downloading {repo_id} to {download_dir}/{model_name}") -------------------------------------------------------------------------------- /cp-retrieval-server/example.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from datasets import load_dataset 4 | import numpy as np 5 | from sentence_transformers import SentenceTransformer 6 | import random 7 | import json 8 | import tqdm as tqdm 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--model_path", 13 | type=str, 14 | default='coldchair16/CPRetriever-Prob', 15 | help="Path to the SentenceTransformer model") 16 | 17 | args = parser.parse_args() 18 | 19 | model_path = args.model_path 20 | model = SentenceTransformer(model_path, trust_remote_code=True) 21 | model.tokenizer.model_max_length = 1024 22 | model.max_seq_length = 1024 23 | 24 | embs = np.load('./probs_embs.npy') 25 | embs = embs / np.linalg.norm(embs, axis=1, keepdims=True) 26 | probs_path = './probs.jsonl' 27 | probs = [] 28 | with open(probs_path, 'r') as f: 29 | for line in f.readlines(): 30 | line = line.strip() 31 | if line: 32 | data = json.loads(line) 33 | probs.append(data) 34 | 35 | 36 | text = ''' 37 | You are given a sequence that supports the following operations: 38 | 1. Flatten a specified range. 39 | 2. Calculate the sum of the counts of distinct numbers in all subsegments of 40 | length k.''' 41 | 42 | text_emb = model.encode(text, convert_to_tensor=True) 43 | text_emb = text_emb / text_emb.norm(dim=-1, keepdim=True) 44 | 45 | sim_mat = np.dot(embs, text_emb) 46 | print(sim_mat.shape) 47 | rank = np.argsort(sim_mat, axis=0)[::-1] 48 | 49 | p = probs[rank[0]] 50 | print(p['title'], p['url'], p['source'], p['text']) 51 | 52 | 53 | -------------------------------------------------------------------------------- /stage1/dataset_process.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | def filted_by_length(data, max_length=1024): 4 | """ 5 | Filter out solutions in each sample whose token length exceeds max_length. 6 | 7 | Args: 8 | data (list): A list of data samples, each containing 'solutions' and 'solutions_length'. 9 | max_length (int): Maximum allowed token length for a solution. 10 | 11 | Returns: 12 | list: A filtered list of samples, each retaining only solutions with acceptable length. 13 | """ 14 | n = len(data) 15 | new_data = [] 16 | for i in range(n): 17 | m = len(data[i]['solutions']) 18 | 19 | # Find indices of solutions within the length constraint 20 | index = [j for j in range(m) if data[i]['solutions_length'][j] <= max_length] 21 | 22 | # Keep only the filtered solutions and lengths 23 | data[i]['solutions'] = [data[i]['solutions'][j] for j in index] 24 | data[i]['solutions_length'] = [data[i]['solutions_length'][j] for j in index] 25 | 26 | if len(index) > 0: 27 | new_data.append(data[i]) 28 | 29 | return new_data 30 | 31 | # Filter out solutions whose token count exceeds max_length 32 | def process_dataset(dataset, max_length=1024): 33 | """ 34 | Process a HuggingFace dataset dict by filtering out solutions that exceed a length threshold. 35 | 36 | Args: 37 | dataset (DatasetDict): HuggingFace dataset containing 'train' and 'test' splits. 38 | max_length (int): Maximum token length allowed for each solution. 39 | 40 | Returns: 41 | Tuple: Filtered train and test data as lists. 42 | """ 43 | traindata = dataset['train'].to_list() 44 | testdata = dataset['test'].to_list() 45 | 46 | traindata = filted_by_length(traindata, max_length) 47 | testdata = filted_by_length(testdata, max_length) 48 | 49 | print(f"Train data size: {len(traindata)}") 50 | print(f"Test data size: {len(testdata)}") 51 | 52 | return traindata, testdata 53 | -------------------------------------------------------------------------------- /cp-retrieval-server/compute_embs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from datasets import load_dataset 4 | import numpy as np 5 | from sentence_transformers import SentenceTransformer 6 | import random 7 | import json 8 | import tqdm as tqdm 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--model_path", 13 | type=str, 14 | default='coldchair16/CPRetriever-Prob', 15 | help="Path to the SentenceTransformer model") 16 | parser.add_argument("--input_path", 17 | type=str, 18 | default='./probs.jsonl', 19 | help="Path to the input JSONL file containing sentences") 20 | parser.add_argument("--max_length", 21 | type=int, 22 | default=2048, 23 | help="Maximum length of input sentences") 24 | 25 | args = parser.parse_args() 26 | 27 | sentences = [] 28 | 29 | input_path = args.input_path 30 | with open(input_path, 'r') as f: 31 | for line in f.readlines(): 32 | line = line.strip() 33 | if line: 34 | data = json.loads(line) 35 | sentences.append(data['text']) 36 | 37 | print(f"Number of sentences: {len(sentences)}") 38 | 39 | save_dir = './' 40 | os.makedirs(save_dir, exist_ok=True) 41 | 42 | model_path = args.model_path 43 | model_name = os.path.basename(model_path.rstrip('/')) 44 | 45 | model = SentenceTransformer(model_path, trust_remote_code=True) 46 | model.tokenizer.model_max_length = args.max_length 47 | model.max_seq_length = args.max_length 48 | 49 | pool = model.start_multi_process_pool() 50 | emb = model.encode_multi_process(sentences, pool, show_progress_bar=True, batch_size=8) 51 | model.stop_multi_process_pool(pool) 52 | 53 | print("Embeddings computed. Shape:", emb.shape) 54 | 55 | save_name = f"{input_path.split('/')[-1].replace('.jsonl', '')}_embs.npy" 56 | save_path = os.path.join(save_dir, save_name) 57 | emb = emb.astype('float32') 58 | np.save(save_path, emb) 59 | 60 | print(f"Embeddings saved to {save_path}") 61 | -------------------------------------------------------------------------------- /stage1/grad_cache/loss.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import functional as F 6 | from torch import distributed as dist 7 | 8 | 9 | class SimpleContrastiveLoss: 10 | def __init__(self, n_hard_negatives: int = 0): 11 | self.target_per_qry = n_hard_negatives + 1 12 | 13 | def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean'): 14 | if target is None: 15 | assert x.size(0) * self.target_per_qry == y.size(0) 16 | target = torch.arange(0, x.size(0) * self.target_per_qry, self.target_per_qry, device=x.device) 17 | 18 | logits = torch.matmul(x, y.transpose(0, 1)) 19 | return F.cross_entropy(logits, target, reduction=reduction) 20 | 21 | 22 | class DistributedContrastiveLoss(SimpleContrastiveLoss): 23 | def __init__(self, n_hard_negatives: int = 0): 24 | assert dist.is_initialized(), "Distributed training has not been properly initialized." 25 | 26 | super().__init__(n_hard_negatives=n_hard_negatives) 27 | self.word_size = dist.get_world_size() 28 | self.rank = dist.get_rank() 29 | 30 | def __call__(self, x: Tensor, y: Tensor, **kwargs): 31 | dist_x = self.gather_tensor(x) 32 | dist_y = self.gather_tensor(y) 33 | 34 | return super().__call__(dist_x, dist_y, **kwargs) 35 | 36 | def gather_tensor(self, t): 37 | gathered = [torch.empty_like(t) for _ in range(self.word_size)] 38 | dist.all_gather(gathered, t) 39 | gathered[self.rank] = t 40 | return torch.cat(gathered, dim=0) 41 | 42 | 43 | class ContrastiveLossWithQueryClosure(SimpleContrastiveLoss): 44 | def __call__( 45 | self, 46 | *reps: Tensor, 47 | query_closure: Callable[[], Tensor] = None, 48 | target: Tensor = None, 49 | reduction: str = 'mean' 50 | ): 51 | if len(reps) == 0 or len(reps) > 2: 52 | raise ValueError(f'Expecting 1 or 2 tensor input, got {len(reps)} tensors') 53 | 54 | # no closure evaluation 55 | if len(reps) == 2: 56 | assert query_closure is None, 'received 2 representation tensors while query_closure is also set' 57 | return super().__call__(*reps, target=target, reduction=reduction) 58 | 59 | # run the closure 60 | assert query_closure is not None 61 | x = query_closure() 62 | y = reps[0] 63 | return super().__call__(x, y, target=target, reduction=reduction) 64 | -------------------------------------------------------------------------------- /stage2/dataset_process.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | def filted_by_length(data, max_length=1024): 4 | """ 5 | Filter out solutions in each sample whose token length exceeds max_length. 6 | 7 | Args: 8 | data (list): A list of data samples, each containing 'solutions' and 'solutions_length'. 9 | max_length (int): Maximum allowed token length for a solution. 10 | 11 | Returns: 12 | list: A filtered list of samples, each retaining only solutions with acceptable length. 13 | """ 14 | n = len(data) 15 | new_data = [] 16 | for i in range(n): 17 | m = len(data[i]['solutions']) 18 | 19 | # Find indices of solutions within the length constraint 20 | index = [j for j in range(m) if data[i]['solutions_length'][j] <= max_length] 21 | 22 | # Keep only the filtered solutions and lengths 23 | data[i]['solutions'] = [data[i]['solutions'][j] for j in index] 24 | data[i]['solutions_length'] = [data[i]['solutions_length'][j] for j in index] 25 | 26 | if len(index) > 0: 27 | new_data.append(data[i]) 28 | 29 | return new_data 30 | 31 | # Filter out solutions whose token count exceeds max_length 32 | def process_dataset(dataset, max_length=1024): 33 | """ 34 | Process a HuggingFace dataset dict by filtering out solutions that exceed a length threshold. 35 | 36 | Args: 37 | dataset (DatasetDict): HuggingFace dataset containing 'train' and 'test' splits. 38 | max_length (int): Maximum token length allowed for each solution. 39 | 40 | Returns: 41 | Tuple: Filtered train and test data as lists. 42 | """ 43 | traindata = dataset['train'].to_list() 44 | testdata = dataset['test'].to_list() 45 | 46 | traindata = filted_by_length(traindata, max_length) 47 | testdata = filted_by_length(testdata, max_length) 48 | 49 | print(f"Train data size: {len(traindata)}") 50 | print(f"Test data size: {len(testdata)}") 51 | 52 | return traindata, testdata 53 | 54 | def process_eval_dataset(queries, corpus, qrels): 55 | queries = dict(zip([str(x['_id']) for x in queries], [x['text'] for x in queries])) # Our queries (qid => question) 56 | corpus = dict(zip([str(x['_id']) for x in corpus], [x['text'] for x in corpus])) # Our corpus (cid => document) 57 | relevant_docs = {} # Query ID to relevant documents (qid => set([relevant_cids]) 58 | for qrel in qrels: 59 | qid, corpus_ids = qrel["query-id"], qrel["corpus-id"] 60 | qid = str(qid) 61 | corpus_ids = str(corpus_ids) 62 | if qid not in relevant_docs: 63 | relevant_docs[qid] = set() 64 | relevant_docs[qid].add(corpus_ids) 65 | return queries, corpus, relevant_docs -------------------------------------------------------------------------------- /stage2/trainer_custom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sentence_transformers import SentenceTransformerTrainer 3 | from torch.utils.data import ConcatDataset 4 | from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, Value 5 | from torch.utils.data import DataLoader 6 | from sentence_transformers import SentenceTransformerTrainer 7 | 8 | class CustomTrainer(SentenceTransformerTrainer): 9 | def __init__( 10 | self, 11 | batch_sizes=None, # New parameter: dictionary specifying batch sizes for each sub-dataset 12 | *args, 13 | **kwargs 14 | ): 15 | self.batch_sizes = batch_sizes or {} 16 | super().__init__(*args, **kwargs) 17 | 18 | def get_train_dataloader(self): 19 | if isinstance(self.train_dataset, DatasetDict): 20 | generator = torch.Generator() 21 | if self.args.seed: 22 | generator.manual_seed(self.args.seed) 23 | 24 | data_collator = self.data_collator 25 | batch_samplers = [] 26 | 27 | # 👇 Key logic: iterate over (name, dataset) pairs, assign batch size based on name 28 | for name, dataset in self.train_dataset.items(): 29 | bs = self.batch_sizes.get(name, self.args.train_batch_size) 30 | sampler = self.get_batch_sampler( 31 | dataset, 32 | batch_size=bs, 33 | drop_last=self.args.dataloader_drop_last, 34 | valid_label_columns=data_collator.valid_label_columns, 35 | generator=generator, 36 | ) 37 | batch_samplers.append(sampler) 38 | 39 | concat = ConcatDataset(self.train_dataset.values()) 40 | batch_sampler = self.get_multi_dataset_batch_sampler( 41 | dataset=concat, 42 | batch_samplers=batch_samplers, 43 | generator=generator, 44 | seed=self.args.seed, 45 | ) 46 | 47 | dl_kwargs = dict( 48 | collate_fn=data_collator, 49 | num_workers=self.args.dataloader_num_workers, 50 | pin_memory=self.args.dataloader_pin_memory, 51 | persistent_workers=self.args.dataloader_persistent_workers, 52 | prefetch_factor=self.args.dataloader_prefetch_factor, 53 | batch_sampler=batch_sampler, 54 | ) 55 | 56 | self.accelerator.even_batches = False 57 | self._train_dataloader = self.accelerator.prepare(DataLoader(concat, **dl_kwargs)) 58 | return self._train_dataloader 59 | 60 | # For single dataset or IterableDataset, fall back to default behavior 61 | return super().get_train_dataloader() 62 | -------------------------------------------------------------------------------- /stage1/grad_cache/cachex/functional.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Any, Callable 2 | from functools import partial 3 | 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from .tree_utils import tree_unchunk 8 | 9 | Array = jax.Array 10 | 11 | 12 | def grad_with_cache(f, **grad_kwargs): 13 | def cache_f(params, cache, *args, **kwargs): 14 | return jnp.sum(f(params, *args, **kwargs) * cache) 15 | return jax.grad(cache_f, **grad_kwargs) 16 | 17 | 18 | def encode_scan_fn(f, carry, x): 19 | return carry, f(**x) 20 | 21 | 22 | def cache_grad_scan_fn(f, params, acc, x): 23 | cached_grad, kwargs = x 24 | 25 | def fwd_fn(w): 26 | return f(params=w, **kwargs) 27 | 28 | chunk_grad = grad_with_cache(fwd_fn)(params, cached_grad) 29 | acc = jax.tree_multimap(lambda u, v: u + v, acc, chunk_grad) 30 | return acc, None 31 | 32 | 33 | def chunk_encode(encode_fn): 34 | def f(**xx): 35 | _, hh = jax.lax.scan(partial(encode_scan_fn, encode_fn), 0, xx) 36 | return hh 37 | return f 38 | 39 | 40 | def cache_grad(encode_fn): 41 | def f(params, grad_accumulator, cached_grad, **xx): 42 | grads, _ = jax.lax.scan( 43 | partial(cache_grad_scan_fn, encode_fn, params), grad_accumulator, [cached_grad, xx] 44 | ) 45 | return grads 46 | return f 47 | 48 | 49 | def unchunk_args(axis: int = 0, argnums: Iterable[int] = ()): 50 | def decorator_unchunk(f): 51 | def g(*args, **kwargs): 52 | new_args = list(args) 53 | for i in argnums: 54 | new_args[i] = tree_unchunk(args[i], axis) 55 | return f(*new_args, **kwargs) 56 | 57 | return g 58 | 59 | return decorator_unchunk 60 | 61 | def grad_cached( 62 | f: Callable[..., Array], 63 | policy: Callable[..., bool] = jax.checkpoint_policies.nothing_saveable, 64 | prevent_cse: bool = True 65 | ): 66 | """ 67 | Single-decorator version of grad cache that uses XLA to infer backward pass. 68 | 69 | The forward pass is manually split into chunks and performed sequentially with lax.scan. 70 | We rely on XLA to infer the backward pass and run it in a similar fashion. 71 | 72 | Args: 73 | f: Function to be differentiated. It should take in a single argument and return a jax array of representations. 74 | policy: The sub-batch rematerialization policy. 75 | prevent_cse: Whether to prevent common subexpression elimination. 76 | 77 | Returns: 78 | Decorated gradient cached `f` that expects input to have an extra leading sub-batch dimension, potentially produced by `tree_chunk`. 79 | 80 | A example of usage on a apply function that takes multiple arguments: 81 | 82 | >>> @cachex.grad_cached 83 | ... def fwd(params, batch): 84 | ... return apply(params, **batch) 85 | 86 | >>> src = cachex.tree_chunk(src, 8) 87 | >>> tgt = cachex.tree_chunk(tgt, 8) 88 | 89 | >>> def compute_loss(params, src, tgt): 90 | ... h_src = fwd(params, src) 91 | ... h_tgt = fwd(params, tgt) 92 | ... return loss(h_src, h_tgt) 93 | 94 | >>> grads = jax.grad(compute_loss)(params, src, tgt) 95 | 96 | Here the `compute_loss` function can typically be dropped into a larger training step function. 97 | """ 98 | def cached_f(params, batch): 99 | def scan_f(_, sub_batch): 100 | return None, f(params, sub_batch) 101 | _, reps = jax.lax.scan(jax.checkpoint(scan_f, policy=policy, prevent_cse=prevent_cse), None, batch) 102 | return jnp.reshape(reps, (-1,) + reps.shape[2:]) 103 | return cached_f -------------------------------------------------------------------------------- /stage1/grad_cache/functional.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | from typing import Callable, Union, Tuple, Any 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch import distributed as dist 7 | 8 | from .context_managers import RandContext 9 | 10 | 11 | def cached(func: Callable[..., Tensor]): 12 | """ 13 | A decorator that takes a model call function into a cached compatible version. 14 | :param func: A function that calls the model and return representation tensor. 15 | :return: A function that returns 1) representation leaf tensors for cache construction, 2) a closure function for 16 | the 2nd forward and the cached backward. Call 2) with 1) as argument after calling backward on the loss Tensor. 17 | """ 18 | @wraps(func) 19 | def cache_func(*args, **kwargs): 20 | rnd_state = RandContext() 21 | with torch.no_grad(): 22 | reps_no_grad = func(*args, **kwargs) 23 | if isinstance(reps_no_grad, Tensor): 24 | reps_no_grad = (reps_no_grad, ) 25 | else: 26 | assert all(isinstance(v, Tensor) for v in reps_no_grad) 27 | leaf_reps = tuple(t.detach().requires_grad_() for t in reps_no_grad) 28 | 29 | @wraps(func) 30 | def forward_backward_func(cache_reps: Union[Tensor, Tuple[Tensor]]): 31 | with rnd_state: 32 | reps = func(*args, **kwargs) 33 | if isinstance(reps, Tensor): 34 | reps = (reps,) 35 | if isinstance(cache_reps, Tensor): 36 | cache_reps = (cache_reps,) 37 | assert len(reps) == len(cache_reps) 38 | 39 | surrogate = sum(map(lambda u, v: torch.dot(u.flatten(), v.grad.flatten()), reps, cache_reps), 0) 40 | surrogate.backward() 41 | 42 | return leaf_reps + (forward_backward_func,) 43 | return cache_func 44 | 45 | 46 | def _cat_tensor_list(xx): 47 | if isinstance(xx, list) and len(xx) > 0 and all(isinstance(x, Tensor) for x in xx): 48 | return torch.cat(xx) 49 | else: 50 | return xx 51 | 52 | 53 | def cat_input_tensor(func: Callable[..., Tensor]): 54 | """ 55 | A decorator that concatenates positional and keyword arguments of type List[Tensor] into a single Tensor 56 | on the 0 dimension. This can come in handy dealing with results of representation tensors from multiple 57 | cached forward. 58 | :param func: A loss function 59 | :return: Decorated loss function for cached results. 60 | """ 61 | @wraps(func) 62 | def cat_f(*args, **kwargs): 63 | args_cat = [_cat_tensor_list(x) for x in args] 64 | kwargs_cat = dict((k, _cat_tensor_list(v)) for k, v in kwargs.values()) 65 | return func(*args_cat, **kwargs_cat) 66 | return cat_f 67 | 68 | 69 | def _maybe_gather_tensor(t: Any, axis: int): 70 | if not isinstance(t, Tensor): 71 | return t 72 | gathered = [torch.empty_like(t) for _ in range(dist.get_world_size())] 73 | dist.all_gather(gathered, t) 74 | gathered[dist.get_rank()] = t 75 | return torch.cat(gathered, dim=axis) 76 | 77 | 78 | def gather_input_tensor(func: Callable[..., Tensor], axis=0): 79 | """ 80 | A decorator that all-gather positional and keyword arguments of type Tensor and concatenate them on axis. 81 | Intended to be used with distributed contrastive learning loss. 82 | :param func: A loss function 83 | :param axis: The axis the gathered tensors are concatenated. 84 | :return: Decorated loss function for distributed training. 85 | """ 86 | @wraps(func) 87 | def f(*args, **kwargs): 88 | args_gathered = [_maybe_gather_tensor(x, axis=axis) for x in args] 89 | kwargs_gathered = dict((k, _maybe_gather_tensor(v, axis=axis)) for k, v in kwargs.values()) 90 | return func(*args_gathered, **kwargs_gathered) 91 | return f 92 | -------------------------------------------------------------------------------- /cp-retrieval-server/templates/stats.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | {{ t.site_name }} - {{ t.view_stats }} 6 | 7 | 8 | 32 | 33 | 34 |

{{ t.site_name }} - {{ t.view_stats }}

35 | 36 | 37 |
38 | 39 | 40 |
41 | 42 | 43 | 44 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | {% for date, count in stats %} 56 | 57 | 58 | 59 | 60 | {% endfor %} 61 |
45 | {{ t.total_search_count or "总搜索次数" }}: 46 | {{ stats | map(attribute=1) | sum }} 47 |
{{ t.date }}{{ t.search_count }}
{{ date }}{{ count }}
62 |
63 | 64 |

{{ t.back }}

65 | 66 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /stage1/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' 3 | import argparse 4 | import torch 5 | from datasets import load_dataset 6 | from sentence_transformers import SentenceTransformer 7 | from transformers import TrainingArguments 8 | from loss import InfoNCELoss_gradcache, InfoNCELoss_gradcache_multipos, GroupInfoNCELoss_gradcache_multipos 9 | from dataset import ContrastiveDataset, ContrastiveDataCollator 10 | from custom_trainer import ContrastiveTrainer 11 | from eval_metric import compute_metrics_custom 12 | from dataset_process import process_dataset 13 | 14 | def str2bool(v): 15 | if isinstance(v, bool): 16 | return v 17 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 18 | return True 19 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 20 | return False 21 | else: 22 | raise argparse.ArgumentTypeError('Boolean value expected.') 23 | 24 | parser = argparse.ArgumentParser() 25 | 26 | parser.add_argument("--model_path", type=str, 27 | default="Salesforce/SFR-Embedding-Code-2B_R") 28 | parser.add_argument("--max_length", type=int, default=1024) 29 | parser.add_argument("--chunk_size", type=int, default=2) 30 | parser.add_argument("--multi_pos", type=int, default=16) 31 | parser.add_argument("--per_device_batch_size", type=int, default=128) 32 | parser.add_argument("--num_train_epochs", type=int, default=20) 33 | parser.add_argument("--lr", type=float, default=3e-8) 34 | parser.add_argument("--temperature", type=float, default=0.07) 35 | parser.add_argument("--loss_type", type=str, choices=["GroupInfoNCE", "InfoNCE", "InfoNCE_multipos"], 36 | default="GroupInfoNCE") 37 | parser.add_argument("--use_data_augmentation", type=str2bool, default=True) 38 | parser.add_argument("--eval_only", type=str2bool, default=False) 39 | 40 | if __name__ == "__main__": 41 | args = parser.parse_args() 42 | 43 | model_name = args.model_path.split('/')[-1] 44 | 45 | # Load dataset from HuggingFace 46 | data = load_dataset('coldchair16/CPRet-data', 'PCPCD') 47 | traindata, testdata = process_dataset(data, args.max_length) 48 | 49 | model_kwargs = {"torch_dtype": torch.bfloat16} # TODO: Make sure your GPU and model support bfloat16 50 | 51 | model = SentenceTransformer(args.model_path, trust_remote_code=True, model_kwargs=model_kwargs, device='cpu') 52 | model.tokenizer.model_max_length = args.max_length 53 | model.max_seq_length = args.max_length 54 | 55 | dataset = ContrastiveDataset( 56 | traindata, model.tokenizer, max_length=args.max_length, multi_pos=args.multi_pos, 57 | use_data_augmentation=args.use_data_augmentation, 58 | ) 59 | val_dataset = ContrastiveDataset( 60 | testdata, model.tokenizer, max_length=args.max_length, multi_pos=args.multi_pos, 61 | use_data_augmentation=False, 62 | eval_mode=True, 63 | ) 64 | 65 | embedding_dim = model.get_sentence_embedding_dimension() 66 | print("Model embedding dimension: ", embedding_dim) 67 | 68 | if args.loss_type == 'GroupInfoNCE': 69 | loss_fn = GroupInfoNCELoss_gradcache_multipos(temperature=args.temperature) 70 | elif args.loss_type == 'InfoNCE': 71 | loss_fn = InfoNCELoss_gradcache(temperature=args.temperature) 72 | elif args.loss_type == 'InfoNCE_multipos': 73 | loss_fn = InfoNCELoss_gradcache_multipos(temperature=args.temperature) 74 | 75 | data_collator = ContrastiveDataCollator(model.tokenizer) 76 | 77 | save_name = args.model_path.split('/')[-1] 78 | save_name += f"-length{args.max_length}" 79 | save_name += f"-bs_per_device{args.per_device_batch_size}*gpus{os.environ['WORLD_SIZE']}" 80 | save_name += f"-epochs{args.num_train_epochs}" 81 | save_name += f"-temperature{args.temperature}" 82 | save_name += f"-lr{args.lr}" 83 | save_name += f"-multi_pos{args.multi_pos}" 84 | save_name += f"-use_data_augmentation_{args.use_data_augmentation}" 85 | save_name += f"-loss{args.loss_type}" 86 | 87 | print(f"Save name: {save_name}") 88 | 89 | training_args = TrainingArguments( 90 | output_dir="./results/" + save_name, 91 | per_device_train_batch_size=args.per_device_batch_size, 92 | num_train_epochs=args.num_train_epochs, 93 | logging_dir="./logs/" + save_name, 94 | logging_steps=1, 95 | eval_strategy="epoch", 96 | # eval_accumulation_steps=5, 97 | save_strategy="epoch", 98 | # save_steps=1, 99 | bf16=True, 100 | dataloader_num_workers=4, 101 | dataloader_drop_last=True, 102 | save_total_limit=1, 103 | remove_unused_columns=False, 104 | learning_rate=args.lr, 105 | weight_decay=0.01, 106 | max_grad_norm=1.0, 107 | gradient_accumulation_steps=1, 108 | lr_scheduler_type="constant", 109 | ) 110 | 111 | trainer = ContrastiveTrainer( 112 | multi_pos=args.multi_pos, 113 | chunk_size=args.chunk_size, 114 | model=model, 115 | args=training_args, 116 | train_dataset=dataset, 117 | eval_dataset=val_dataset, 118 | loss_fn=loss_fn, 119 | data_collator=data_collator, 120 | compute_metrics=compute_metrics_custom, 121 | ) 122 | 123 | if args.eval_only: 124 | eval_results = trainer.evaluate() 125 | print(eval_results) 126 | else: 127 | trainer.train() 128 | -------------------------------------------------------------------------------- /stage1/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import random 4 | from datasets import load_from_disk 5 | from transformers import AutoModel, AutoTokenizer, DataCollatorWithPadding 6 | 7 | 8 | # Define the ContrastiveDataset 9 | class ContrastiveDataset(Dataset): 10 | def __init__( 11 | self, 12 | data, tokenizer, 13 | max_length = 1024, 14 | eval_mode = False, 15 | prompt_question = "", 16 | prompt_code = "", 17 | eval_max_solutions_per_question = 20, 18 | multi_pos = 1, 19 | use_data_augmentation = True, 20 | ): 21 | self.data = data 22 | self.tokenizer = tokenizer 23 | self.max_length = max_length 24 | self.eval_mode = eval_mode 25 | self.prompt_question = prompt_question 26 | self.prompt_code = prompt_code 27 | self.multi_pos = multi_pos 28 | self.use_data_augmentation = use_data_augmentation 29 | 30 | self.eval_max_solutions_per_question = eval_max_solutions_per_question 31 | 32 | self.len = len(self.data) 33 | 34 | # for debug 35 | # if (eval_mode): 36 | # self.len = 8 37 | 38 | def __len__(self): 39 | return self.len 40 | 41 | def __getitem__(self, idx): 42 | if (self.eval_mode or not self.use_data_augmentation): 43 | query = self.data[idx]['question'] 44 | else: 45 | query = random.choice([ 46 | self.data[idx]['question'], 47 | self.data[idx]['question-main'], 48 | ]) 49 | 50 | query = self.prompt_question + query 51 | query_encoding = self.tokenizer([query], return_tensors="pt", padding='max_length', max_length=self.max_length, truncation=True) 52 | 53 | if (self.eval_mode): 54 | n = min(self.eval_max_solutions_per_question, len(self.data[idx]['solutions'])) 55 | doc = self.data[idx]['solutions'][:n] 56 | doc = [self.prompt_code + d for d in doc] 57 | if (n < self.eval_max_solutions_per_question): 58 | doc = doc + [""] * (self.eval_max_solutions_per_question - n) 59 | doc_encoding = self.tokenizer(doc, return_tensors="pt", padding='max_length', max_length=self.max_length, truncation=True) 60 | doc_id = torch.tensor(n) 61 | else: 62 | if (self.multi_pos == 1): 63 | doc = random.choice(self.data[idx]['solutions']) 64 | doc = self.prompt_code + doc 65 | doc_encoding = self.tokenizer([doc], return_tensors="pt", padding='max_length', max_length=self.max_length, truncation=True) 66 | doc_encoding['input_ids'] = doc_encoding['input_ids'].squeeze(0) 67 | doc_encoding['attention_mask'] = doc_encoding['attention_mask'].squeeze(0) 68 | doc_id = torch.tensor(1) 69 | else: 70 | n = len(self.data[idx]['solutions']) 71 | if (n < self.multi_pos): 72 | doc = self.data[idx]['solutions'] * (self.multi_pos // n) + self.data[idx]['solutions'][:self.multi_pos % n] 73 | else: 74 | doc = random.sample(self.data[idx]['solutions'], self.multi_pos) 75 | doc = [self.prompt_code + d for d in doc] 76 | doc_encoding = self.tokenizer(doc, return_tensors="pt", padding='max_length', max_length=self.max_length, truncation=True) 77 | doc_id = torch.tensor(n) 78 | 79 | # Return a dictionary with the required fields 80 | return { 81 | # shape (max_length) 82 | 'input_ids': query_encoding['input_ids'].squeeze(0), # Standard field for Trainer 83 | 'attention_mask': query_encoding['attention_mask'].squeeze(0), # Standard field for Trainer 84 | 'id' : torch.tensor(idx), 85 | 86 | # shape : (1, max_length) or (n, max_length) 87 | 'doc_input_ids': doc_encoding['input_ids'], # Additional field for document 88 | 'doc_attention_mask': doc_encoding['attention_mask'], # Additional field for document 89 | 'doc_id': doc_id, # Additional field for evaluation 90 | } 91 | 92 | # Custom DataCollator to handle additional fields 93 | from transformers import DataCollator 94 | 95 | class ContrastiveDataCollator: 96 | def __init__(self, tokenizer, padding=True, truncation=True, max_length=None): 97 | self.tokenizer = tokenizer 98 | self.padding = padding 99 | self.truncation = truncation 100 | self.max_length = max_length 101 | 102 | def __call__(self, features): 103 | batch = {} 104 | if 'input_ids' in features[0]: 105 | batch['input_ids'] = torch.stack([f['input_ids'] for f in features]) 106 | if 'attention_mask' in features[0]: 107 | batch['attention_mask'] = torch.stack([f['attention_mask'] for f in features]) 108 | if 'id' in features[0]: 109 | batch['id'] = torch.stack([f['id'] for f in features]) 110 | 111 | if 'doc_input_ids' in features[0]: 112 | batch['doc_input_ids'] = torch.stack([f['doc_input_ids'] for f in features]) 113 | if 'doc_attention_mask' in features[0]: 114 | batch['doc_attention_mask'] = torch.stack([f['doc_attention_mask'] for f in features]) 115 | if 'doc_id' in features[0]: 116 | batch['doc_id'] = torch.stack([f['doc_id'] for f in features]) 117 | return batch -------------------------------------------------------------------------------- /stage1/grad_cache/pytorch_lightning/readme.md: -------------------------------------------------------------------------------- 1 | # PL_GradCache 2 | 3 | This is an experimental folder to provide example of using Grad Cache with PyTorch Lightning (pl), tested on pl version '2.2.0.post0' with Multi-GPUs and Mix-Precision (fp-16). Pytorch Metric Learning is required to install as well for contrastive loss calculation. 4 | 5 | - [Wandb Logging Experiments for Sanity Test](https://api.wandb.ai/links/xyznlp/nmf8d551) 6 | 7 | ### Installation 8 | 9 | After GradCache is installed, do 10 | 11 | ``` 12 | cd GradCache/src/grad_cache/pytorch_lightning 13 | python -m venv plgc 14 | . ./plgc/bin/activate 15 | pip3 install -U pip 16 | pip3 install -r requirements.txt 17 | ``` 18 | 19 | ### Reproducing Wandb Experiments 20 | 21 | ``` 22 | # 1-gpu 23 | python pl_example.py --gpus 1 --batch_size 16 24 | # 2-gpus 25 | python pl_example.py --gpus 2 --batch_size 8 26 | # 1-gpu, gradcache 27 | python pl_example.py --gpus 1 --batch_size 16 --use_gc --gc_minibatch_size 2 28 | # 2-gpus, gradcache 29 | python pl_example.py --gpus 2 --batch_size 8 --use_gc --gc_minibatch_size 2 30 | ``` 31 | 32 | Optionally, do mix-precision training with `--precision 16`, run different ddp_backend with `--ddp_backend {gloo/nccl/etc.}` 33 | 34 | ### Example 35 | 36 | Run `python pl_example.py` with the following flags. 37 | 38 | * `--use_gc` activates GradCache. 39 | * `--gc_minibatch_size {minibatch_size}` defines the batch size that each GPU needs to hold its memory into. If we specify `--gpus 2 --batch_size 8 --gc_minibatch 2`, for example, the model would be trained with batch size 8 * 2 = 16, the trainer would split each batch on each GPU (8 data samples) into 4 chunks of mini batches (2 data samples per mini batch). Set this to 1 gives the minimal possible gpu memory usage. 40 | 41 | ### Summary 42 | 43 | - Add `pl_gradcache.py` as customized GradCache on PyTorch Lightning. 44 | - Use manual backward in gradcache by calling `lightning_trainer.manual_backward(loss)` instead of using `loss.backward()` (this requires changing gradcache). 45 | - Set gradcache `no_sync_except_last=True` in multi-GPU case. 46 | 47 | ### Changes to the original GradCache 48 | 49 | #### File Change 50 | - `pl_gradcache.py` is the GradCache we will run on PyTorch Lightning (pl) with Distributed Data Parallel (ddp). 51 | 52 | #### Change in Optimization 53 | - In pt ddp setting, we need to first set `lightning_trainer.automatic_optimization=False` for us to customize calling backward. 54 | - See [the pl optimization doc](https://lightning.ai/docs/pytorch/stable/common/optimization.html) for implementation details, make sure that we are calling `self.optimizers()` instead of creating one by ourselves: if we do `self.optimizer = optimizer` in `self.configure_optimizers()`, this is not correct as it initializes a base optimizer in pt, but `self.optimizers()` is a wrapper for that. The base optimizer does not have the correct access to ddp and logging. 55 | - Then, replace all `loss.backward()` in GradCache with `lightning_trainer.manual_backward(loss)`. 56 | 57 | #### Change in GradCache 58 | - Set `no_sync_except_last=True` in Multi-GPU case to avoid unnecessary gradient reduction in the last step of gradcache. 59 | 60 | #### If you want to run GradCache in PyTorch Lightning with Multi-GPUs 61 | - In short, you are good to go by not worrying about this part. But here are the key changes in the original gradcache that are necessary for this to work. 62 | - we have two options to use gradcache in pl ddp setting and call `self.init_gc(scaler, gc_model)`. 63 | - We can set `gc_model=pytorch lightning trainer`. 64 | - PyTorch Lightning would then wrap the base model (transformer) by their implementation of DDP. 65 | - In this case, just set `no_sync_except_last = False`, because lightning will handle gradient sync before `optimizer.step()`. 66 | - Set `no_sync_except_last = True` in this case does not work as the base model in gradcache is the transformer, which causes gradcache assertion and `model.no_sync` not available error. 67 | - Or, we can just change gradcache instead (remove assert DDP and `model.no_sync`). 68 | - The only downside of this approach is that the training may take a little longer, because gradient sync is done on the full batch size (expected by pytorch lightning) instead of the last minibatch (expected by gradcache). But based on some sanity runs, it is ok (less than a 10% runtime increase). 69 | - We can set `gc_model=pytorch lightning trainer.strategy.model`, i.e. the wrapped base model by PyTorch DDP. 70 | - This is tricky as PyTorch Lightning uses a parameter `require_backward_grad_sync` to determine whether gradients would be synced across GPUs. 71 | - Firstly, Pytorch Lightning overrides the PyTorch DDP by their own implementation and set `require_backward_grad_sync=False` before each training step (when `automatic optimization=False`). Then, it is set it to True **after** each training step. 72 | - The issue here is that gradcache needs the gradient to be synced in the last backward step, which happens inside the training step hook of pytorch lightning. Thus, what we can only do is to set this variable manually before the last backward step in gradcache - we cannot set it outside of gradcache either, because the first backward of gradcache to do gradient checkpointing should NOT sync gradient (this is the point of gradcache essentially). 73 | - Thus, we do `model.require_backward_grad_sync=True` at the very end of gradcache - before the backward of the last minibatch surrogate. 74 | - The advantage of this is that we can do `no_sync_except_last` as what gradcache hopes us to do (no runtime increase). The downside is that we need to modify gradcache in a very hacky way. This is the default setup. 75 | -------------------------------------------------------------------------------- /TestCases.md: -------------------------------------------------------------------------------- 1 | 2 | # 🔍 Test Cases of CPRet 3 | 4 | ## 📌 System Launch Overview 5 | 6 | * Our open-source **problem retrieval platform** was launched on **May 21, 2025**: https://cpret.online/ (Before : http://1.94.255.218:5000/) 7 | * In less than a week, the platform has recorded nearly **2,000 search queries**. 8 | * The blog post introducing the system on **Codeforces** has received over **250 upvotes** and positive community feedback: https://codeforces.com/blog/entry/143098 9 | * **Use cases** include: 10 | 11 | * **Similar Problem Retrieval**: Assisting contestants in expanding their problem-solving perspective and addressing knowledge blind spots. 12 | * **Duplicate Problem Retrieval**: Helping problem setters identify previously seen ideas or solutions early on. 13 | 14 | --- 15 | 16 | ## 🧪 Test Case 1: Duplicate Retrieval in a Recent Contest 17 | 18 | **Target Contest** 19 | 20 | * **Name**: The 2025 CCPC National Invitational Contest (Northeast), The 19th Northeast Collegiate Programming Contest 21 | * **Url**: https://codeforces.com/gym/105924 22 | * **Date**: May 25, 2025 23 | * **Total Problems**: 12 24 | 25 | **Key Findings** 26 | 27 | * The system successfully identified **6 problems** with highly similar or identical historical counterparts. 28 | * Only the **top 3 results per query** were manually inspected, indicating a **duplicate rate of at least 50%**—likely higher in practice. 29 | * Under current scoring rules, solving these 6 problems is enough to secure a **silver medal**, and just one additional easy problem would result in a **gold medal**, raising **serious fairness concerns**. 30 | 31 | **Detection Summary** 32 | 33 | | Contest Problem | Matched Historical Problem | Similarity Level | Rank | 34 | | ------------------------------------------------------------------------ | -------------------------------------------------------------------------------- | ---------------- | ---- | 35 | | [A. GD Ultimate Rhythm Lab](https://codeforces.com/gym/105924/problem/A) | [Nowcoder - 小睿睿的数列](https://ac.nowcoder.com/acm/problem/24479) | Same approach | 1 | 36 | | [D. Defend the Carrot](https://codeforces.com/gym/105924/problem/D) | [SPOJ - UOFTBB](https://www.spoj.com/problems/UOFTBB/) | Almost identical | 1 | 37 | | [E. Tree Edge Removal](https://codeforces.com/gym/105924/problem/E) | [Luogu - \[JRKSJ R7\] 茎](https://www.luogu.com.cn/problem/P8935) | Almost identical | 1 | 38 | | [F. Youthful Oath II](https://codeforces.com/gym/105924/problem/F) | [Codeforces - 80B Depression](https://codeforces.com/problemset/problem/80/B) | Almost identical | 1 | 39 | | [J. Kingdom: Memories](https://codeforces.com/gym/105924/problem/J) | [AtCoder - R Walk](https://atcoder.jp/contests/dp/tasks/dp_r) | Almost identical | 3 | 40 | | [L. Bathhouse](https://codeforces.com/gym/105924/problem/L) | [Codeforces - 219E Parking Lot](https://codeforces.com/problemset/problem/219/E) | Same approach | 2 | 41 | 42 | --- 43 | 44 | ## 🧪 Test Case 2: Similar Problem Retrieval – MEX Variants 45 | 46 | We conducted a query with the classic problem "**interval MEX**" to identify its **variants across different contests**, aiming to showcase the system’s utility for **idea expansion and knowledge transfer**. 47 | 48 | ### Query Problem Description 49 | 50 | > Given a sequence of $n$ natural numbers $a[1..n]$, answer $m$ queries. 51 | > Each query specifies a range $[l, r]$, and asks for $\mathrm{mex}({a_l, a_{l+1}, \dots, a_r})$ — 52 | > the **minimum excluded value** in the subarray. 53 | > 54 | > This problem can be solved in **$O((n + m) log n)$** time using segment trees. 55 | 56 | ### Retrieval Results 57 | 58 | | Rank | Title | Description | 59 | | ---- | ---------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------- | 60 | | 1 | [Luogu P4137: RMQ Problem / mex](https://www.luogu.com.cn/problem/P4137) | Original problem | 61 | | 2 | [LOJ 6908: THUPC 2024 Prelim - "Matryoshka"](https://loj.ac/p/6908) | MEX of all subarrays of length $k$, then take the MEX of those | 62 | | 5 | [AtCoder ABC194E: Mex Min](https://atcoder.jp/contests/abc194/tasks/abc194_e) | MEX of all subarrays of length $k$, then take the **minimum** | 63 | | 6 | [Luogu P10032: Mex of Sequence](https://www.luogu.com.cn/problem/P10032) | Repeated operations: $a'[i] = \mathrm{mex}(a \setminus a[i])$ | 64 | | 11 | [Nowcoder 237670: “经典”问题](https://ac.nowcoder.com/acm/problem/237670) | MEX queries on permutations, optimized to $O(n + m)$ | 65 | | 14 | [Luogu P8087: 『JROI-5』Interval](https://www.luogu.com.cn/problem/P8087) | MEX of all subarrays of length $k$, then take the **maximum** | 66 | | 15 | [AtCoder ABC290C: Max MEX](https://atcoder.jp/contests/abc290/tasks/abc290_c) | MEX of all **subsequences** of length $k$, then take the **minimum** | 67 | | 16 | [Codeforces 1436E: Complicated Computations](https://codeforces.com/problemset/problem/1436/E) | MEX of all subarrays, then take the MEX again | 68 | | 23 | [AtCoder ABC330E: Mex and Update](https://atcoder.jp/contests/abc330/tasks/abc330_e) | Support element modification or querying full array MEX | 69 | | 24 | [Luogu P11837: Making Mexes B](https://www.luogu.com.cn/problem/P11837) | Minimum edits to ensure $\mathrm{mex}(a) = i$ | 70 | -------------------------------------------------------------------------------- /stage1/grad_cache_custom.py: -------------------------------------------------------------------------------- 1 | from grad_cache import GradCache 2 | from typing import List, Union, Callable, Any 3 | from contextlib import nullcontext 4 | import torch 5 | from torch import nn, Tensor 6 | from torch.cuda.amp import GradScaler, autocast 7 | from grad_cache.context_managers import RandContext 8 | 9 | class GradCacheCustom(GradCache): 10 | def __init__( 11 | self, 12 | multi_pos : int = 1, 13 | backward_fn = None, 14 | compute_loss_context_manager = None, 15 | *args, 16 | **kwargs, 17 | ): 18 | super().__init__(*args, **kwargs) 19 | self.multi_pos = multi_pos 20 | self.backward_fn = backward_fn 21 | self.compute_loss_context_manager = None 22 | self.get_rep_fn = lambda x: x['sentence_embedding'] 23 | 24 | def model_call(self, model: nn.Module, model_input): 25 | """ 26 | Literally call the model's __call__ method. 27 | :param model: model to be called 28 | :param model_input: input to the model call 29 | :return: model output 30 | """ 31 | with self.compute_loss_context_manager(): 32 | return model(model_input) 33 | 34 | def compute_loss(self, *reps: Tensor, **loss_kwargs) -> Tensor: 35 | """ 36 | Compute the loss based on the representation tensors. The tensors should be ordered same as the list of models 37 | registered in this GradCache class instance. 38 | :param reps: Representations for computing the loss. 39 | :param loss_kwargs: Keyword arguments input to the loss function. 40 | :return: the loss tensor. 41 | """ 42 | with self.compute_loss_context_manager(): 43 | loss = self.loss_fn(*reps, **loss_kwargs) 44 | return loss 45 | 46 | def build_cache(self, *reps: Tensor, **loss_kwargs) -> [List[Tensor], Tensor]: 47 | """ 48 | Compute the gradient cache 49 | :param reps: Computed representations from all encoder models 50 | :param loss_kwargs: Extra keyword arguments to the loss function 51 | :return: A tuple of a) gradient cache for each encoder model, and b) loss tensor 52 | """ 53 | # for r in reps: 54 | # print(f"r.shape : {r.shape}") 55 | reps = [r.detach().requires_grad_() for r in reps] 56 | with autocast() if self.fp16 else nullcontext(): 57 | loss = self.compute_loss(*reps, **loss_kwargs) 58 | 59 | if self.fp16: 60 | self.backward_fn(self.scaler.scale(loss)) 61 | else: 62 | self.backward_fn(loss) 63 | 64 | cache = [r.grad for r in reps] 65 | 66 | return cache, loss.detach() 67 | 68 | def forward_backward( 69 | self, 70 | model: nn.Module, 71 | model_inputs, 72 | cached_gradients: List[Tensor], 73 | random_states: List[RandContext], 74 | no_sync_except_last: bool = False 75 | ): 76 | """ 77 | Run the second forward and the backward pass to compute gradient for a model. 78 | :param model: Encoder model. 79 | :param model_inputs: Chunked input to the encoder model. 80 | :param cached_gradients: Chunked gradient cache tensor for each input. 81 | :param random_states: Each input's device random state during the first forward. 82 | :param no_sync_except_last: If True, under distributed setup, only trigger gradient reduction across processes 83 | for the last sub-batch's forward-backward pass. 84 | """ 85 | if no_sync_except_last: 86 | sync_contexts = [model.no_sync for _ in range(len(model_inputs) - 1)] + [nullcontext] 87 | else: 88 | sync_contexts = [nullcontext for _ in range(len(model_inputs))] 89 | 90 | for x, state, gradient, sync_context in zip(model_inputs, random_states, cached_gradients, sync_contexts): 91 | with sync_context(): 92 | with state: 93 | y = self.model_call(model, x) 94 | reps = self.get_reps(y) 95 | 96 | surrogate = torch.dot(reps.flatten(), gradient.flatten()) 97 | self.backward_fn(surrogate) 98 | 99 | def cache_step( 100 | self, 101 | *model_inputs, 102 | no_sync_except_last: bool = False, 103 | **loss_kwargs 104 | ) -> Tensor: 105 | """ 106 | Run a cached step to compute gradient over the inputs. 107 | :param model_inputs: Input to each encoder model. Should be in similar order as the class's model. 108 | :param no_sync_except_last: If True, under distributed setup, for each model, only trigger gradient reduction 109 | across processes for the last sub-batch's forward-backward pass. 110 | :param loss_kwargs: Additional keyword arguments to the loss function. 111 | :return: The current's loss. 112 | """ 113 | all_reps = [] 114 | all_rnd_states = [] 115 | 116 | if no_sync_except_last: 117 | assert all(map(lambda m: isinstance(m, nn.parallel.DistributedDataParallel), self.models)), \ 118 | 'Some of models are not wrapped in DistributedDataParallel. Make sure you are running DDP with ' \ 119 | 'proper initializations.' 120 | 121 | model_inputs = model_inputs[0] 122 | model_inputs = [ 123 | { 124 | 'input_ids': model_inputs['input_ids'], 125 | 'attention_mask': model_inputs['attention_mask'], 126 | }, 127 | { 128 | 'input_ids': model_inputs['doc_input_ids'], 129 | 'attention_mask': model_inputs['doc_attention_mask'], 130 | } 131 | ] 132 | 133 | if (self.multi_pos > 1): 134 | for k in ['input_ids', 'attention_mask']: 135 | model_inputs[1][k] = model_inputs[1][k].flatten(0, 1) 136 | 137 | model_inputs = [self.split_inputs(x, chunk_size) for x, chunk_size in zip(model_inputs, self.chunk_sizes)] 138 | # print('model_inputs : ', model_inputs) 139 | 140 | for model, x in zip(self.models, model_inputs): 141 | model_reps, rnd_states = self.forward_no_grad(model, x) 142 | all_reps.append(model_reps) 143 | all_rnd_states.append(rnd_states) 144 | 145 | cache, loss = self.build_cache(*all_reps, **loss_kwargs) 146 | cache = [c.split(chunk_size) for c, chunk_size in zip(cache, self.chunk_sizes)] 147 | 148 | for model, x, model_cache, rnd_states in zip( 149 | self.models, model_inputs, cache, all_rnd_states): 150 | self.forward_backward(model, x, model_cache, rnd_states, no_sync_except_last=no_sync_except_last) 151 | 152 | return loss -------------------------------------------------------------------------------- /stage1/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import Any, Dict 5 | 6 | class InfoNCELoss_gradcache(nn.Module): 7 | def __init__(self, temperature: float = 0.07): 8 | super(InfoNCELoss_gradcache, self).__init__() 9 | self.temperature = temperature 10 | 11 | def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, **loss_kwargs) -> torch.Tensor: 12 | # Convert bf16 to fp32 for numerical stability 13 | query_embeddings = query_embeddings.float() 14 | doc_embeddings = doc_embeddings.float() 15 | 16 | # Gather embeddings across distributed processes 17 | query_embeddings = self.gather_tensor(query_embeddings) 18 | doc_embeddings = self.gather_tensor(doc_embeddings) 19 | 20 | # Normalize the embeddings to unit vectors 21 | query_embeddings = F.normalize(query_embeddings, p=2, dim=1) 22 | doc_embeddings = F.normalize(doc_embeddings, p=2, dim=1) 23 | 24 | return self._compute_loss(query_embeddings, doc_embeddings) 25 | 26 | def _compute_loss(self, query_embeddings, doc_embeddings): 27 | # Positive similarity (dot product between aligned query-doc pairs) 28 | sim_u = torch.sum(query_embeddings * doc_embeddings, dim=1) / self.temperature 29 | sim_u = torch.exp(sim_u) 30 | 31 | # Similarities with all queries and docs (positive + negatives) 32 | sim_v = torch.sum(torch.exp(torch.matmul(query_embeddings, doc_embeddings.T) / self.temperature), dim=1) + \ 33 | torch.sum(torch.exp(torch.matmul(query_embeddings, query_embeddings.T) / self.temperature), dim=1) 34 | # Remove self-similarity term 35 | sim_v = sim_v - torch.exp(torch.tensor(1 / self.temperature, device=sim_v.device)) 36 | 37 | return -torch.log(sim_u / sim_v).mean() 38 | 39 | def gather_tensor(self, t): 40 | # All-gather operation for distributed training 41 | t_new = torch.distributed.nn.all_gather(t) 42 | t_new = torch.cat(t_new, dim=0) 43 | return t_new 44 | 45 | 46 | # Version that uses sum of positive similarities in numerator 47 | class InfoNCELoss_gradcache_multipos(InfoNCELoss_gradcache): 48 | def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, **loss_kwargs) -> torch.Tensor: 49 | query_embeddings = query_embeddings.float() 50 | doc_embeddings = doc_embeddings.float() 51 | 52 | query_embeddings = self.gather_tensor(query_embeddings) 53 | doc_embeddings = self.gather_tensor(doc_embeddings) 54 | 55 | # Reshape to [batch_size, num_pos, dim] 56 | doc_embeddings = doc_embeddings.reshape(query_embeddings.shape[0], -1, query_embeddings.shape[1]) 57 | 58 | query_embeddings = F.normalize(query_embeddings, p=2, dim=1) 59 | doc_embeddings = F.normalize(doc_embeddings, p=2, dim=2) 60 | 61 | sim_mat = torch.exp(query_embeddings @ query_embeddings.T / self.temperature) 62 | sim_v = torch.sum(sim_mat, dim=1) - torch.exp(torch.tensor(float(1.0) / self.temperature, device=sim_mat.device)) 63 | 64 | sim_w = torch.einsum("ik, jnk -> ijn", query_embeddings, doc_embeddings) / self.temperature 65 | sim_w = torch.exp(sim_w) # shape: [bs, bs, pos] 66 | sim_w = torch.sum(sim_w, dim=2) # shape: [bs, bs] 67 | sim_u = sim_w.diagonal() # positive similarities on the diagonal 68 | sim_w = torch.sum(sim_w, dim=1) # total similarities 69 | 70 | return -torch.log(sim_u / (sim_w + sim_v)).mean() 71 | 72 | 73 | # Version that separates query-doc and query-query terms 74 | class InfoNCELoss_gradcache_multipos(InfoNCELoss_gradcache): 75 | def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, **loss_kwargs) -> torch.Tensor: 76 | query_embeddings = query_embeddings.float() 77 | doc_embeddings = doc_embeddings.float() 78 | 79 | query_embeddings = self.gather_tensor(query_embeddings) 80 | doc_embeddings = self.gather_tensor(doc_embeddings) 81 | 82 | doc_embeddings = doc_embeddings.reshape(query_embeddings.shape[0], -1, query_embeddings.shape[1]) 83 | 84 | query_embeddings = F.normalize(query_embeddings, p=2, dim=1) 85 | doc_embeddings = F.normalize(doc_embeddings, p=2, dim=2) 86 | 87 | sim_mat = torch.exp(query_embeddings @ query_embeddings.T / self.temperature) 88 | sim_v = torch.sum(sim_mat, dim=1) - torch.exp(torch.tensor(float(1.0) / self.temperature, device=sim_mat.device)) 89 | 90 | # Positive similarities: shape [bs, pos] 91 | sim_a = torch.exp(torch.sum(query_embeddings.unsqueeze(1) * doc_embeddings, dim=2) / self.temperature) 92 | 93 | sim_w = torch.einsum("ik, jnk -> ijn", query_embeddings, doc_embeddings) / self.temperature 94 | sim_w = torch.exp(sim_w) # shape: [bs, bs, pos] 95 | 96 | sim_p = torch.sum(sim_w, dim=2) # [bs, bs] 97 | sim_u_origin = sim_p.diagonal() 98 | sim_q = torch.sum(sim_p, dim=1) 99 | 100 | return torch.mean(-torch.log(sim_a / ((sim_q + sim_v - sim_u_origin).unsqueeze(1) + sim_a))) 101 | 102 | 103 | # Version that implements Group-InfoNCE with regularization 104 | class GroupInfoNCELoss_gradcache_multipos(InfoNCELoss_gradcache): 105 | def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, **loss_kwargs) -> torch.Tensor: 106 | query_embeddings = query_embeddings.float() 107 | doc_embeddings = doc_embeddings.float() 108 | 109 | query_embeddings = self.gather_tensor(query_embeddings) 110 | doc_embeddings = self.gather_tensor(doc_embeddings) 111 | 112 | doc_embeddings = doc_embeddings.reshape(query_embeddings.shape[0], -1, query_embeddings.shape[1]) 113 | 114 | query_embeddings = F.normalize(query_embeddings, p=2, dim=1) 115 | doc_embeddings = F.normalize(doc_embeddings, p=2, dim=2) 116 | 117 | # sim_a: similarity between each query and its group of positive docs 118 | sim_a = torch.einsum('ik, ijk -> ij', query_embeddings, doc_embeddings) 119 | loss_penalty = torch.var(sim_a, dim=1).mean() # penalty: variance within group 120 | 121 | sim_b = torch.mean(sim_a, dim=1) 122 | sim_b = torch.exp(sim_b / self.temperature) 123 | 124 | # query-to-query similarity matrix (excluding self-similarity) 125 | sim_p = torch.einsum('ik, jk -> ij', query_embeddings, query_embeddings) 126 | sim_p = torch.exp(sim_p / self.temperature) 127 | sim_p = torch.sum(sim_p, dim=1) - torch.exp(torch.tensor(1 / self.temperature, device=sim_p.device)) 128 | 129 | # query-to-all-doc-groups similarity 130 | sim_q = torch.einsum('ik, jpk -> ijp', query_embeddings, doc_embeddings) 131 | sim_q = torch.mean(sim_q, dim=2) # [bs, bs] 132 | sim_q = torch.exp(sim_q / self.temperature) 133 | sim_q = torch.sum(sim_q, dim=1) 134 | 135 | loss = -torch.log(sim_b / (sim_p + sim_q)).mean() 136 | 137 | if torch.distributed.get_rank() == 0: 138 | print(f"\nloss_penalty: {loss_penalty} / T^2 = {loss_penalty / self.temperature ** 2}") 139 | print(f"loss: {loss}") 140 | 141 | return loss + loss_penalty / self.temperature ** 2 142 | -------------------------------------------------------------------------------- /stage1/eval_metric.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from sentence_transformers.evaluation import InformationRetrievalEvaluator 6 | from sentence_transformers.similarity_functions import SimilarityFunction 7 | from typing import TYPE_CHECKING, Callable 8 | from torch import Tensor 9 | from tqdm import trange 10 | 11 | # Modified InformationRetrievalEvaluator to support directly passing in embeddings 12 | class CustomRetrievalEvaluator(InformationRetrievalEvaluator): 13 | def __init__( 14 | self, 15 | corpus_chunk_size: int = 50000, 16 | mrr_at_k: list[int] = [10], 17 | ndcg_at_k: list[int] = [10, 100], 18 | accuracy_at_k: list[int] = [1, 3, 5, 10], 19 | precision_recall_at_k: list[int] = [1, 3, 5, 10], 20 | map_at_k: list[int] = [10, 100], 21 | show_progress_bar: bool = False, 22 | batch_size: int = 32, 23 | name: str = "", 24 | write_csv: bool = True, 25 | truncate_dim: int | None = None, 26 | score_functions: dict[str, Callable[[Tensor, Tensor], Tensor]] | None = { 27 | "cos": SimilarityFunction.to_similarity_fn("cosine") 28 | }, 29 | main_score_function: str | SimilarityFunction | None = None, 30 | ) -> None: 31 | # Manually call base class __init__ (bypassing argument mismatches) 32 | super(InformationRetrievalEvaluator).__init__() 33 | self.corpus_chunk_size = corpus_chunk_size 34 | self.mrr_at_k = mrr_at_k 35 | self.ndcg_at_k = ndcg_at_k 36 | self.accuracy_at_k = accuracy_at_k 37 | self.precision_recall_at_k = precision_recall_at_k 38 | self.map_at_k = map_at_k 39 | 40 | self.show_progress_bar = show_progress_bar 41 | self.batch_size = batch_size 42 | self.name = name 43 | self.write_csv = write_csv 44 | self.score_functions = score_functions 45 | self.score_function_names = sorted(list(self.score_functions.keys())) if score_functions else [] 46 | self.main_score_function = SimilarityFunction(main_score_function) if main_score_function else None 47 | self.truncate_dim = truncate_dim 48 | 49 | if name: 50 | name = "_" + name 51 | 52 | self.csv_file: str = "Information-Retrieval_evaluation" + name + "_results.csv" 53 | self.csv_headers = ["epoch", "steps"] 54 | 55 | self._append_csv_headers(self.score_function_names) 56 | 57 | def calc_results(self, query_embeddings, doc_embeddings, r_docs): 58 | # Construct internal ID mappings and relevance ground truth 59 | self.queries_ids = [f"q{i}" for i in range(len(query_embeddings))] 60 | self.corpus_ids = [f"c{i}" for i in range(len(doc_embeddings))] 61 | self.relevant_docs = { 62 | f"q{i}": set([f"c{j}" for j in r_docs[i]]) for i in range(len(r_docs)) 63 | } 64 | self.queries = [""] * len(query_embeddings) 65 | self.corpus = [""] * len(doc_embeddings) 66 | 67 | # Determine the largest k needed 68 | max_k = max( 69 | max(self.mrr_at_k), 70 | max(self.ndcg_at_k), 71 | max(self.accuracy_at_k), 72 | max(self.precision_recall_at_k), 73 | max(self.map_at_k), 74 | ) 75 | 76 | # Initialize result storage 77 | queries_result_list = { 78 | name: [[] for _ in range(len(query_embeddings))] 79 | for name in self.score_functions 80 | } 81 | 82 | # Process corpus in chunks to avoid memory overflow 83 | for corpus_start_idx in trange( 84 | 0, len(self.corpus), self.corpus_chunk_size, desc="Corpus Chunks", disable=not self.show_progress_bar 85 | ): 86 | corpus_end_idx = min(corpus_start_idx + self.corpus_chunk_size, len(self.corpus)) 87 | sub_corpus_embeddings = doc_embeddings[corpus_start_idx:corpus_end_idx] 88 | 89 | for name, score_function in self.score_functions.items(): 90 | # Compute similarity scores 91 | pair_scores = score_function(query_embeddings, sub_corpus_embeddings) 92 | 93 | # Top-k scores and indices 94 | pair_scores_top_k_values, pair_scores_top_k_idx = torch.topk( 95 | pair_scores, min(max_k, len(pair_scores[0])), dim=1, largest=True, sorted=False 96 | ) 97 | pair_scores_top_k_values = pair_scores_top_k_values.cpu().tolist() 98 | pair_scores_top_k_idx = pair_scores_top_k_idx.cpu().tolist() 99 | 100 | for query_itr in range(len(query_embeddings)): 101 | for sub_corpus_id, score in zip( 102 | pair_scores_top_k_idx[query_itr], pair_scores_top_k_values[query_itr] 103 | ): 104 | corpus_id = self.corpus_ids[corpus_start_idx + sub_corpus_id] 105 | # Do not skip query==corpus as some setups use identical indexing 106 | if len(queries_result_list[name][query_itr]) < max_k: 107 | heapq.heappush(queries_result_list[name][query_itr], (score, corpus_id)) 108 | else: 109 | heapq.heappushpop(queries_result_list[name][query_itr], (score, corpus_id)) 110 | 111 | # Convert heap tuples to result dicts 112 | for name in queries_result_list: 113 | for query_itr in range(len(queries_result_list[name])): 114 | queries_result_list[name][query_itr] = [ 115 | {"corpus_id": corpus_id, "score": score} 116 | for score, corpus_id in queries_result_list[name][query_itr] 117 | ] 118 | 119 | # Compute standard IR metrics 120 | scores = { 121 | name: self.compute_metrics(queries_result_list[name]) 122 | for name in self.score_functions 123 | } 124 | 125 | # Compute average cosine similarity between each query and its relevant docs 126 | for name, score_func in self.score_functions.items(): 127 | v_list = [] 128 | for i in range(len(query_embeddings)): 129 | score_v = score_func(query_embeddings[i].reshape(1, -1), doc_embeddings[r_docs[i]]) 130 | v_list.append(score_v.mean().item()) 131 | v_list = np.array(v_list) 132 | scores["cos"][f"{name}_mean"] = np.mean(v_list) 133 | p_list = [10, 50, 90] 134 | scores["cos"][f"{name}_percentile@k"] = {p: np.percentile(v_list, p) for p in p_list} 135 | 136 | return scores 137 | 138 | 139 | def compute_metrics_custom(p): 140 | # p is an EvalPrediction object containing predictions and label_ids 141 | query_embeddings, doc_embeddings, ids, doc_ids = p.predictions 142 | 143 | n = len(ids) 144 | 145 | doc_embeddings_flatten = [] 146 | rdocs = [] 147 | tot = 0 148 | for i in range(n): 149 | num = doc_ids[i] # number of documents for query i 150 | rdocs.append(list(range(tot, tot + num))) # relevant doc indices for query i 151 | for j in range(num): 152 | doc_embeddings_flatten.append(doc_embeddings[i][j]) 153 | tot += num 154 | 155 | doc_embeddings_flatten = np.array(doc_embeddings_flatten) 156 | 157 | evaluator = CustomRetrievalEvaluator() 158 | results = evaluator.calc_results(query_embeddings, doc_embeddings_flatten, rdocs)["cos"] 159 | 160 | # Flatten nested results into a flat dictionary 161 | new_results = {} 162 | for k in results.keys(): 163 | if isinstance(results[k], dict): 164 | for kk in results[k].keys(): 165 | new_results[f"{k}_{kk}"] = results[k][kk] 166 | else: 167 | new_results[k] = results[k] 168 | 169 | # Log results only on rank 0 170 | if torch.distributed.get_rank() == 0: 171 | print(results) 172 | 173 | return new_results 174 | -------------------------------------------------------------------------------- /stage1/custom_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from transformers import Trainer, TrainingArguments 6 | import torch.distributed.nn 7 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union 8 | from trainer_headers import * 9 | from grad_cache_custom import GradCacheCustom 10 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union 11 | 12 | # grad_cache version 13 | class ContrastiveTrainer(Trainer): 14 | def __init__( 15 | self, 16 | loss_fn=None, 17 | multi_pos=1, 18 | chunk_size=8, 19 | *args, 20 | **kwargs, 21 | ): 22 | super().__init__(*args, **kwargs) 23 | self.loss_fn = loss_fn 24 | self.grad_cache = GradCacheCustom( 25 | models=[self.model, self.model], 26 | chunk_sizes=[chunk_size] * 2, 27 | loss_fn=self.loss_fn, 28 | multi_pos=multi_pos, 29 | ) 30 | 31 | def compute_loss(self, model, inputs, return_outputs=False, **kwargs): 32 | # Compute sentence embeddings for query and document 33 | query_embeddings = model.forward( 34 | { 35 | 'input_ids': inputs['input_ids'], 36 | 'attention_mask': inputs['attention_mask'], 37 | } 38 | )['sentence_embedding'] 39 | doc_embeddings = model.forward( 40 | { 41 | 'input_ids': inputs['doc_input_ids'], 42 | 'attention_mask': inputs['doc_attention_mask'], 43 | } 44 | )['sentence_embedding'] 45 | 46 | # Gather embeddings across all processes for contrastive loss 47 | query_embeddings_new = torch.distributed.nn.all_gather(query_embeddings) 48 | doc_embeddings_new = torch.distributed.nn.all_gather(doc_embeddings) 49 | query_embeddings_new = torch.cat(query_embeddings_new, dim=0) 50 | doc_embeddings_new = torch.cat(doc_embeddings_new, dim=0) 51 | 52 | # Compute the contrastive loss 53 | loss = self.loss_fn(query_embeddings_new, doc_embeddings_new) 54 | 55 | # Optionally return intermediate outputs 56 | return (loss, (query_embeddings, doc_embeddings)) if return_outputs else loss 57 | 58 | # Evaluation step 59 | def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None): 60 | # Get the device of the model 61 | device = next(model.parameters()).device 62 | 63 | # Move inputs to the model's device 64 | inputs = {key: value.to(device) for key, value in inputs.items()} 65 | 66 | # Disable gradient computation 67 | with torch.no_grad(): 68 | # Compute query embedding 69 | query_embeddings = model.forward( 70 | { 71 | 'input_ids': inputs['input_ids'], 72 | 'attention_mask': inputs['attention_mask'], 73 | } 74 | )['sentence_embedding'] 75 | 76 | n = inputs['doc_input_ids'].shape[1] 77 | doc_embeddings = None 78 | 79 | # Loop through each document in the list (dimension 1) 80 | for i in range(0, n): 81 | doc_embedding = model.forward( 82 | { 83 | 'input_ids': inputs['doc_input_ids'][:, i], 84 | 'attention_mask': inputs['doc_attention_mask'][:, i], 85 | } 86 | )['sentence_embedding'].unsqueeze(1) 87 | doc_embeddings = doc_embedding if doc_embeddings is None else torch.cat((doc_embeddings, doc_embedding), dim=1) 88 | 89 | # Return output format: (loss, predictions, labels) 90 | if prediction_loss_only: 91 | return (None, None, None) 92 | 93 | return None, \ 94 | (query_embeddings, doc_embeddings, inputs['id'], inputs['doc_id']), \ 95 | torch.zeros_like(inputs['id']) 96 | 97 | def training_step( 98 | self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None 99 | ) -> torch.Tensor: 100 | """ 101 | Perform a training step with support for gradient caching. 102 | 103 | Args: 104 | model (nn.Module): The model being trained. 105 | inputs (dict): A batch of input data. 106 | num_items_in_batch (optional): The number of training examples in the batch. 107 | 108 | Returns: 109 | torch.Tensor: The computed training loss. 110 | """ 111 | model.train() 112 | if hasattr(self.optimizer, "train") and callable(self.optimizer.train): 113 | self.optimizer.train() 114 | 115 | inputs = self._prepare_inputs(inputs) 116 | 117 | # SageMaker model parallelism path 118 | if is_sagemaker_mp_enabled(): 119 | loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps) 120 | return loss_mb.reduce_mean().detach().to(self.args.device) 121 | 122 | self.grad_cache.compute_loss_context_manager = self.compute_loss_context_manager 123 | 124 | # Empty cache conditionally at configured steps 125 | if ( 126 | self.args.torch_empty_cache_steps is not None 127 | and self.state.global_step % self.args.torch_empty_cache_steps == 0 128 | ): 129 | if is_torch_xpu_available(): 130 | torch.xpu.empty_cache() 131 | elif is_torch_mlu_available(): 132 | torch.mlu.empty_cache() 133 | elif is_torch_musa_available(): 134 | torch.musa.empty_cache() 135 | elif is_torch_npu_available(): 136 | torch.npu.empty_cache() 137 | elif is_torch_mps_available(min_version="2.0"): 138 | torch.mps.empty_cache() 139 | else: 140 | torch.cuda.empty_cache() 141 | 142 | kwargs = {} 143 | 144 | # Handle special optimizer requirements 145 | if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: 146 | kwargs["learning_rate"] = self._get_learning_rate() 147 | 148 | # Mean-reduce loss in multi-GPU setting 149 | if self.args.n_gpu > 1: 150 | loss = loss.mean() 151 | 152 | # Setup backward function depending on backend 153 | if self.use_apex: 154 | def apex_backward(loss): 155 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 156 | scaled_loss.backward() 157 | self.grad_cache.backward_fn = apex_backward 158 | else: 159 | if self.accelerator.distributed_type == DistributedType.DEEPSPEED: 160 | kwargs["scale_wrt_gas"] = False 161 | 162 | def accelerator_backward(loss): 163 | # Normalize loss across accumulation steps unless the model handles this internally 164 | if not self.model_accepts_loss_kwargs and self.compute_loss_func is None: 165 | loss = loss / self.args.gradient_accumulation_steps 166 | self.accelerator.backward(loss, **kwargs) 167 | self.grad_cache.backward_fn = accelerator_backward 168 | 169 | # Use gradient caching to compute and apply gradients 170 | return self.grad_cache.cache_step(inputs, num_items_in_batch=num_items_in_batch).detach() 171 | 172 | def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): 173 | # Call the base class's save_model method 174 | super().save_model(output_dir, _internal_call) 175 | 176 | model = self.model 177 | 178 | # Only save model on rank 0 in distributed setup 179 | if torch.distributed.get_rank() == 0: 180 | print(f"Rank {torch.distributed.get_rank()} saving model.") 181 | save_dir = os.path.join(output_dir, "sentence-transformer-checkpoint") 182 | os.makedirs(save_dir, exist_ok=True) 183 | # Save in SentenceTransformer-compatible format 184 | model.save(save_dir, safe_serialization=False) 185 | 186 | # Synchronize all processes to ensure save is complete 187 | torch.distributed.barrier() 188 | -------------------------------------------------------------------------------- /stage2/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' 3 | import torch 4 | from sentence_transformers import ( 5 | SentenceTransformer, 6 | SentenceTransformerTrainingArguments, 7 | ) 8 | from datasets import load_dataset, Dataset 9 | from transformers.integrations import TensorBoardCallback 10 | from sentence_transformers.evaluation import InformationRetrievalEvaluator 11 | from sentence_transformers.losses import TripletLoss, CachedMultipleNegativesRankingLoss 12 | from sentence_transformers.losses import TripletDistanceMetric 13 | from trainer_custom import CustomTrainer 14 | from sentence_transformers.training_args import BatchSamplers 15 | import argparse 16 | from dataset_process import process_dataset, process_eval_dataset 17 | from sentence_transformers.training_args import MultiDatasetBatchSamplers 18 | import random 19 | 20 | def str2bool(v): 21 | if isinstance(v, bool): 22 | return v 23 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 24 | return True 25 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 26 | return False 27 | else: 28 | raise argparse.ArgumentTypeError('Boolean value expected.') 29 | 30 | # Initialize argument parser 31 | parser = argparse.ArgumentParser(description='Fine-tune a model for sentence similarity') 32 | 33 | # Add arguments 34 | parser.add_argument('--model_path', 35 | default='coldchair16/CPRetriever-Code', 36 | type=str, 37 | help='The name of the model to fine-tune') 38 | 39 | parser.add_argument('--triplet_margin', 40 | default=0.3, 41 | type=float, 42 | help='The margin for triplet loss') 43 | 44 | parser.add_argument('--cut', 45 | default=True, 46 | type=str2bool, 47 | help='Whether to cut the dataset SimplifiedRetrieval') 48 | 49 | parser.add_argument('--lr', 50 | default=2e-6, 51 | type=float, 52 | help='The learning rate for training') 53 | 54 | parser.add_argument('--epochs', 55 | default=1, 56 | type=int, 57 | help='Number of epochs for training') 58 | 59 | parser.add_argument('--dataset_id', 60 | default='coldchair16/CPRet-data', 61 | type=str, 62 | help='The dataset ID for training') 63 | 64 | parser.add_argument('--eval_task_list', 65 | default=['T2C', 'C2C', 'P2Dup', 'S2Full'], 66 | type=str, 67 | nargs='+', 68 | help='List of evaluation tasks') 69 | 70 | parser.add_argument('--train_task_list', 71 | default=['P2Dup', 'S2Full'], 72 | # default=['PCD', 'P2Dup', 'S2Full'], 73 | type=str, 74 | nargs='+', 75 | help="List of training tasks: ['PCD', 'P2Dup', 'S2Full']") 76 | 77 | parser.add_argument('--max_length', 78 | default=1024, 79 | type=int, 80 | help='Maximum input length for the model') 81 | 82 | parser.add_argument('--eval_only', 83 | default=False, 84 | type=str2bool, 85 | help='Whether to run evaluation only, not training') 86 | 87 | 88 | def main(): 89 | # Parse arguments 90 | args = parser.parse_args() 91 | # Use the parsed arguments in your code 92 | dataset_id = args.dataset_id 93 | eval_task_list = args.eval_task_list 94 | train_task_list = args.train_task_list 95 | max_length = args.max_length 96 | eval_only = args.eval_only 97 | model_path = args.model_path 98 | triplet_margin = args.triplet_margin 99 | cut = args.cut 100 | lr = args.lr 101 | epochs = args.epochs 102 | 103 | model_name = model_path.split('/')[-1] 104 | 105 | model_kwargs = {'torch_dtype': torch.bfloat16} # TODO: Make sure your GPU and model support bfloat16 106 | model = SentenceTransformer(model_path, trust_remote_code=True, model_kwargs=model_kwargs) 107 | 108 | model.tokenizer.model_max_length = max_length 109 | model.max_seq_length = max_length 110 | 111 | evaluator_list = [] 112 | for task in eval_task_list: 113 | queries = load_dataset(dataset_id, f"{task}-queries", split='test') 114 | corpus = load_dataset(dataset_id, f"{task}-corpus", split='test') 115 | qrels = load_dataset(dataset_id, f"{task}-qrels", split='test') 116 | queries, corpus, relevant_docs = process_eval_dataset(queries, corpus, qrels) 117 | evaluator = InformationRetrievalEvaluator( 118 | queries=queries, 119 | corpus=corpus, 120 | relevant_docs=relevant_docs, 121 | name=task, 122 | ) 123 | evaluator_list.append(evaluator) 124 | 125 | train_dataset = {} 126 | for task in train_task_list: 127 | if (task == 'PCD'): 128 | data = load_dataset(dataset_id, f"PCPCD") 129 | traindata_PCPCD, testdata_PCPCD = process_dataset(data, max_length) 130 | anchor = [] 131 | positive = [] 132 | for d in traindata_PCPCD: 133 | anchor.append(d['question']) 134 | positive.append(random.choice(d['solutions'])) 135 | dataset = Dataset.from_dict({ 136 | 'anchor' : anchor, 137 | 'positive' : positive, 138 | }) 139 | elif (task == 'P2Dup' or task == 'S2Full'): 140 | data = load_dataset(dataset_id, f"{task}-train-pairs", split='train') 141 | dataset = Dataset.from_dict({ 142 | 'anchor': data['query'], 143 | 'positive': data['pos'], 144 | 'negative': data['neg'], 145 | }) 146 | if (cut and task == 'S2Full'): 147 | dataset = dataset.select(random.sample(range(len(dataset)), 1000)) 148 | print(f"train task = {task} len = {len(dataset)}") 149 | train_dataset[task] = dataset 150 | 151 | 152 | output_dir = f"./results/{model_name}" 153 | output_dir = output_dir + '_' + '-'.join(train_task_list) 154 | output_dir = output_dir + f"_margin{triplet_margin}" 155 | output_dir = output_dir + f"_lr{lr}" 156 | output_dir = output_dir + f'_maxlength{max_length}' 157 | if (cut): 158 | output_dir = output_dir + "_cut" 159 | output_dir = output_dir + f"_epochs{epochs}" 160 | print(f"output_dir : {output_dir}") 161 | 162 | 163 | args = SentenceTransformerTrainingArguments( 164 | output_dir=output_dir, 165 | num_train_epochs=epochs, 166 | per_device_train_batch_size=1, 167 | per_device_eval_batch_size=1, 168 | learning_rate=lr, 169 | lr_scheduler_type='constant', 170 | bf16=True, 171 | eval_strategy="epoch", 172 | save_strategy="epoch", 173 | save_total_limit=1, 174 | logging_steps=10, 175 | run_name='CPRet', # Will be used in W&B if `wandb` is installed 176 | save_on_each_node=False, 177 | max_grad_norm=1.0, 178 | batch_sampler=BatchSamplers.NO_DUPLICATES, 179 | multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, 180 | ) 181 | 182 | 183 | losses = { 184 | 'PCD' : CachedMultipleNegativesRankingLoss(model, mini_batch_size=2, scale=14.28571429), 185 | 'P2Dup' : TripletLoss(model, TripletDistanceMetric.COSINE, triplet_margin=triplet_margin), 186 | 'S2Full' : TripletLoss(model, TripletDistanceMetric.COSINE, triplet_margin=triplet_margin), 187 | } 188 | 189 | # 为每个数据集单独设置 batch size 190 | batch_sizes = { 191 | 'PCD' : 32, 192 | 'P2Dup' : 1, 193 | 'S2Full' : 1, 194 | } 195 | 196 | trainer = CustomTrainer( 197 | batch_sizes=batch_sizes, 198 | model=model, 199 | train_dataset=train_dataset, 200 | args = args, 201 | loss=losses, 202 | callbacks=[TensorBoardCallback()], 203 | evaluator=evaluator_list, 204 | ) 205 | if (eval_only): 206 | eval_results = trainer.evaluate() 207 | print(eval_results) 208 | else: 209 | trainer.train() 210 | 211 | if __name__ == '__main__': 212 | main() 213 | -------------------------------------------------------------------------------- /stage1/grad_cache/pytorch_lightning/pl_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch Lightning Example of using Grad Cache, tested on PyTorch Lightning version '2.2.0.post0' with Multi-GPUs and Mix-Precision (fp-16). 3 | Required to install Pytorch Metric Learning as well for contrastive loss calculation. 4 | """ 5 | 6 | import os 7 | import argparse 8 | import torch 9 | import lightning as pl 10 | from contextlib import nullcontext 11 | from lightning.pytorch.loggers import WandbLogger 12 | from lightning.pytorch.strategies import DDPStrategy 13 | from pytorch_metric_learning.utils import distributed as pml_dist 14 | from pytorch_metric_learning.losses import SupConLoss 15 | 16 | from grad_cache.pytorch_lightning.pl_gradcache import PLGradCache 17 | 18 | 19 | class RandomDataset(torch.utils.data.Dataset): 20 | def __init__(self, params): 21 | self.params = params 22 | 23 | def __len__(self): 24 | return self.params.data_size 25 | 26 | def __getitem__(self, idx): 27 | # Generate random float inputs with shape [2, input_dim] for contrastive learning 28 | input_data = torch.randn(2, self.params.input_dim) 29 | # Generate a random integer label for binary classification (0 or 1), replicate it to have shape [2] 30 | label = torch.randint(0, 2, (1,), dtype=torch.long) 31 | label = torch.tensor([label, label], dtype=torch.long) 32 | return input_data, label 33 | 34 | 35 | class SimpleLitModel(pl.LightningModule): 36 | def __init__(self, params): 37 | super().__init__() 38 | self.params = params 39 | self.loss = SupConLoss(temperature=params.temperature) 40 | if params.gpus > 1: 41 | self.loss = pml_dist.DistributedLossWrapper(self.loss) 42 | self.automatic_optimization = (not self.params.use_gc) # needed when use_gc is on 43 | self.fp16 = (self.params.precision == 16) 44 | self.linear = torch.nn.Linear(params.input_dim, params.embed_dim) # our simple model 45 | 46 | def init_gc(self, scaler, ddp_module): 47 | """Sets up the required components of GradCache. This method is called after the model is initialized.""" 48 | assert self.params.use_gc 49 | if self.fp16 and self.params.use_gc: 50 | # pytorch lightning autocast wraps everything in it 51 | # it needs to be disabled in gradcache because we do forward twice, and one with no grad 52 | # then we do autocast manually in gradcache when we need to 53 | # original post: https://discuss.pytorch.org/t/autocast-and-torch-no-grad-unexpected-behaviour/93475/3 54 | # pl source code: your_venv_name/lib/python3.8/site-packages/lightning/pytorch/plugins/precision/amp.py::forward_context 55 | self.trainer.strategy.precision_plugin.forward_context = nullcontext 56 | 57 | print(f"*** initializing gradcache with ddp_module={type(ddp_module)}, minibatch_size={self.params.gc_minibatch_size}") 58 | self.gc = PLGradCache( 59 | models=[ddp_module], 60 | chunk_sizes=self.params.gc_minibatch_size, 61 | loss_fn=self.calculate_loss, 62 | fp16=self.fp16, 63 | scaler=(scaler if self.fp16 else None), # needed when using automatic_optimization is off and fp16 is on 64 | backward_fn=self.manual_backward, # needed when automatic_optimization is off 65 | ) 66 | 67 | def train_dataloader(self): 68 | train_dataset = RandomDataset(params) 69 | train_loader = torch.utils.data.DataLoader( 70 | train_dataset, 71 | batch_size=params.batch_size, 72 | num_workers=params.num_workers, 73 | drop_last=True, 74 | ) 75 | return train_loader 76 | 77 | def configure_optimizers(self): 78 | optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) 79 | return optimizer 80 | 81 | def calculate_loss(self, embeddings, labels): 82 | # embeddings shape [batch_size, 2, embed_dim] 83 | # labels shape [batch_size, 2] 84 | embeddings = embeddings.flatten(0, 1) 85 | labels = labels.flatten() 86 | return self.loss(embeddings, labels) 87 | 88 | def forward(self, inputs): # needed for grad cache 89 | return self.linear(inputs) 90 | 91 | def on_train_start(self): # initialize grad cache here 92 | if self.params.use_gc: 93 | self.init_gc(self.trainer.scaler, self.trainer.strategy.model) 94 | # self.init_gc(self.trainer.scaler, self.trainer.lightning_module) # we can use this if nccl strategy is available 95 | 96 | def training_step(self, batch, batch_idx): 97 | # inputs shape [batch_size, 2, input_dim] 98 | # labels shape [batch_size, 2] 99 | inputs, labels = batch 100 | if self.params.use_gc: 101 | assert self.gc is not None 102 | optimizer = self.optimizers() 103 | optimizer.zero_grad() 104 | loss = self.gc( 105 | inputs, 106 | no_sync_except_last=(self.params.gpus > 1), 107 | labels=labels.flatten(), 108 | ) 109 | loss /= max(1, self.params.gpus) # needed when automatic_optimization is off 110 | log_loss = loss 111 | optimizer.step() 112 | else: 113 | outputs = self.linear(inputs) 114 | loss = self.calculate_loss(outputs, labels) 115 | log_loss = loss / max(1, self.params.gpus) 116 | self.log( 117 | "train_loss", 118 | log_loss, 119 | on_step=True, 120 | on_epoch=True, 121 | sync_dist=self.params.use_gc, # needed when automatic_optimization is off 122 | ) 123 | print(f"batch_idx={batch_idx}, loss={loss}") 124 | return loss 125 | 126 | 127 | def get_argument_parser(): 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument("--random_seed", type=int, default=42) 130 | parser.add_argument("--num_workers", type=int, default=4) 131 | parser.add_argument("--gpus", type=int, default=0) 132 | parser.add_argument("--precision", type=int, default=32) 133 | parser.add_argument("--ddp_backend", type=str, default="nccl", help="torch distributed backend (Default: nccl), use 'gloo' if nccl doesn't work") 134 | parser.add_argument("--project_name", type=str, default="debug_gradcache") 135 | 136 | # training params 137 | parser.add_argument("--data_size", type=int, default=100) 138 | parser.add_argument("--epochs", type=int, default=10) 139 | parser.add_argument("--batch_size", type=int, default=16) 140 | parser.add_argument("--temperature", type=float, default=0.1) 141 | 142 | # model hyperparams 143 | parser.add_argument("--input_dim", type=int, default=784) 144 | parser.add_argument("--embed_dim", type=int, default=512) 145 | 146 | # grad cache params 147 | parser.add_argument("--use_gc", action="store_true", default=False, help="whether to use grad cache") 148 | parser.add_argument("--gc_minibatch_size", type=int, default=2, help="mini batch size of grad cache, must be provided if use_gc is on") 149 | 150 | return parser 151 | 152 | 153 | def main(params): 154 | # set random seeds reproduceability 155 | torch.backends.cudnn.deterministic = True 156 | torch.backends.cudnn.benchmark = False 157 | 158 | # set different random seeds for each worker 159 | pl.seed_everything(seed=params.random_seed, workers=True) 160 | 161 | # weirdness with HuggingFace tokenizer when processing things in parallel 162 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 163 | torch.multiprocessing.set_sharing_strategy("file_system") 164 | 165 | # load model 166 | model = SimpleLitModel(params) 167 | 168 | # load trainer 169 | experiment_id = f"gpus-{params.gpus}_precision-{params.precision}" 170 | if params.use_gc: 171 | experiment_id += "_gc" 172 | experiment_id += "_pl" 173 | wandb_logger = WandbLogger( 174 | project=params.project_name, 175 | name=experiment_id, 176 | ) 177 | ddp = DDPStrategy(process_group_backend=params.ddp_backend) 178 | trainer = pl.Trainer( 179 | accelerator="gpu" if params.gpus > 0 else "cpu", 180 | strategy=ddp if params.gpus > 1 else "auto", 181 | devices=params.gpus if params.gpus > 0 else "auto", 182 | precision=params.precision, 183 | logger=wandb_logger, 184 | max_epochs=params.epochs, 185 | log_every_n_steps=1, 186 | ) 187 | trainer.fit(model) 188 | 189 | 190 | if __name__ == "__main__": 191 | params = get_argument_parser().parse_args() 192 | main(params) 193 | -------------------------------------------------------------------------------- /stage1/grad_cache/pytorch_lightning/pl_gradcache.py: -------------------------------------------------------------------------------- 1 | from contextlib import nullcontext 2 | from typing import Any, Callable, List, Tuple, Union 3 | 4 | import torch 5 | from torch import Tensor, nn 6 | from torch.cuda.amp import GradScaler, autocast 7 | 8 | from ..grad_cache import GradCache, RandContext 9 | 10 | class PLGradCache(GradCache): 11 | """ 12 | Gradient Cache class with PyTorch Lightning Support. 13 | Implements input chunking, first graph-less forward pass, Gradient Cache creation, second forward & backward gradient computation. 14 | Optimizer step is not included. Native torch automatic mixed precision is supported. 15 | Gradient unscaling and scaler update are handled internally. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | models: List[nn.Module], 21 | chunk_sizes: Union[int, List[int]], 22 | loss_fn: Callable[..., Tensor], 23 | split_input_fn: Callable[[Any, int], Any] = None, 24 | get_rep_fn: Callable[..., Tensor] = None, 25 | fp16: bool = False, 26 | scaler: GradScaler = None, 27 | backward_fn=None, # [added] 28 | ): 29 | """ 30 | Initialize the Gradient Cache class instance. 31 | :param models: A list of all encoder models to be updated by the current cache. 32 | :param chunk_sizes: An integer indicating chunk size. Or a list of integers of chunk size for each model. 33 | :param loss_fn: A loss function that takes arbitrary numbers of representation tensors and 34 | arbitrary numbers of keyword arguments as input. It should not in any case modify the input tensors' relations 35 | in the autograd graph, which are later relied upon to create the gradient cache. 36 | :param split_input_fn: An optional function that split generic model input into chunks. If not provided, this 37 | class will try its best to split the inputs of supported types. See `split_inputs` function. 38 | :param get_rep_fn: An optional function that takes generic model output and return representation tensors. If 39 | not provided, the generic output is assumed to be the representation tensor. 40 | :param fp16: If True, run mixed precision training, which requires scaler to also be set. 41 | :param scaler: A GradScaler object for automatic mixed precision training. 42 | :[added] param backward_fn: The `manual_backward` function of pytorch lightning trainer when automatic_optimization is disabled. 43 | """ 44 | super().__init__(models, chunk_sizes, loss_fn, split_input_fn, get_rep_fn, fp16, scaler) 45 | self.backward_fn = backward_fn 46 | 47 | def build_cache(self, *reps: Tensor, **loss_kwargs) -> Union[List[Tensor], Tensor]: 48 | """ 49 | Compute the gradient cache 50 | :param reps: Computed representations from all encoder models 51 | :param loss_kwargs: Extra keyword arguments to the loss function 52 | :return: A tuple of a) gradient cache for each encoder model, and b) loss tensor 53 | """ 54 | reps = [r.detach().requires_grad_() for r in reps] 55 | with autocast() if self.fp16 else nullcontext(): 56 | loss = self.compute_loss(*reps, **loss_kwargs) 57 | 58 | self.backward_fn(loss) # [modified] 59 | 60 | cache = [r.grad for r in reps] 61 | 62 | return cache, loss.detach() 63 | 64 | def forward_backward( 65 | self, 66 | model: nn.Module, 67 | model_inputs, 68 | cached_gradients: List[Tensor], 69 | random_states: List[RandContext], 70 | no_sync_except_last: bool = False, 71 | ): 72 | """ 73 | Run the second forward and the backward pass to compute gradient for a model. 74 | :param model: Encoder model. 75 | :param model_inputs: Chunked input to the encoder model. 76 | :param cached_gradients: Chunked gradient cache tensor for each input. 77 | :param random_states: Each input's device random state during the first forward. 78 | :param no_sync_except_last: If True, under distributed setup, only trigger gradient reduction across processes 79 | for the last sub-batch's forward-backward pass. 80 | """ 81 | if isinstance( 82 | model, nn.parallel.DistributedDataParallel 83 | ): # [use ddp_model] 84 | 85 | if no_sync_except_last: 86 | sync_contexts = [ 87 | model.no_sync for _ in range(len(model_inputs) - 1) 88 | ] + [nullcontext] 89 | sync_flags = [True] * (len(model_inputs)) # [added] 90 | else: 91 | sync_contexts = [nullcontext for _ in range(len(model_inputs))] 92 | sync_flags = [False] * (len(model_inputs)) # [added] 93 | 94 | # [modified] 95 | for x, state, gradient, sync_context, sync_flag in zip( 96 | model_inputs, random_states, cached_gradients, sync_contexts, sync_flags 97 | ): 98 | with sync_context(): 99 | with state: 100 | y = self.model_call(model, x) 101 | reps = self.get_reps(y) 102 | surrogate = torch.dot(reps.flatten(), gradient.flatten()) 103 | if sync_flag: 104 | model.require_backward_grad_sync = True 105 | if self.fp16: # [added] 106 | self.scaler._enabled = False 107 | self.backward_fn(surrogate) 108 | self.scaler._enabled = True 109 | else: 110 | self.backward_fn(surrogate) # [modified] 111 | else: # [use base model (i.e. SimpleLitModel)] 112 | 113 | # [remove no_sync_except_last: pytorch lightning would handle gradient sync automatically] 114 | for x, state, gradient in zip( 115 | model_inputs, random_states, cached_gradients 116 | ): 117 | with state: 118 | y = self.model_call(model, x) 119 | reps = self.get_reps(y) 120 | surrogate = torch.dot(reps.flatten(), gradient.flatten()) 121 | if self.fp16: # [added] 122 | self.scaler._enabled = False 123 | self.backward_fn(surrogate) 124 | self.scaler._enabled = True 125 | else: 126 | self.backward_fn(surrogate) # [added] 127 | 128 | def cache_step( 129 | self, *model_inputs, no_sync_except_last: bool = False, **loss_kwargs 130 | ) -> Tuple[Tensor, Tensor]: 131 | """ 132 | Run a cached step to compute gradient over the inputs. 133 | :param model_inputs: Input to each encoder model. Should be in similar order as the class's model. 134 | :param no_sync_except_last: If True, under distributed setup, for each model, only trigger gradient reduction 135 | across processes for the last sub-batch's forward-backward pass. 136 | :param loss_kwargs: Additional keyword arguments to the loss function. 137 | :return: A tuple of the current's loss and the model's representation. 138 | """ 139 | all_reps = [] 140 | all_rnd_states = [] 141 | 142 | # [removed: we check it in forward_backward(.)] 143 | # if no_sync_except_last: 144 | # assert all(map(lambda m: isinstance(m, nn.parallel.DistributedDataParallel), self.models)), \ 145 | # 'Some of models are not wrapped in DistributedDataParallel. Make sure you are running DDP with ' \ 146 | # 'proper initializations.' 147 | 148 | model_inputs = [ 149 | self.split_inputs(x, chunk_size) 150 | for x, chunk_size in zip(model_inputs, self.chunk_sizes) 151 | ] 152 | 153 | for model, x in zip(self.models, model_inputs): 154 | model_reps, rnd_states = self.forward_no_grad(model, x) 155 | all_reps.append(model_reps) 156 | all_rnd_states.append(rnd_states) 157 | 158 | # all_reps: len(self.models) x [batch_size, 2, embed_dim] 159 | # cache: len(self.models) x gc_minibatch x [(batch_size / gc_minibatch, 2, embed_dim] 160 | 161 | cache, loss = self.build_cache(*all_reps, **loss_kwargs) 162 | cache = [c.split(chunk_size) for c, chunk_size in zip(cache, self.chunk_sizes)] 163 | 164 | for model, x, model_cache, rnd_states in zip( 165 | self.models, model_inputs, cache, all_rnd_states 166 | ): 167 | self.forward_backward( 168 | model, 169 | x, 170 | model_cache, 171 | rnd_states, 172 | no_sync_except_last=no_sync_except_last, 173 | ) 174 | 175 | return loss 176 | -------------------------------------------------------------------------------- /cp-retrieval-server/templates/index.html: -------------------------------------------------------------------------------- 1 | {% extends "base.html" %} 2 | {% block content %} 3 |

{{ t.site_name }}

4 | 5 | 8 | 9 | 12 | 13 | 16 | 17 |
18 | 19 |
20 | 22 | 25 | 29 | 🧪 {{ t.example_report }} 30 | 31 |
32 |
33 |
34 | 35 |
36 | 37 | 38 |
39 |
40 | 41 | {% for oj_name in available_ojs %} 42 |
43 | 46 | 47 |
48 | {% endfor %} 49 |
50 | 51 | 52 | 53 | {{ t.view_stats }} 54 |
55 | 56 | {% if results %} 57 |
58 |

{{ t.summary.format(total=total, page=page, max_page=max_page, elapsed=elapsed) | safe }} 60 |

61 | 62 | 84 | 85 | 119 | 120 | {% endif %} 121 | 122 | 144 | 145 |
146 | 147 | 148 | 150 | 163 | 164 | 165 | GitHub: coldchair/CPRet 166 | 167 | 168 | 169 | 170 |
171 | 241 | 242 | 243 | {% endblock %} 244 | 245 | 246 | -------------------------------------------------------------------------------- /cp-retrieval-server/app.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' 3 | import torch 4 | import math 5 | import json 6 | import numpy as np 7 | from flask import Flask, request, render_template, redirect, url_for 8 | from sentence_transformers import SentenceTransformer 9 | import time 10 | from flask import Flask, request, render_template, redirect, url_for 11 | from markupsafe import Markup, escape 12 | from datetime import datetime 13 | 14 | # ===== Multilingual dictionary ===== 15 | I18N = { 16 | "zh": { 17 | "site_name" : "CPRet:编程竞赛题目检索", 18 | "new_domain_info": "我们的最新域名是 cpret.online,我们的 GitHub 仓库是 CPRet,欢迎收藏或 star!", 19 | "paper_info": "📰 2025 年 9 月更新:🎉 恭喜!我们的项目论文 CPRet 被 NeurIPS 2025 D&B track 接收!", 20 | "info": "📢 2025 年 7 月更新:我们已升级模型并同步更新了题目数据库,检索效果更佳!", 21 | "info2": "📢 2025 年 10 月更新:新增了部分 OJ 的题目,并进一步优化了模型效果!", 22 | "placeholder": "输入题目描述或简略题意(超过 2048 token 的查询将被截断)…", 23 | "template_btn": "填入示例查询", 24 | "search_btn": "搜索", 25 | "summary" : "共 {total} 条结果,页 {page}/{max_page},耗时 {elapsed:.1f} ms", 26 | "prev" : "上一页", 27 | "next" : "下一页", 28 | "untitled" : "未命名", 29 | "view_origin": "原站链接", 30 | "back": "返回搜索", 31 | "view_stats": "📊 每日搜索统计", 32 | "date": "日期", 33 | "search_count": "搜索次数", 34 | "total_search_count": "总搜索次数", 35 | "example_report": "使用示例(实测报告)", 36 | "filter_by_oj": "筛选 OJ", 37 | "select_all": "全选", 38 | "deselect_all": "全不选", 39 | "moving_average": "渐近平均", 40 | }, 41 | "en": { 42 | "site_name" : "CPRet: Competitive Programming Problem Retrieval", 43 | "new_domain_info": "Our new domain is cpret.online. Our GitHub repo is CPRet. Please bookmark or star it!", 44 | "paper_info": "📰 September 2025 Update: 🎉 Congrats! Our project paper CPRet has been accepted by the NeurIPS 2025 D&B track!", 45 | "info": "📢 July 2025 Update: We've upgraded our model and synchronized the problem database for better retrieval!", 46 | "info2": "📢 October 2025 Update: Added new problems from several OJs and further optimized the model performance!", 47 | "placeholder": "Enter problem description or simplified statement (queries longer than 2048 tokens will be truncated)…", 48 | "template_btn": "Insert example query", 49 | "search_btn": "Search", 50 | "summary" : "{total} results, page {page}/{max_page}, {elapsed:.1f} ms", 51 | "prev" : "Prev", 52 | "next" : "Next", 53 | "untitled" : "Untitled", 54 | "view_origin": "Original Link", 55 | "back": "Back to Search", 56 | "view_stats": "📊 Daily Search Stats", 57 | "date": "Date", 58 | "search_count": "Search Count", 59 | "total_search_count": "Total Search Count", 60 | "example_report": "Test Cases (Demo Report)", 61 | "filter_by_oj": "Filter by OJ", 62 | "select_all": "Select All", 63 | "deselect_all": "Deselect All", 64 | "moving_average": "Moving Average", 65 | }, 66 | } 67 | 68 | 69 | def detect_lang(): 70 | """Language priority: ?lang= -> Accept-Language -> zh""" 71 | qlang = request.args.get("lang") 72 | if qlang in ("zh", "en"): 73 | return qlang 74 | header = request.headers.get("Accept-Language", "") 75 | return "en" if header.lower().startswith("en") else "zh" 76 | 77 | 78 | # ---------------- Configuration ---------------- # 79 | SEARCH_STATS_PATH = "search_stats.json" 80 | MODEL_PATH = os.getenv( 81 | "MODEL_PATH", 82 | "coldchair16/CPRetriever-Prob-Qwen3-4B-2510" 83 | ) 84 | EMB_PATH = os.getenv( 85 | 'EMB_PATH', 86 | './probs_2511_embs.npy' 87 | ) 88 | PROB_PATH = os.getenv( 89 | 'PROB_PATH', 90 | './probs_2511.jsonl' 91 | ) 92 | BF_16 = os.getenv( 93 | "BF_16", 94 | 1, 95 | ) 96 | 97 | PAGE_SIZE = 20 # Number of results per page 98 | # ------------------------------------- # 99 | 100 | app = Flask(__name__) 101 | 102 | # ---------- Load model & data on startup ---------- # 103 | print("Loading SentenceTransformer model …") 104 | if BF_16 == 1: 105 | model = SentenceTransformer(MODEL_PATH, trust_remote_code=True, model_kwargs={"torch_dtype": torch.bfloat16}) 106 | else: 107 | model = SentenceTransformer(MODEL_PATH, trust_remote_code=True) 108 | model.tokenizer.model_max_length = 2048 109 | model.max_seq_length = 2048 110 | 111 | print("Loading pre‑computed embeddings …") 112 | embs = np.load(EMB_PATH).astype("float32") 113 | embs /= np.linalg.norm(embs, axis=1, keepdims=True) 114 | 115 | print("Loading problem metadata …") 116 | probs = [json.loads(line) for line in open(PROB_PATH, "r", encoding="utf‑8")] 117 | 118 | # 收集所有可用的 OJ 源,用于前端下拉菜单 119 | available_ojs = sorted(list(set(p.get("source") for p in probs if p.get("source")))) 120 | 121 | assert len(probs) == embs.shape[0], "Mismatch between vector and problem count!" 122 | print(f"Ready! {len(probs)} problems indexed.\n") 123 | 124 | 125 | from functools import lru_cache 126 | import hashlib 127 | 128 | def _hash(text: str) -> str: 129 | """Simple hash to shorten long queries as dict keys""" 130 | return hashlib.md5(text.encode("utf-8")).hexdigest() 131 | 132 | @lru_cache(maxsize=1024) # Cache up to 1024 different queries 133 | def search_once(q: str): 134 | """Return ranked indices and similarity list (numpy array -> Python list)""" 135 | # -> float32 136 | q_emb = model.encode(q, convert_to_tensor=True).to(torch.float32).cpu().numpy() 137 | q_emb = q_emb / np.linalg.norm(q_emb) 138 | sims = embs.dot(q_emb) 139 | idx = sims.argsort()[::-1] 140 | return idx.tolist(), sims.tolist() 141 | 142 | def load_search_stats(): 143 | """Load daily search statistics from file.""" 144 | if not os.path.exists(SEARCH_STATS_PATH): 145 | return {} 146 | with open(SEARCH_STATS_PATH, "r", encoding="utf-8") as f: 147 | return json.load(f) 148 | 149 | def save_search_stats(stats: dict): 150 | """Save search statistics to file.""" 151 | with open(SEARCH_STATS_PATH, "w", encoding="utf-8") as f: 152 | json.dump(stats, f, ensure_ascii=False, indent=2) 153 | 154 | def record_search(): 155 | """Update today's search count and save to file.""" 156 | stats = load_search_stats() 157 | today = datetime.now().strftime("%Y-%m-%d") 158 | stats[today] = stats.get(today, 0) + 1 159 | save_search_stats(stats) 160 | 161 | @app.route("/", methods=["GET"]) 162 | def index(): 163 | lang = detect_lang() 164 | t = I18N[lang] 165 | 166 | q = request.args.get("q", "").strip() 167 | page = max(int(request.args.get("page", "1")), 1) 168 | 169 | if "oj" in request.args: 170 | selected_ojs = request.args.getlist("oj") 171 | else: 172 | selected_ojs = available_ojs 173 | 174 | results, total, elapsed = [], 0, 0.0 175 | 176 | if q: 177 | record_search() 178 | tic = time.perf_counter() 179 | idx, sims = search_once(q) 180 | elapsed = (time.perf_counter() - tic) * 1_000 181 | 182 | # 根据 OJ 筛选结果 183 | filtered_idx = [] 184 | for j in idx: 185 | p = probs[j] 186 | # 如果选中了特定的 OJ,并且当前问题的 source 不在选中列表中,则跳过 187 | if p.get("source") in selected_ojs: 188 | filtered_idx.append(j) 189 | 190 | total = len(filtered_idx) 191 | 192 | results = [] 193 | start, end = (page - 1) * PAGE_SIZE, page * PAGE_SIZE 194 | for rank, j in enumerate(filtered_idx[start:end], start=start + 1): 195 | p = probs[j] 196 | results.append({ 197 | "rank" : rank, 198 | "pid" : j, 199 | "score" : float(sims[j]), 200 | "title" : p.get("title") or t["untitled"], 201 | "url" : p.get("url", "#"), 202 | "source": p.get("source", ""), 203 | }) 204 | 205 | 206 | return render_template( 207 | "index.html", 208 | lang=lang, 209 | t=t, 210 | query=q, 211 | results=results, 212 | page=page, 213 | page_size=PAGE_SIZE, 214 | total=total, 215 | max_page=max(1, math.ceil(total / PAGE_SIZE)), 216 | elapsed=elapsed, 217 | available_ojs=available_ojs, 218 | selected_ojs=selected_ojs, 219 | ) 220 | 221 | @app.route("/p/") 222 | def problem(pid: int): 223 | lang = detect_lang() 224 | t = I18N[lang] 225 | 226 | if pid < 0 or pid >= len(probs): 227 | return f"Problem #{pid} not found", 404 228 | 229 | p = probs[pid] 230 | 231 | raw = p.get("text", "(No text)").replace("\n", "
") 232 | text_html = Markup(raw) 233 | 234 | return render_template( 235 | "problem.html", 236 | lang=lang, 237 | t=t, 238 | pid=pid, 239 | title=p.get("title") or t["untitled"], 240 | source=p.get("source", ""), 241 | url=p.get("url", "#"), 242 | text_html=text_html, 243 | query=request.args.get("q", ""), # Pass original query to return button 244 | page=request.args.get("page", "1"), 245 | selected_ojs_str=request.args.get("oj", "") # Pass selected_ojs_str for back button 246 | ) 247 | 248 | @app.route("/stats") 249 | def stats(): 250 | lang = detect_lang() 251 | t = I18N[lang] 252 | 253 | stats = load_search_stats() 254 | stats_data = sorted(stats.items(), key=lambda x: x[0], reverse=True) 255 | 256 | stats_draw = list(reversed(stats_data)) 257 | stats_draw = stats_draw[:-1] # Exclude today 258 | 259 | return render_template( 260 | "stats.html", 261 | lang=lang, 262 | t=t, 263 | stats=stats_data, 264 | stats_draw=stats_draw, 265 | ) 266 | 267 | # -------------- Local run entry -------------- # 268 | if __name__ == "__main__": 269 | # export FLASK_ENV=development to enable auto-reload/debug 270 | app.run(host="0.0.0.0", port=5000, debug=False) -------------------------------------------------------------------------------- /stage1/trainer_headers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. 17 | """ 18 | 19 | import contextlib 20 | import copy 21 | import functools 22 | import glob 23 | import importlib.metadata 24 | import inspect 25 | import json 26 | import math 27 | import os 28 | import random 29 | import re 30 | import shutil 31 | import sys 32 | import tempfile 33 | import time 34 | import warnings 35 | from collections.abc import Mapping 36 | from pathlib import Path 37 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union 38 | 39 | 40 | # Integrations must be imported before ML frameworks: 41 | # isort: off 42 | from transformers.integrations import ( 43 | get_reporting_integration_callbacks, 44 | ) 45 | 46 | # isort: on 47 | 48 | import huggingface_hub.utils as hf_hub_utils 49 | import numpy as np 50 | import torch 51 | import torch.distributed as dist 52 | from huggingface_hub import ModelCard, create_repo, upload_folder 53 | from packaging import version 54 | from torch import nn 55 | from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler 56 | 57 | from transformers import __version__ 58 | from transformers.configuration_utils import PretrainedConfig 59 | from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator 60 | from transformers.debug_utils import DebugOption, DebugUnderflowOverflow 61 | from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor 62 | from transformers.feature_extraction_utils import FeatureExtractionMixin 63 | from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend 64 | from transformers.image_processing_utils import BaseImageProcessor 65 | from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available 66 | from transformers.integrations.tpu import tpu_spmd_dataloader 67 | from transformers.modelcard import TrainingSummary 68 | from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model 69 | from transformers.models.auto.modeling_auto import ( 70 | MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, 71 | MODEL_MAPPING_NAMES, 72 | ) 73 | from transformers.optimization import Adafactor, get_scheduler 74 | from transformers.processing_utils import ProcessorMixin 75 | from transformers.pytorch_utils import ( 76 | ALL_LAYERNORM_LAYERS, 77 | is_torch_greater_or_equal_than_2_3, 78 | ) 79 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 80 | from transformers.trainer_callback import ( 81 | CallbackHandler, 82 | DefaultFlowCallback, 83 | ExportableState, 84 | PrinterCallback, 85 | ProgressCallback, 86 | TrainerCallback, 87 | TrainerControl, 88 | TrainerState, 89 | ) 90 | from transformers.trainer_pt_utils import ( 91 | DistributedTensorGatherer, 92 | EvalLoopContainer, 93 | IterableDatasetShard, 94 | LabelSmoother, 95 | LayerWiseDummyOptimizer, 96 | LengthGroupedSampler, 97 | SequentialDistributedSampler, 98 | distributed_broadcast_scalars, 99 | distributed_concat, 100 | find_batch_size, 101 | get_model_param_count, 102 | get_module_class_from_name, 103 | get_parameter_names, 104 | nested_concat, 105 | nested_detach, 106 | nested_numpify, 107 | nested_xla_mesh_reduce, 108 | reissue_pt_warnings, 109 | remove_dummy_checkpoint, 110 | set_rng_state_for_device, 111 | ) 112 | from transformers.trainer_utils import ( 113 | PREFIX_CHECKPOINT_DIR, 114 | BestRun, 115 | EvalLoopOutput, 116 | EvalPrediction, 117 | HPSearchBackend, 118 | HubStrategy, 119 | PredictionOutput, 120 | RemoveColumnsCollator, 121 | SaveStrategy, 122 | TrainerMemoryTracker, 123 | TrainOutput, 124 | check_target_module_exists, 125 | default_compute_objective, 126 | denumpify_detensorize, 127 | enable_full_determinism, 128 | find_executable_batch_size, 129 | get_last_checkpoint, 130 | has_length, 131 | neftune_post_forward_hook, 132 | number_of_arguments, 133 | seed_worker, 134 | set_seed, 135 | speed_metrics, 136 | ) 137 | from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments 138 | from transformers.utils import ( 139 | ADAPTER_CONFIG_NAME, 140 | ADAPTER_SAFE_WEIGHTS_NAME, 141 | ADAPTER_WEIGHTS_NAME, 142 | CONFIG_NAME, 143 | SAFE_WEIGHTS_INDEX_NAME, 144 | SAFE_WEIGHTS_NAME, 145 | WEIGHTS_INDEX_NAME, 146 | WEIGHTS_NAME, 147 | XLA_FSDPV2_MIN_VERSION, 148 | PushInProgress, 149 | PushToHubMixin, 150 | can_return_loss, 151 | find_labels, 152 | is_accelerate_available, 153 | is_apex_available, 154 | is_apollo_torch_available, 155 | is_bitsandbytes_available, 156 | is_datasets_available, 157 | is_galore_torch_available, 158 | is_grokadamw_available, 159 | is_in_notebook, 160 | is_ipex_available, 161 | is_liger_kernel_available, 162 | is_lomo_available, 163 | is_peft_available, 164 | is_safetensors_available, 165 | is_sagemaker_dp_enabled, 166 | is_sagemaker_mp_enabled, 167 | is_schedulefree_available, 168 | is_torch_compile_available, 169 | is_torch_mlu_available, 170 | is_torch_mps_available, 171 | is_torch_musa_available, 172 | is_torch_neuroncore_available, 173 | is_torch_npu_available, 174 | is_torch_xla_available, 175 | is_torch_xpu_available, 176 | is_torchao_available, 177 | logging, 178 | strtobool, 179 | ) 180 | from transformers.utils.deprecation import deprecate_kwarg 181 | from transformers.utils.quantization_config import QuantizationMethod 182 | 183 | 184 | DEFAULT_CALLBACKS = [DefaultFlowCallback] 185 | DEFAULT_PROGRESS_CALLBACK = ProgressCallback 186 | 187 | if is_in_notebook(): 188 | from transformers.utils.notebook import NotebookProgressCallback 189 | 190 | DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback 191 | 192 | if is_apex_available(): 193 | from apex import amp 194 | 195 | if is_datasets_available(): 196 | import datasets 197 | 198 | if is_torch_xla_available(): 199 | import torch_xla.core.xla_model as xm 200 | import torch_xla.debug.metrics as met 201 | from torch_xla import __version__ as XLA_VERSION 202 | 203 | IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION) 204 | if IS_XLA_FSDPV2_POST_2_2: 205 | import torch_xla.distributed.spmd as xs 206 | import torch_xla.runtime as xr 207 | else: 208 | IS_XLA_FSDPV2_POST_2_2 = False 209 | 210 | 211 | if is_sagemaker_mp_enabled(): 212 | import smdistributed.modelparallel.torch as smp 213 | from smdistributed.modelparallel import __version__ as SMP_VERSION 214 | 215 | IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") 216 | 217 | from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat 218 | else: 219 | IS_SAGEMAKER_MP_POST_1_10 = False 220 | 221 | 222 | if is_safetensors_available(): 223 | import safetensors.torch 224 | 225 | if is_peft_available(): 226 | from peft import PeftModel 227 | 228 | 229 | if is_accelerate_available(): 230 | from accelerate import Accelerator, skip_first_batches 231 | from accelerate import __version__ as accelerate_version 232 | from accelerate.state import AcceleratorState 233 | from accelerate.utils import ( 234 | AutocastKwargs, 235 | DistributedDataParallelKwargs, 236 | DistributedType, 237 | load_fsdp_model, 238 | load_fsdp_optimizer, 239 | save_fsdp_model, 240 | save_fsdp_optimizer, 241 | ) 242 | 243 | DATA_SAMPLERS = [RandomSampler] 244 | if version.parse(accelerate_version) > version.parse("0.23.0"): 245 | from accelerate.data_loader import SeedableRandomSampler 246 | 247 | DATA_SAMPLERS += [SeedableRandomSampler] 248 | 249 | if is_deepspeed_available(): 250 | from accelerate.utils import DeepSpeedSchedulerWrapper 251 | 252 | if is_accelerate_available("0.28.0"): 253 | from accelerate.utils import DataLoaderConfiguration 254 | 255 | 256 | def _is_peft_model(model): 257 | if is_peft_available(): 258 | classes_to_check = (PeftModel,) if is_peft_available() else () 259 | # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321 260 | if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"): 261 | from peft import PeftMixedModel 262 | 263 | classes_to_check = (*classes_to_check, PeftMixedModel) 264 | return isinstance(model, classes_to_check) 265 | return False 266 | 267 | 268 | def _get_fsdp_ckpt_kwargs(): 269 | # TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release 270 | if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters): 271 | return {"adapter_only": True} 272 | else: 273 | return {} 274 | 275 | 276 | def safe_globals(): 277 | # Starting from version 2.4 PyTorch introduces a check for the objects loaded 278 | # with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes 279 | # a default and requires allowlisting of objects being loaded. 280 | # See: https://github.com/pytorch/pytorch/pull/137602 281 | # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals 282 | # See: https://github.com/huggingface/accelerate/pull/3036 283 | if version.parse(torch.__version__).release < version.parse("2.6").release: 284 | return contextlib.nullcontext() 285 | 286 | np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core 287 | allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype] 288 | # numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for 289 | # all versions of numpy 290 | allowlist += [type(np.dtype(np.uint32))] 291 | 292 | return torch.serialization.safe_globals(allowlist) 293 | 294 | 295 | if TYPE_CHECKING: 296 | import optuna 297 | 298 | if is_datasets_available(): 299 | import datasets 300 | 301 | logger = logging.get_logger(__name__) 302 | 303 | 304 | # Name of the files used for checkpointing 305 | TRAINING_ARGS_NAME = "training_args.bin" 306 | TRAINER_STATE_NAME = "trainer_state.json" 307 | OPTIMIZER_NAME = "optimizer.pt" 308 | SCALER_NAME = "scaler.pt" 309 | OPTIMIZER_NAME_BIN = "optimizer.bin" 310 | SCHEDULER_NAME = "scheduler.pt" 311 | FSDP_MODEL_NAME = "pytorch_model_fsdp" 312 | 313 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CPRet: A Dataset, Benchmark, and Model for Retrieval in Competitive Programming 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2505.12925-b31b1b.svg)](https://arxiv.org/abs/2505.12925) 4 | [![🤗 HF Collection](https://img.shields.io/badge/HuggingFace-CPRet-yellow)](https://huggingface.co/collections/coldchair16/cpret-682451276f05c5988fcbdf34) 5 | 6 | Email contact: 2317757009@qq.com 7 | 8 | ## 🌐 Try Online Demo 9 | 10 | We provide an **online demo** of the CPRet retrieval service, available at: 11 | 12 | 👉 [https://www.cpret.online/](https://www.cpret.online/) 13 | 14 | This demo can assist in **duplicate problem detection** by retrieving potentially similar problems, though final identification still requires manual verification. 15 | 16 | It also supports **similar problem retrieval** to help broaden your problem-solving perspective. 17 | 18 | You can input either a **full problem description** or a **simplified version**, and the system will return the most relevant existing problems. 19 | 20 | You can refer to the usage examples of the retrieval platform at: [https://github.com/coldchair/CPRet/blob/main/TestCases.md](https://github.com/coldchair/CPRet/blob/main/TestCases.md) 21 | 22 | It runs the same codebase and embedding model as the local deployment (see below), so you can preview its capabilities before setting up your own instance. 23 | 24 | ## 🚀 News 25 | 26 | **Oct 2025: CPRetriever-Prob-Qwen3-4B-2510 Released with Enhanced Retrieval Performance!** 27 | 28 | We're excited to announce a major update to the CPRetriever model series! We've trained the new [**CPRetriever-Prob-Qwen3-4B-2510**](https://huggingface.co/coldchair16/CPRetriever-Prob-Qwen3-4B-2510) model based on [**Qwen3-Embedding-4B**](https://huggingface.co/Qwen/Qwen3-Embedding-4B), released in June 2025, and it has achieved **state-of-the-art results** in problem-related retrieval tasks. Concurrently, we've also updated our website's retrieval problem database to the latest Oct 2025 version. 29 | 30 | Here's a comparison of model performance: 31 | 32 | | model | type | size | Text-to-Code | Code-to-Code | Problem-to-Duplicate | Simplified-to-Full | Avg | 33 | | :------------------------ | :--- | :--- | :----------- | :----------- | :------------------- | :----------------- | :----- | 34 | | CPRetriever-Code | code | 2B | 70.40 | 70.59 | 38.68 | 81.45 | 65.28 | 35 | | CPRetriever-Prob | code | 2B | 56.50 | 70.68 | 60.06 | 90.74 | 69.50 | 36 | | CPRetriever-Prob-Qwen3-4B | code | 4B | 65.85 | 70.19 | 71.45 | 95.03 | 75.63 | 37 | | CPRetriever-Prob-Qwen3-4B-2510 | code | 4B | 80.84 | 87.10 | 74.33 | 96.15 | 84.61 | 38 | 39 | The CPRetriever-Prob-Qwen3-4B-2510 model follows the same training procedure and dataset as CPRetriever-Prob-Qwen3-4B, but was retrained in October 2025 with adjusted data proportions, an extended maximum sequence length of 2048, and optimized hyperparameters for improved performance. 40 | 41 | 42 | **Sept 2025:** 🎉 We’re excited to announce that our paper has been accepted to the **NeurIPS 2025 D&B Track**! 43 | 44 | ## 📌 Overview 45 | 46 | **CPRet** is a comprehensive suite for competitive programming retrieval research, consisting of: 47 | 48 | * A large-scale dataset and benchmark for retrieval tasks in coding contests. 49 | * A dual-stage training pipeline with contrastive pretraining and task-specific fine-tuning. 50 | * A local retrieval server for **simplified description** and **duplicate problem** search, powered by our trained model [**CPRetriever-Prob-Qwen3-4B-2510**](https://huggingface.co/coldchair16/CPRetriever-Prob-Qwen3-4B-2510). 51 | 52 | We define the following **four core retrieval tasks** to support both practical applications and academic benchmarking: 53 | 54 | 1. **Text-to-Code (T2C):** Retrieve relevant code given a natural language problem description. 55 | 2. **Code-to-Code (C2C):** Retrieve other implementations of the same problem based on a given solution. 56 | 3. **Problem-to-Duplicate (P2D):** Detect duplicate or near-duplicate problems from existing contest archives. 57 | 4. **Simplified-to-Full (S2F):** Retrieve the original full version of a simplified problem. 58 | 59 | 60 | ## 🧰 Repository Contents 61 | 62 | * `cp-retrieval-server/`: Code for running a local retrieval web service. 63 | * `stage1/`: Code for stage-1 contrastive pretraining. 64 | * `stage2/`: Code for stage-2 problem-level fine-tuning. 65 | 66 | --- 67 | 68 | ## ⚙️ Setup 69 | 70 | ### Environment 71 | 72 | * Recommended: `python >= 3.10` 73 | 74 | * Install dependencies: 75 | 76 | ```bash 77 | pip install -r requirements.txt 78 | ``` 79 | 80 | * Install PyTorch (with CUDA support if needed): 81 | → Refer to: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/) 82 | 83 | * PyTorch ≥ 2.0 is recommended. 84 | 85 | ### 🔁 Accessing Hugging Face from Restricted Regions 86 | 87 | If you're experiencing connectivity issues with Hugging Face, consider using the official mirror: 88 | 89 | ```python 90 | import os 91 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' 92 | ``` 93 | 94 | Or set it as an environment variable: 95 | 96 | ```bash 97 | export HF_ENDPOINT=https://hf-mirror.com 98 | ``` 99 | 100 | --- 101 | 102 | ## 🚀 Run Local Retrieval Service 103 | 104 | 105 | 1. **Download embeddings:** 106 | 107 | * You run `cp-retrieval-server/download.py` to download the problems and embeddings. 108 | 109 | * **If you are using the new model, `CPRetriever-Prob-Qwen3-4B-2510`:** 110 | * Please download the following files from [HF dataset CPRet-Embeddings](https://huggingface.co/datasets/coldchair16/CPRet-Embeddings) into the `cp-retrieval-server/` directory: 111 | * `probs_2511.jsonl` 112 | * `probs_2511_embs.npy` 113 | 114 | * **If you are using the old model, `CPRetriever-Prob`:** 115 | * Please download the following files from [HF dataset CPRet-Embeddings](https://huggingface.co/datasets/coldchair16/CPRet-Embeddings) into the `cp-retrieval-server/` directory: 116 | * `probs_embs.npy` 117 | * `probs.jsonl` 118 | 119 | 120 | 121 | 2. **Start the service:** 122 | 123 | ```bash 124 | cd cp-retrieval-server 125 | ``` 126 | 127 | If you're using the **old model (`CPRetriever-Prob`)**, set the following environment variables before starting the service: 128 | 129 | ```bash 130 | export MODEL_PATH=coldchair16/CPRetriever-Prob 131 | export EMB_PATH=./probs_embs.npy 132 | export PROB_PATH=./probs.jsonl 133 | ``` 134 | 135 | Note: bf16 is enabled by default. If your device does not support it, set the environment variable BF_16=0: 136 | 137 | ```bash 138 | export BF_16=0 139 | ```` 140 | 141 | Then, run: 142 | 143 | ```bash 144 | python app.py 145 | ``` 146 | 147 | 3. **About the Dataset:** 148 | 149 | The current retrieval problem database (as of Nov 2025) includes problems from the following online judges: 150 | 151 | * [Codeforces](https://codeforces.com/) 152 | * [AtCoder](https://atcoder.jp/) 153 | * [SPOJ](https://www.spoj.com/) 154 | * [Nowcoder](https://ac.nowcoder.com/) 155 | * [Luogu](https://www.luogu.com.cn/) 156 | * [Loj](https://loj.ac/) 157 | * [CodeChef](https://www.codechef.com/dashboard) 158 | * [AIZU](https://judge.u-aizu.ac.jp/onlinejudge/) 159 | * [UOJ](https://uoj.ac/) 160 | * [QOJ](https://qoj.ac/) 161 | 162 | The data is collected up to **Nov 2025**. 163 | You can add your own data source and generate embeddings using [`compute_embs.py`](cp-retrieval-server/compute_embs.py). Running this process for the current database on a H100 GPU takes approximately 4 GPU hours. 164 | 165 | If you have access to a larger or more diverse problem dataset, **we welcome contributions and are happy to update the collection** — feel free to contact us (2317757009@qq.com) or open an issue/pull request. 166 | 167 | 168 | 4. **System Requirements:** 169 | 170 | This service can be run on **CPU** or **GPU**, depending on your environment. 171 | We recommend the following memory for smooth operation: 172 | 173 | * For the **2B old models** (e.g., `CPRetriever-Prob`): at least **16GB of system memory or GPU VRAM**. 174 | * For the **4B new model** (`CPRetriever-Prob-Qwen3-4B-2510`): **32GB or more of system memory or GPU VRAM**. 175 | 176 | The above requirements are for fp32; if the device supports bf16, only half of the memory/VRAM is needed. 177 | 178 | Typical query latency: 179 | 180 | * On **CPU** (8 cores): **10–20 seconds**. 181 | * On **GPU** (e.g., A800): **0.1–1 seconds**. 182 | 183 | Inference time depends on the input length. 184 | 185 | 186 | --- 187 | 188 | ## 🏋️‍♀️ Training Instructions 189 | 190 | > **⚠️ Note**: Recommended GPU memory ≥ **50 GB** to avoid OOM. 191 | 192 | ## 🔧 Stage 1: Contrastive Pretraining 193 | 194 | ```bash 195 | cd stage1 196 | torchrun --nproc_per_node=8 train.py 197 | ```` 198 | 199 | * Change `--nproc_per_node` to match the number of available GPUs. 200 | * Use `--help` to see all configurable hyperparameters. 201 | 202 | ### ⚠️ Note on Using `Salesforce/SFR-Embedding-Code-2B_R` 203 | 204 | If you are using [`Salesforce/SFR-Embedding-Code-2B_R`](https://huggingface.co/Salesforce/SFR-Embedding-Code-2B_R) as your encoder, make sure to **manually disable `device_map="auto"`** when loading the model. 205 | 206 | The original code might look like this: 207 | 208 | ```python 209 | self.model = Gemma2Model.from_pretrained(config._name_or_path, trust_remote_code=True, is_causal=False, device_map="auto") 210 | self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, trust_remote_code=True, device_map="auto") 211 | ``` 212 | 213 | This setting can cause the model to skip training due to automatic device placement. 214 | **Please change it to:** 215 | 216 | ```python 217 | self.model = Gemma2Model.from_pretrained(config._name_or_path, trust_remote_code=True, is_causal=False, device_map=None) 218 | self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, trust_remote_code=True, device_map=None) 219 | ``` 220 | 221 | Alternatively, you can directly copy the patched file from our repo: 222 | 👉 [modeling\_gemma2.py](https://huggingface.co/coldchair16/CPRetriever-Code/blob/main/modeling_gemma2.py) 223 | 224 | 225 | --- 226 | ### Stage 2: Problem-Level Fine-Tuning 227 | 228 | ```bash 229 | cd stage2 230 | torchrun --nproc_per_node=1 train.py 231 | ``` 232 | 233 | * Also supports `--help` to inspect all args. 234 | 235 | --- 236 | 237 | ## 🔧 Notable Hyperparameters 238 | 239 | * `--model_path`: Can be either an HF model repo (e.g. `coldchair16/CPRetriever-Code`) or a local directory supporting SentenceTransformer. 240 | * `--eval_only True`: Run evaluation without training. 241 | 242 | 243 | ## 📫 Citation & License 244 | 245 | If you find **CPRet** useful in your research or applications, please consider citing our paper: 246 | 247 | ```bibtex 248 | @misc{deng2025cpretdatasetbenchmarkmodel, 249 | title = {CPRet: A Dataset, Benchmark, and Model for Retrieval in Competitive Programming}, 250 | author = {Han Deng and Yuan Meng and Shixiang Tang and Wanli Ouyang and Xinzhu Ma}, 251 | year = {2025}, 252 | eprint = {2505.12925}, 253 | archivePrefix = {arXiv}, 254 | primaryClass = {cs.SE}, 255 | url = {https://arxiv.org/abs/2505.12925} 256 | } 257 | ``` 258 | 259 | ### 📄 License 260 | 261 | This work is released for **research and non-commercial use only**. 262 | 263 | **License:** [CC BY-NC 4.0 (Attribution–NonCommercial)](https://creativecommons.org/licenses/by-nc/4.0/) 264 | 265 | 266 | 267 | -------------------------------------------------------------------------------- /stage1/grad_cache/grad_cache.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Callable, Any 2 | from contextlib import nullcontext 3 | from itertools import repeat 4 | from collections import UserDict 5 | import logging 6 | 7 | import torch 8 | from torch import nn, Tensor 9 | from torch.cuda.amp import GradScaler, autocast 10 | 11 | from grad_cache.context_managers import RandContext 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class GradCache: 17 | """ 18 | Gradient Cache class. Implements input chunking, first graph-less forward pass, Gradient Cache creation, second 19 | forward & backward gradient computation. Optimizer step is not included. Native torch automatic mixed precision is 20 | supported. User needs to handle gradient unscaling and scaler update after a gradeitn cache step. 21 | """ 22 | def __init__( 23 | self, 24 | models: List[nn.Module], 25 | chunk_sizes: Union[int, List[int]], 26 | loss_fn: Callable[..., Tensor], 27 | split_input_fn: Callable[[Any, int], Any] = None, 28 | get_rep_fn: Callable[..., Tensor] = None, 29 | fp16: bool = False, 30 | scaler: GradScaler = None, 31 | ): 32 | """ 33 | Initialize the Gradient Cache class instance. 34 | :param models: A list of all encoder models to be updated by the current cache. 35 | :param chunk_sizes: An integer indicating chunk size. Or a list of integers of chunk size for each model. 36 | :param loss_fn: A loss function that takes arbitrary numbers of representation tensors and 37 | arbitrary numbers of keyword arguments as input. It should not in any case modify the input tensors' relations 38 | in the autograd graph, which are later relied upon to create the gradient cache. 39 | :param split_input_fn: An optional function that split generic model input into chunks. If not provided, this 40 | class will try its best to split the inputs of supported types. See `split_inputs` function. 41 | :param get_rep_fn: An optional function that takes generic model output and return representation tensors. If 42 | not provided, the generic output is assumed to be the representation tensor. 43 | :param fp16: If True, run mixed precision training, which requires scaler to also be set. 44 | :param scaler: A GradScaler object for automatic mixed precision training. 45 | """ 46 | self.models = models 47 | 48 | if isinstance(chunk_sizes, int): 49 | self.chunk_sizes = [chunk_sizes for _ in range(len(models))] 50 | else: 51 | self.chunk_sizes = chunk_sizes 52 | 53 | self.split_input_fn = split_input_fn 54 | self.get_rep_fn = get_rep_fn 55 | self.loss_fn = loss_fn 56 | 57 | if fp16: 58 | assert scaler is not None, "mixed precision training requires a gradient scaler passed in" 59 | 60 | self.fp16 = fp16 61 | self.scaler = scaler 62 | 63 | self._get_input_tensors_strict = False 64 | 65 | def __call__(self, *args, **kwargs): 66 | """ 67 | Call the cache_step function. 68 | :return: Current step loss. 69 | """ 70 | return self.cache_step(*args, **kwargs) 71 | 72 | def split_inputs(self, model_input, chunk_size: int) -> List: 73 | """ 74 | Split input into chunks. Will call user provided `split_input_fn` if specified. Otherwise, 75 | it can handle input types of tensor, list of tensors and dictionary of tensors. 76 | :param model_input: Generic model input. 77 | :param chunk_size: Size of each chunk. 78 | :return: A list of chunked model input. 79 | """ 80 | # delegate splitting to user provided function 81 | if self.split_input_fn is not None: 82 | return self.split_input_fn(model_input, chunk_size) 83 | 84 | if isinstance(model_input, (dict, UserDict)) and all(isinstance(x, Tensor) for x in model_input.values()): 85 | keys = list(model_input.keys()) 86 | chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys] 87 | return [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))] 88 | 89 | elif isinstance(model_input, list) and all(isinstance(x, Tensor) for x in model_input): 90 | chunked_x = [t.split(chunk_size, dim=0) for t in model_input] 91 | return [list(s) for s in zip(*chunked_x)] 92 | 93 | elif isinstance(model_input, Tensor): 94 | return list(model_input.split(chunk_size, dim=0)) 95 | 96 | elif isinstance(model_input, tuple) and list(map(type, model_input)) == [list, dict]: 97 | args_chunks = self.split_inputs(model_input[0], chunk_size) 98 | kwargs_chunks = self.split_inputs(model_input[1], chunk_size) 99 | return list(zip(args_chunks, kwargs_chunks)) 100 | 101 | else: 102 | raise NotImplementedError(f'Model input split not implemented for type {type(model_input)}') 103 | 104 | def get_input_tensors(self, model_input) -> List[Tensor]: 105 | """ 106 | Recursively go through model input and grab all tensors, which are then used to record current device random 107 | states. This method will do its best to parse types of Tensor, tuple, list, dict and UserDict. Other types will 108 | be ignored unless self._get_input_tensors_strict is set to True, in which case an exception will be raised. 109 | :param model_input: input to model 110 | :return: all torch tensors in model_input 111 | """ 112 | if isinstance(model_input, Tensor): 113 | return [model_input] 114 | 115 | elif isinstance(model_input, (list, tuple)): 116 | return sum((self.get_input_tensors(x) for x in model_input), []) 117 | 118 | elif isinstance(model_input, (dict, UserDict)): 119 | return sum((self.get_input_tensors(x) for x in model_input.values()), []) 120 | 121 | elif self._get_input_tensors_strict: 122 | raise NotImplementedError(f'get_input_tensors not implemented for type {type(model_input)}') 123 | 124 | else: 125 | return [] 126 | 127 | def model_call(self, model: nn.Module, model_input): 128 | """ 129 | Literally call the model's __call__ method. 130 | :param model: model to be called 131 | :param model_input: input to the model call 132 | :return: model output 133 | """ 134 | with autocast() if self.fp16 else nullcontext(): 135 | if isinstance(model_input, Tensor): 136 | return model(model_input) 137 | elif isinstance(model_input, list): 138 | return model(*model_input) 139 | elif isinstance(model_input, (dict, UserDict)): 140 | return model(**model_input) 141 | elif isinstance(model_input, tuple) and list(map(type, model_input)) == [list, dict]: 142 | model_args, model_kwargs = model_input 143 | return model(*model_args, **model_kwargs) 144 | else: 145 | raise NotImplementedError 146 | 147 | def get_reps(self, model_out) -> Tensor: 148 | """ 149 | Return representation tensor from generic model output 150 | :param model_out: generic model output 151 | :return: a single tensor corresponding to the model representation output 152 | """ 153 | if self.get_rep_fn is not None: 154 | return self.get_rep_fn(model_out) 155 | else: 156 | return model_out 157 | 158 | def compute_loss(self, *reps: Tensor, **loss_kwargs) -> Tensor: 159 | """ 160 | Compute the loss based on the representation tensors. The tensors should be ordered same as the list of models 161 | registered in this GradCache class instance. 162 | :param reps: Representations for computing the loss. 163 | :param loss_kwargs: Keyword arguments input to the loss function. 164 | :return: the loss tensor. 165 | """ 166 | loss = self.loss_fn(*reps, **loss_kwargs) 167 | return loss 168 | 169 | def forward_no_grad( 170 | self, 171 | model: nn.Module, 172 | model_inputs, 173 | ) -> [Tensor, List[RandContext]]: 174 | """ 175 | The first forward pass without gradient computation. 176 | :param model: Encoder model. 177 | :param model_inputs: Model input already broken into chunks. 178 | :return: A tuple of a) representations and b) recorded random states. 179 | """ 180 | rnd_states = [] 181 | model_reps = [] 182 | 183 | with torch.no_grad(): 184 | for x in model_inputs: 185 | rnd_states.append(RandContext(*self.get_input_tensors(x))) 186 | y = self.model_call(model, x) 187 | model_reps.append(self.get_reps(y)) 188 | 189 | # concatenate all sub-batch representations 190 | model_reps = torch.cat(model_reps, dim=0) 191 | return model_reps, rnd_states 192 | 193 | def build_cache(self, *reps: Tensor, **loss_kwargs) -> [List[Tensor], Tensor]: 194 | """ 195 | Compute the gradient cache 196 | :param reps: Computed representations from all encoder models 197 | :param loss_kwargs: Extra keyword arguments to the loss function 198 | :return: A tuple of a) gradient cache for each encoder model, and b) loss tensor 199 | """ 200 | reps = [r.detach().requires_grad_() for r in reps] 201 | with autocast() if self.fp16 else nullcontext(): 202 | loss = self.compute_loss(*reps, **loss_kwargs) 203 | 204 | if self.fp16: 205 | self.scaler.scale(loss).backward() 206 | else: 207 | loss.backward() 208 | 209 | cache = [r.grad for r in reps] 210 | 211 | return cache, loss.detach() 212 | 213 | def forward_backward( 214 | self, 215 | model: nn.Module, 216 | model_inputs, 217 | cached_gradients: List[Tensor], 218 | random_states: List[RandContext], 219 | no_sync_except_last: bool = False 220 | ): 221 | """ 222 | Run the second forward and the backward pass to compute gradient for a model. 223 | :param model: Encoder model. 224 | :param model_inputs: Chunked input to the encoder model. 225 | :param cached_gradients: Chunked gradient cache tensor for each input. 226 | :param random_states: Each input's device random state during the first forward. 227 | :param no_sync_except_last: If True, under distributed setup, only trigger gradient reduction across processes 228 | for the last sub-batch's forward-backward pass. 229 | """ 230 | if no_sync_except_last: 231 | sync_contexts = [model.no_sync for _ in range(len(model_inputs) - 1)] + [nullcontext] 232 | else: 233 | sync_contexts = [nullcontext for _ in range(len(model_inputs))] 234 | 235 | for x, state, gradient, sync_context in zip(model_inputs, random_states, cached_gradients, sync_contexts): 236 | with sync_context(): 237 | with state: 238 | y = self.model_call(model, x) 239 | reps = self.get_reps(y) 240 | 241 | surrogate = torch.dot(reps.flatten(), gradient.flatten()) 242 | surrogate.backward() 243 | 244 | def cache_step( 245 | self, 246 | *model_inputs, 247 | no_sync_except_last: bool = False, 248 | **loss_kwargs 249 | ) -> Tensor: 250 | """ 251 | Run a cached step to compute gradient over the inputs. 252 | :param model_inputs: Input to each encoder model. Should be in similar order as the class's model. 253 | :param no_sync_except_last: If True, under distributed setup, for each model, only trigger gradient reduction 254 | across processes for the last sub-batch's forward-backward pass. 255 | :param loss_kwargs: Additional keyword arguments to the loss function. 256 | :return: The current's loss. 257 | """ 258 | all_reps = [] 259 | all_rnd_states = [] 260 | 261 | if no_sync_except_last: 262 | assert all(map(lambda m: isinstance(m, nn.parallel.DistributedDataParallel), self.models)), \ 263 | 'Some of models are not wrapped in DistributedDataParallel. Make sure you are running DDP with ' \ 264 | 'proper initializations.' 265 | 266 | model_inputs = [self.split_inputs(x, chunk_size) for x, chunk_size in zip(model_inputs, self.chunk_sizes)] 267 | 268 | for model, x in zip(self.models, model_inputs): 269 | model_reps, rnd_states = self.forward_no_grad(model, x) 270 | all_reps.append(model_reps) 271 | all_rnd_states.append(rnd_states) 272 | 273 | cache, loss = self.build_cache(*all_reps, **loss_kwargs) 274 | cache = [c.split(chunk_size) for c, chunk_size in zip(cache, self.chunk_sizes)] 275 | 276 | for model, x, model_cache, rnd_states in zip( 277 | self.models, model_inputs, cache, all_rnd_states): 278 | self.forward_backward(model, x, model_cache, rnd_states, no_sync_except_last=no_sync_except_last) 279 | 280 | return loss 281 | --------------------------------------------------------------------------------