├── stage1
├── grad_cache
│ ├── pytorch_lightning
│ │ ├── requirements.txt
│ │ ├── readme.md
│ │ ├── pl_example.py
│ │ └── pl_gradcache.py
│ ├── __init__.py
│ ├── cachex
│ │ ├── __init__.py
│ │ ├── tree_utils.py
│ │ ├── training.py
│ │ └── functional.py
│ ├── context_managers.py
│ ├── loss.py
│ ├── functional.py
│ └── grad_cache.py
├── dataset_process.py
├── train.py
├── dataset.py
├── grad_cache_custom.py
├── loss.py
├── eval_metric.py
├── custom_trainer.py
└── trainer_headers.py
├── requirements.txt
├── cp-retrieval-server
├── templates
│ ├── problem.html
│ ├── base.html
│ ├── stats.html
│ └── index.html
├── download.py
├── example.py
├── compute_embs.py
└── app.py
├── .gitignore
├── stage2
├── dataset_process.py
├── trainer_custom.py
└── train.py
├── TestCases.md
└── README.md
/stage1/grad_cache/pytorch_lightning/requirements.txt:
--------------------------------------------------------------------------------
1 | lightning
2 | pytorch_metric_learning
3 |
--------------------------------------------------------------------------------
/stage1/grad_cache/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from .grad_cache import GradCache
3 | except ModuleNotFoundError:
4 | pass
5 |
--------------------------------------------------------------------------------
/stage1/grad_cache/cachex/__init__.py:
--------------------------------------------------------------------------------
1 | from .functional import chunk_encode, cache_grad, unchunk_args, grad_cached
2 | from .tree_utils import tree_chunk, tree_unchunk
3 |
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | sentence-transformers==3.4.1
2 | transformers==4.52.4
3 | torch==2.6.0
4 | torchaudio==2.6.0
5 | torchvision==0.21.0
6 | pandas
7 | tensorboard
8 | datasets
9 | accelerate
10 | einops
11 | Flask
--------------------------------------------------------------------------------
/stage1/grad_cache/cachex/tree_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import jax
4 |
5 |
6 | def tree_chunk(tree: Any, n_chunk: int, axis: int = 0) -> Any:
7 | return jax.tree_map(
8 | lambda v: v.reshape(v.shape[:axis] + (n_chunk, -1) + v.shape[axis + 1:]),
9 | tree
10 | )
11 |
12 |
13 | def tree_unchunk(tree: Any, axis: int = 0) -> Any:
14 | return jax.tree_map(
15 | lambda x: x.reshape(x.shape[:axis] + (-1,) + x.shape[axis + 2:]),
16 | tree
17 | )
18 |
--------------------------------------------------------------------------------
/cp-retrieval-server/templates/problem.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 | {% block content %}
3 |
4 | ← {{ t.back }}
5 |
6 |
7 |
{{ title }}
8 | {{ source }}
9 |
10 |
11 |
12 |
13 | {{ text_html }}
14 |
15 |
16 |
17 |
18 | {{ t.view_origin }} 🔗
19 |
20 | {% endblock %}
21 |
--------------------------------------------------------------------------------
/stage1/grad_cache/context_managers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.checkpoint import get_device_states, set_device_states
3 |
4 |
5 | class RandContext:
6 | def __init__(self, *tensors):
7 | self.fwd_cpu_state = torch.get_rng_state()
8 | self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors)
9 |
10 | def __enter__(self):
11 | self._fork = torch.random.fork_rng(
12 | devices=self.fwd_gpu_devices,
13 | enabled=True
14 | )
15 | self._fork.__enter__()
16 | torch.set_rng_state(self.fwd_cpu_state)
17 | set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states)
18 |
19 | def __exit__(self, exc_type, exc_val, exc_tb):
20 | self._fork.__exit__(exc_type, exc_val, exc_tb)
21 | self._fork = None
--------------------------------------------------------------------------------
/stage1/grad_cache/cachex/training.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import jax
4 | import jax.numpy as jnp
5 |
6 | from .functional import chunk_encode, cache_grad, unchunk_args
7 |
8 |
9 | def cache_train_step(loss_fn, state, ss, tt, axis='device'):
10 | def encode_with_params(params, **kwargs):
11 | return state.apply_fn(params=params, **kwargs)
12 |
13 | encode_fn = chunk_encode(partial(encode_with_params, state.params))
14 | grad_fn = cache_grad(encode_with_params)
15 |
16 | s_reps = encode_fn(**ss)
17 | t_reps = encode_fn(**tt)
18 |
19 | @unchunk_args(axis=0, argnums=(0, 1))
20 | def grad_cache_fn(xx, yy):
21 | return jnp.mean(loss_fn(xx, yy, axis=axis))
22 | loss, (s_grads, t_grads) = jax.value_and_grad(grad_cache_fn, argnums=(0, 1))(s_reps, t_reps)
23 |
24 | grads = jax.tree_map(lambda v: jnp.zeros_like(v), state.params)
25 | grads = grad_fn(state.params, grads, s_grads, **ss)
26 | grads = grad_fn(state.params, grads, t_grads, **tt)
27 |
28 | loss, grads = jax.lax.pmean([loss, grads], axis)
29 | new_state = state.apply_gradients(grads=grads)
30 | return loss, new_state
31 |
--------------------------------------------------------------------------------
/cp-retrieval-server/templates/base.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | {{ t.site_name }}
6 |
7 |
8 |
9 |
10 |
11 |
23 | {% block content %}{% endblock %}
24 |
25 |
26 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | logs/
2 | models/
3 | results/
4 | dataset/
5 | *.npy
6 | *.jsonl
7 | *.json
8 |
9 | # Python 相关
10 | __pycache__/
11 | *.py[cod]
12 | *.pyo
13 | *.pyd
14 | *.so
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 | .pytest_cache/
33 |
34 | # 虚拟环境
35 | venv/
36 | env/
37 | .venv/
38 | ENV/
39 | env.bak/
40 | venv.bak/
41 |
42 | # 操作系统生成的文件
43 | .DS_Store
44 | Thumbs.db
45 |
46 | # IDE 和编辑器生成的文件
47 | .idea/
48 | .vscode/
49 | *.suo
50 | *.ntvs*
51 | *.njsproj
52 | *.sln
53 | *.sw?
54 |
55 | # 依赖目录
56 | node_modules/
57 | bower_components/
58 | vendor/
59 |
60 | # 编译生成的文件
61 | *.log
62 | *.out
63 | *.exe
64 | *.dll
65 | *.dylib
66 |
67 | # 调试文件
68 | *.debug
69 | *.pdb
70 |
71 | # 打包文件
72 | *.zip
73 | *.tar.gz
74 | *.rar
75 |
76 | # 环境文件
77 | .env
78 | .env.local
79 | .env.development
80 | .env.test
81 | .env.production
82 |
83 | # 日志文件
84 | logs/
85 | *.log
86 | npm-debug.log*
87 | yarn-debug.log*
88 | yarn-error.log*
89 |
90 | # 缓存目录
91 | .cache/
92 | .temp/
93 |
94 | # 测试生成的报告
95 | coverage/
96 | *.lcov
97 |
98 | # 系统文件
99 | *.swp
100 | *.swo
--------------------------------------------------------------------------------
/cp-retrieval-server/download.py:
--------------------------------------------------------------------------------
1 | import os
2 | # 修改为镜像源
3 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
4 |
5 | from huggingface_hub import snapshot_download
6 |
7 | download_dir = './'
8 |
9 | # download probs and embeddings from Hugging Face
10 | repo_id = 'coldchair16/CPRet-Embeddings'
11 | snapshot_download(
12 | repo_id=repo_id,
13 | repo_type='dataset',
14 | local_dir=os.path.join(download_dir),
15 | allow_patterns=['probs_2511*'],
16 | local_dir_use_symlinks=False,
17 | resume_download=True,
18 | max_workers=4,
19 | )
20 | print(f"Finished downloading {repo_id} to {download_dir}")
21 |
22 |
23 | # download models from Hugging Face
24 | # download_dir = './'
25 | # repo_id_list = [
26 | # 'coldchair16/CPRetriever-Prob-Qwen3-4B-2510',
27 | # ]
28 | # for repo_id in repo_id_list:
29 | # model_name = repo_id.split("/")[-1]
30 | # # replace '.' with 'p' to avoid issues in SentenceTransformer
31 | # model_name = model_name.replace(".", "p")
32 | # print('Begin downloading:', repo_id)
33 | # snapshot_download(
34 | # repo_id=repo_id,
35 | # repo_type="model",
36 | # local_dir=os.path.join(download_dir, model_name),
37 | # allow_patterns=['*'],
38 | # local_dir_use_symlinks=False,
39 | # resume_download=True,
40 | # max_workers=4,
41 | # )
42 | # print(f"Finished downloading {repo_id} to {download_dir}/{model_name}")
--------------------------------------------------------------------------------
/cp-retrieval-server/example.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from datasets import load_dataset
4 | import numpy as np
5 | from sentence_transformers import SentenceTransformer
6 | import random
7 | import json
8 | import tqdm as tqdm
9 |
10 | if __name__ == "__main__":
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--model_path",
13 | type=str,
14 | default='coldchair16/CPRetriever-Prob',
15 | help="Path to the SentenceTransformer model")
16 |
17 | args = parser.parse_args()
18 |
19 | model_path = args.model_path
20 | model = SentenceTransformer(model_path, trust_remote_code=True)
21 | model.tokenizer.model_max_length = 1024
22 | model.max_seq_length = 1024
23 |
24 | embs = np.load('./probs_embs.npy')
25 | embs = embs / np.linalg.norm(embs, axis=1, keepdims=True)
26 | probs_path = './probs.jsonl'
27 | probs = []
28 | with open(probs_path, 'r') as f:
29 | for line in f.readlines():
30 | line = line.strip()
31 | if line:
32 | data = json.loads(line)
33 | probs.append(data)
34 |
35 |
36 | text = '''
37 | You are given a sequence that supports the following operations:
38 | 1. Flatten a specified range.
39 | 2. Calculate the sum of the counts of distinct numbers in all subsegments of
40 | length k.'''
41 |
42 | text_emb = model.encode(text, convert_to_tensor=True)
43 | text_emb = text_emb / text_emb.norm(dim=-1, keepdim=True)
44 |
45 | sim_mat = np.dot(embs, text_emb)
46 | print(sim_mat.shape)
47 | rank = np.argsort(sim_mat, axis=0)[::-1]
48 |
49 | p = probs[rank[0]]
50 | print(p['title'], p['url'], p['source'], p['text'])
51 |
52 |
53 |
--------------------------------------------------------------------------------
/stage1/dataset_process.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 |
3 | def filted_by_length(data, max_length=1024):
4 | """
5 | Filter out solutions in each sample whose token length exceeds max_length.
6 |
7 | Args:
8 | data (list): A list of data samples, each containing 'solutions' and 'solutions_length'.
9 | max_length (int): Maximum allowed token length for a solution.
10 |
11 | Returns:
12 | list: A filtered list of samples, each retaining only solutions with acceptable length.
13 | """
14 | n = len(data)
15 | new_data = []
16 | for i in range(n):
17 | m = len(data[i]['solutions'])
18 |
19 | # Find indices of solutions within the length constraint
20 | index = [j for j in range(m) if data[i]['solutions_length'][j] <= max_length]
21 |
22 | # Keep only the filtered solutions and lengths
23 | data[i]['solutions'] = [data[i]['solutions'][j] for j in index]
24 | data[i]['solutions_length'] = [data[i]['solutions_length'][j] for j in index]
25 |
26 | if len(index) > 0:
27 | new_data.append(data[i])
28 |
29 | return new_data
30 |
31 | # Filter out solutions whose token count exceeds max_length
32 | def process_dataset(dataset, max_length=1024):
33 | """
34 | Process a HuggingFace dataset dict by filtering out solutions that exceed a length threshold.
35 |
36 | Args:
37 | dataset (DatasetDict): HuggingFace dataset containing 'train' and 'test' splits.
38 | max_length (int): Maximum token length allowed for each solution.
39 |
40 | Returns:
41 | Tuple: Filtered train and test data as lists.
42 | """
43 | traindata = dataset['train'].to_list()
44 | testdata = dataset['test'].to_list()
45 |
46 | traindata = filted_by_length(traindata, max_length)
47 | testdata = filted_by_length(testdata, max_length)
48 |
49 | print(f"Train data size: {len(traindata)}")
50 | print(f"Test data size: {len(testdata)}")
51 |
52 | return traindata, testdata
53 |
--------------------------------------------------------------------------------
/cp-retrieval-server/compute_embs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from datasets import load_dataset
4 | import numpy as np
5 | from sentence_transformers import SentenceTransformer
6 | import random
7 | import json
8 | import tqdm as tqdm
9 |
10 | if __name__ == "__main__":
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--model_path",
13 | type=str,
14 | default='coldchair16/CPRetriever-Prob',
15 | help="Path to the SentenceTransformer model")
16 | parser.add_argument("--input_path",
17 | type=str,
18 | default='./probs.jsonl',
19 | help="Path to the input JSONL file containing sentences")
20 | parser.add_argument("--max_length",
21 | type=int,
22 | default=2048,
23 | help="Maximum length of input sentences")
24 |
25 | args = parser.parse_args()
26 |
27 | sentences = []
28 |
29 | input_path = args.input_path
30 | with open(input_path, 'r') as f:
31 | for line in f.readlines():
32 | line = line.strip()
33 | if line:
34 | data = json.loads(line)
35 | sentences.append(data['text'])
36 |
37 | print(f"Number of sentences: {len(sentences)}")
38 |
39 | save_dir = './'
40 | os.makedirs(save_dir, exist_ok=True)
41 |
42 | model_path = args.model_path
43 | model_name = os.path.basename(model_path.rstrip('/'))
44 |
45 | model = SentenceTransformer(model_path, trust_remote_code=True)
46 | model.tokenizer.model_max_length = args.max_length
47 | model.max_seq_length = args.max_length
48 |
49 | pool = model.start_multi_process_pool()
50 | emb = model.encode_multi_process(sentences, pool, show_progress_bar=True, batch_size=8)
51 | model.stop_multi_process_pool(pool)
52 |
53 | print("Embeddings computed. Shape:", emb.shape)
54 |
55 | save_name = f"{input_path.split('/')[-1].replace('.jsonl', '')}_embs.npy"
56 | save_path = os.path.join(save_dir, save_name)
57 | emb = emb.astype('float32')
58 | np.save(save_path, emb)
59 |
60 | print(f"Embeddings saved to {save_path}")
61 |
--------------------------------------------------------------------------------
/stage1/grad_cache/loss.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.nn import functional as F
6 | from torch import distributed as dist
7 |
8 |
9 | class SimpleContrastiveLoss:
10 | def __init__(self, n_hard_negatives: int = 0):
11 | self.target_per_qry = n_hard_negatives + 1
12 |
13 | def __call__(self, x: Tensor, y: Tensor, target: Tensor = None, reduction: str = 'mean'):
14 | if target is None:
15 | assert x.size(0) * self.target_per_qry == y.size(0)
16 | target = torch.arange(0, x.size(0) * self.target_per_qry, self.target_per_qry, device=x.device)
17 |
18 | logits = torch.matmul(x, y.transpose(0, 1))
19 | return F.cross_entropy(logits, target, reduction=reduction)
20 |
21 |
22 | class DistributedContrastiveLoss(SimpleContrastiveLoss):
23 | def __init__(self, n_hard_negatives: int = 0):
24 | assert dist.is_initialized(), "Distributed training has not been properly initialized."
25 |
26 | super().__init__(n_hard_negatives=n_hard_negatives)
27 | self.word_size = dist.get_world_size()
28 | self.rank = dist.get_rank()
29 |
30 | def __call__(self, x: Tensor, y: Tensor, **kwargs):
31 | dist_x = self.gather_tensor(x)
32 | dist_y = self.gather_tensor(y)
33 |
34 | return super().__call__(dist_x, dist_y, **kwargs)
35 |
36 | def gather_tensor(self, t):
37 | gathered = [torch.empty_like(t) for _ in range(self.word_size)]
38 | dist.all_gather(gathered, t)
39 | gathered[self.rank] = t
40 | return torch.cat(gathered, dim=0)
41 |
42 |
43 | class ContrastiveLossWithQueryClosure(SimpleContrastiveLoss):
44 | def __call__(
45 | self,
46 | *reps: Tensor,
47 | query_closure: Callable[[], Tensor] = None,
48 | target: Tensor = None,
49 | reduction: str = 'mean'
50 | ):
51 | if len(reps) == 0 or len(reps) > 2:
52 | raise ValueError(f'Expecting 1 or 2 tensor input, got {len(reps)} tensors')
53 |
54 | # no closure evaluation
55 | if len(reps) == 2:
56 | assert query_closure is None, 'received 2 representation tensors while query_closure is also set'
57 | return super().__call__(*reps, target=target, reduction=reduction)
58 |
59 | # run the closure
60 | assert query_closure is not None
61 | x = query_closure()
62 | y = reps[0]
63 | return super().__call__(x, y, target=target, reduction=reduction)
64 |
--------------------------------------------------------------------------------
/stage2/dataset_process.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 |
3 | def filted_by_length(data, max_length=1024):
4 | """
5 | Filter out solutions in each sample whose token length exceeds max_length.
6 |
7 | Args:
8 | data (list): A list of data samples, each containing 'solutions' and 'solutions_length'.
9 | max_length (int): Maximum allowed token length for a solution.
10 |
11 | Returns:
12 | list: A filtered list of samples, each retaining only solutions with acceptable length.
13 | """
14 | n = len(data)
15 | new_data = []
16 | for i in range(n):
17 | m = len(data[i]['solutions'])
18 |
19 | # Find indices of solutions within the length constraint
20 | index = [j for j in range(m) if data[i]['solutions_length'][j] <= max_length]
21 |
22 | # Keep only the filtered solutions and lengths
23 | data[i]['solutions'] = [data[i]['solutions'][j] for j in index]
24 | data[i]['solutions_length'] = [data[i]['solutions_length'][j] for j in index]
25 |
26 | if len(index) > 0:
27 | new_data.append(data[i])
28 |
29 | return new_data
30 |
31 | # Filter out solutions whose token count exceeds max_length
32 | def process_dataset(dataset, max_length=1024):
33 | """
34 | Process a HuggingFace dataset dict by filtering out solutions that exceed a length threshold.
35 |
36 | Args:
37 | dataset (DatasetDict): HuggingFace dataset containing 'train' and 'test' splits.
38 | max_length (int): Maximum token length allowed for each solution.
39 |
40 | Returns:
41 | Tuple: Filtered train and test data as lists.
42 | """
43 | traindata = dataset['train'].to_list()
44 | testdata = dataset['test'].to_list()
45 |
46 | traindata = filted_by_length(traindata, max_length)
47 | testdata = filted_by_length(testdata, max_length)
48 |
49 | print(f"Train data size: {len(traindata)}")
50 | print(f"Test data size: {len(testdata)}")
51 |
52 | return traindata, testdata
53 |
54 | def process_eval_dataset(queries, corpus, qrels):
55 | queries = dict(zip([str(x['_id']) for x in queries], [x['text'] for x in queries])) # Our queries (qid => question)
56 | corpus = dict(zip([str(x['_id']) for x in corpus], [x['text'] for x in corpus])) # Our corpus (cid => document)
57 | relevant_docs = {} # Query ID to relevant documents (qid => set([relevant_cids])
58 | for qrel in qrels:
59 | qid, corpus_ids = qrel["query-id"], qrel["corpus-id"]
60 | qid = str(qid)
61 | corpus_ids = str(corpus_ids)
62 | if qid not in relevant_docs:
63 | relevant_docs[qid] = set()
64 | relevant_docs[qid].add(corpus_ids)
65 | return queries, corpus, relevant_docs
--------------------------------------------------------------------------------
/stage2/trainer_custom.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from sentence_transformers import SentenceTransformerTrainer
3 | from torch.utils.data import ConcatDataset
4 | from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, Value
5 | from torch.utils.data import DataLoader
6 | from sentence_transformers import SentenceTransformerTrainer
7 |
8 | class CustomTrainer(SentenceTransformerTrainer):
9 | def __init__(
10 | self,
11 | batch_sizes=None, # New parameter: dictionary specifying batch sizes for each sub-dataset
12 | *args,
13 | **kwargs
14 | ):
15 | self.batch_sizes = batch_sizes or {}
16 | super().__init__(*args, **kwargs)
17 |
18 | def get_train_dataloader(self):
19 | if isinstance(self.train_dataset, DatasetDict):
20 | generator = torch.Generator()
21 | if self.args.seed:
22 | generator.manual_seed(self.args.seed)
23 |
24 | data_collator = self.data_collator
25 | batch_samplers = []
26 |
27 | # 👇 Key logic: iterate over (name, dataset) pairs, assign batch size based on name
28 | for name, dataset in self.train_dataset.items():
29 | bs = self.batch_sizes.get(name, self.args.train_batch_size)
30 | sampler = self.get_batch_sampler(
31 | dataset,
32 | batch_size=bs,
33 | drop_last=self.args.dataloader_drop_last,
34 | valid_label_columns=data_collator.valid_label_columns,
35 | generator=generator,
36 | )
37 | batch_samplers.append(sampler)
38 |
39 | concat = ConcatDataset(self.train_dataset.values())
40 | batch_sampler = self.get_multi_dataset_batch_sampler(
41 | dataset=concat,
42 | batch_samplers=batch_samplers,
43 | generator=generator,
44 | seed=self.args.seed,
45 | )
46 |
47 | dl_kwargs = dict(
48 | collate_fn=data_collator,
49 | num_workers=self.args.dataloader_num_workers,
50 | pin_memory=self.args.dataloader_pin_memory,
51 | persistent_workers=self.args.dataloader_persistent_workers,
52 | prefetch_factor=self.args.dataloader_prefetch_factor,
53 | batch_sampler=batch_sampler,
54 | )
55 |
56 | self.accelerator.even_batches = False
57 | self._train_dataloader = self.accelerator.prepare(DataLoader(concat, **dl_kwargs))
58 | return self._train_dataloader
59 |
60 | # For single dataset or IterableDataset, fall back to default behavior
61 | return super().get_train_dataloader()
62 |
--------------------------------------------------------------------------------
/stage1/grad_cache/cachex/functional.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable, Any, Callable
2 | from functools import partial
3 |
4 | import jax
5 | import jax.numpy as jnp
6 |
7 | from .tree_utils import tree_unchunk
8 |
9 | Array = jax.Array
10 |
11 |
12 | def grad_with_cache(f, **grad_kwargs):
13 | def cache_f(params, cache, *args, **kwargs):
14 | return jnp.sum(f(params, *args, **kwargs) * cache)
15 | return jax.grad(cache_f, **grad_kwargs)
16 |
17 |
18 | def encode_scan_fn(f, carry, x):
19 | return carry, f(**x)
20 |
21 |
22 | def cache_grad_scan_fn(f, params, acc, x):
23 | cached_grad, kwargs = x
24 |
25 | def fwd_fn(w):
26 | return f(params=w, **kwargs)
27 |
28 | chunk_grad = grad_with_cache(fwd_fn)(params, cached_grad)
29 | acc = jax.tree_multimap(lambda u, v: u + v, acc, chunk_grad)
30 | return acc, None
31 |
32 |
33 | def chunk_encode(encode_fn):
34 | def f(**xx):
35 | _, hh = jax.lax.scan(partial(encode_scan_fn, encode_fn), 0, xx)
36 | return hh
37 | return f
38 |
39 |
40 | def cache_grad(encode_fn):
41 | def f(params, grad_accumulator, cached_grad, **xx):
42 | grads, _ = jax.lax.scan(
43 | partial(cache_grad_scan_fn, encode_fn, params), grad_accumulator, [cached_grad, xx]
44 | )
45 | return grads
46 | return f
47 |
48 |
49 | def unchunk_args(axis: int = 0, argnums: Iterable[int] = ()):
50 | def decorator_unchunk(f):
51 | def g(*args, **kwargs):
52 | new_args = list(args)
53 | for i in argnums:
54 | new_args[i] = tree_unchunk(args[i], axis)
55 | return f(*new_args, **kwargs)
56 |
57 | return g
58 |
59 | return decorator_unchunk
60 |
61 | def grad_cached(
62 | f: Callable[..., Array],
63 | policy: Callable[..., bool] = jax.checkpoint_policies.nothing_saveable,
64 | prevent_cse: bool = True
65 | ):
66 | """
67 | Single-decorator version of grad cache that uses XLA to infer backward pass.
68 |
69 | The forward pass is manually split into chunks and performed sequentially with lax.scan.
70 | We rely on XLA to infer the backward pass and run it in a similar fashion.
71 |
72 | Args:
73 | f: Function to be differentiated. It should take in a single argument and return a jax array of representations.
74 | policy: The sub-batch rematerialization policy.
75 | prevent_cse: Whether to prevent common subexpression elimination.
76 |
77 | Returns:
78 | Decorated gradient cached `f` that expects input to have an extra leading sub-batch dimension, potentially produced by `tree_chunk`.
79 |
80 | A example of usage on a apply function that takes multiple arguments:
81 |
82 | >>> @cachex.grad_cached
83 | ... def fwd(params, batch):
84 | ... return apply(params, **batch)
85 |
86 | >>> src = cachex.tree_chunk(src, 8)
87 | >>> tgt = cachex.tree_chunk(tgt, 8)
88 |
89 | >>> def compute_loss(params, src, tgt):
90 | ... h_src = fwd(params, src)
91 | ... h_tgt = fwd(params, tgt)
92 | ... return loss(h_src, h_tgt)
93 |
94 | >>> grads = jax.grad(compute_loss)(params, src, tgt)
95 |
96 | Here the `compute_loss` function can typically be dropped into a larger training step function.
97 | """
98 | def cached_f(params, batch):
99 | def scan_f(_, sub_batch):
100 | return None, f(params, sub_batch)
101 | _, reps = jax.lax.scan(jax.checkpoint(scan_f, policy=policy, prevent_cse=prevent_cse), None, batch)
102 | return jnp.reshape(reps, (-1,) + reps.shape[2:])
103 | return cached_f
--------------------------------------------------------------------------------
/stage1/grad_cache/functional.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | from typing import Callable, Union, Tuple, Any
3 |
4 | import torch
5 | from torch import Tensor
6 | from torch import distributed as dist
7 |
8 | from .context_managers import RandContext
9 |
10 |
11 | def cached(func: Callable[..., Tensor]):
12 | """
13 | A decorator that takes a model call function into a cached compatible version.
14 | :param func: A function that calls the model and return representation tensor.
15 | :return: A function that returns 1) representation leaf tensors for cache construction, 2) a closure function for
16 | the 2nd forward and the cached backward. Call 2) with 1) as argument after calling backward on the loss Tensor.
17 | """
18 | @wraps(func)
19 | def cache_func(*args, **kwargs):
20 | rnd_state = RandContext()
21 | with torch.no_grad():
22 | reps_no_grad = func(*args, **kwargs)
23 | if isinstance(reps_no_grad, Tensor):
24 | reps_no_grad = (reps_no_grad, )
25 | else:
26 | assert all(isinstance(v, Tensor) for v in reps_no_grad)
27 | leaf_reps = tuple(t.detach().requires_grad_() for t in reps_no_grad)
28 |
29 | @wraps(func)
30 | def forward_backward_func(cache_reps: Union[Tensor, Tuple[Tensor]]):
31 | with rnd_state:
32 | reps = func(*args, **kwargs)
33 | if isinstance(reps, Tensor):
34 | reps = (reps,)
35 | if isinstance(cache_reps, Tensor):
36 | cache_reps = (cache_reps,)
37 | assert len(reps) == len(cache_reps)
38 |
39 | surrogate = sum(map(lambda u, v: torch.dot(u.flatten(), v.grad.flatten()), reps, cache_reps), 0)
40 | surrogate.backward()
41 |
42 | return leaf_reps + (forward_backward_func,)
43 | return cache_func
44 |
45 |
46 | def _cat_tensor_list(xx):
47 | if isinstance(xx, list) and len(xx) > 0 and all(isinstance(x, Tensor) for x in xx):
48 | return torch.cat(xx)
49 | else:
50 | return xx
51 |
52 |
53 | def cat_input_tensor(func: Callable[..., Tensor]):
54 | """
55 | A decorator that concatenates positional and keyword arguments of type List[Tensor] into a single Tensor
56 | on the 0 dimension. This can come in handy dealing with results of representation tensors from multiple
57 | cached forward.
58 | :param func: A loss function
59 | :return: Decorated loss function for cached results.
60 | """
61 | @wraps(func)
62 | def cat_f(*args, **kwargs):
63 | args_cat = [_cat_tensor_list(x) for x in args]
64 | kwargs_cat = dict((k, _cat_tensor_list(v)) for k, v in kwargs.values())
65 | return func(*args_cat, **kwargs_cat)
66 | return cat_f
67 |
68 |
69 | def _maybe_gather_tensor(t: Any, axis: int):
70 | if not isinstance(t, Tensor):
71 | return t
72 | gathered = [torch.empty_like(t) for _ in range(dist.get_world_size())]
73 | dist.all_gather(gathered, t)
74 | gathered[dist.get_rank()] = t
75 | return torch.cat(gathered, dim=axis)
76 |
77 |
78 | def gather_input_tensor(func: Callable[..., Tensor], axis=0):
79 | """
80 | A decorator that all-gather positional and keyword arguments of type Tensor and concatenate them on axis.
81 | Intended to be used with distributed contrastive learning loss.
82 | :param func: A loss function
83 | :param axis: The axis the gathered tensors are concatenated.
84 | :return: Decorated loss function for distributed training.
85 | """
86 | @wraps(func)
87 | def f(*args, **kwargs):
88 | args_gathered = [_maybe_gather_tensor(x, axis=axis) for x in args]
89 | kwargs_gathered = dict((k, _maybe_gather_tensor(v, axis=axis)) for k, v in kwargs.values())
90 | return func(*args_gathered, **kwargs_gathered)
91 | return f
92 |
--------------------------------------------------------------------------------
/cp-retrieval-server/templates/stats.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | {{ t.site_name }} - {{ t.view_stats }}
6 |
7 |
8 |
32 |
33 |
34 | {{ t.site_name }} - {{ t.view_stats }}
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 | |
45 | {{ t.total_search_count or "总搜索次数" }}:
46 | {{ stats | map(attribute=1) | sum }}
47 | |
48 |
49 |
50 |
51 | | {{ t.date }} |
52 | {{ t.search_count }} |
53 |
54 |
55 | {% for date, count in stats %}
56 |
57 | | {{ date }} |
58 | {{ count }} |
59 |
60 | {% endfor %}
61 |
62 |
63 |
64 | {{ t.back }}
65 |
66 |
146 |
147 |
148 |
--------------------------------------------------------------------------------
/stage1/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | # os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
3 | import argparse
4 | import torch
5 | from datasets import load_dataset
6 | from sentence_transformers import SentenceTransformer
7 | from transformers import TrainingArguments
8 | from loss import InfoNCELoss_gradcache, InfoNCELoss_gradcache_multipos, GroupInfoNCELoss_gradcache_multipos
9 | from dataset import ContrastiveDataset, ContrastiveDataCollator
10 | from custom_trainer import ContrastiveTrainer
11 | from eval_metric import compute_metrics_custom
12 | from dataset_process import process_dataset
13 |
14 | def str2bool(v):
15 | if isinstance(v, bool):
16 | return v
17 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
18 | return True
19 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
20 | return False
21 | else:
22 | raise argparse.ArgumentTypeError('Boolean value expected.')
23 |
24 | parser = argparse.ArgumentParser()
25 |
26 | parser.add_argument("--model_path", type=str,
27 | default="Salesforce/SFR-Embedding-Code-2B_R")
28 | parser.add_argument("--max_length", type=int, default=1024)
29 | parser.add_argument("--chunk_size", type=int, default=2)
30 | parser.add_argument("--multi_pos", type=int, default=16)
31 | parser.add_argument("--per_device_batch_size", type=int, default=128)
32 | parser.add_argument("--num_train_epochs", type=int, default=20)
33 | parser.add_argument("--lr", type=float, default=3e-8)
34 | parser.add_argument("--temperature", type=float, default=0.07)
35 | parser.add_argument("--loss_type", type=str, choices=["GroupInfoNCE", "InfoNCE", "InfoNCE_multipos"],
36 | default="GroupInfoNCE")
37 | parser.add_argument("--use_data_augmentation", type=str2bool, default=True)
38 | parser.add_argument("--eval_only", type=str2bool, default=False)
39 |
40 | if __name__ == "__main__":
41 | args = parser.parse_args()
42 |
43 | model_name = args.model_path.split('/')[-1]
44 |
45 | # Load dataset from HuggingFace
46 | data = load_dataset('coldchair16/CPRet-data', 'PCPCD')
47 | traindata, testdata = process_dataset(data, args.max_length)
48 |
49 | model_kwargs = {"torch_dtype": torch.bfloat16} # TODO: Make sure your GPU and model support bfloat16
50 |
51 | model = SentenceTransformer(args.model_path, trust_remote_code=True, model_kwargs=model_kwargs, device='cpu')
52 | model.tokenizer.model_max_length = args.max_length
53 | model.max_seq_length = args.max_length
54 |
55 | dataset = ContrastiveDataset(
56 | traindata, model.tokenizer, max_length=args.max_length, multi_pos=args.multi_pos,
57 | use_data_augmentation=args.use_data_augmentation,
58 | )
59 | val_dataset = ContrastiveDataset(
60 | testdata, model.tokenizer, max_length=args.max_length, multi_pos=args.multi_pos,
61 | use_data_augmentation=False,
62 | eval_mode=True,
63 | )
64 |
65 | embedding_dim = model.get_sentence_embedding_dimension()
66 | print("Model embedding dimension: ", embedding_dim)
67 |
68 | if args.loss_type == 'GroupInfoNCE':
69 | loss_fn = GroupInfoNCELoss_gradcache_multipos(temperature=args.temperature)
70 | elif args.loss_type == 'InfoNCE':
71 | loss_fn = InfoNCELoss_gradcache(temperature=args.temperature)
72 | elif args.loss_type == 'InfoNCE_multipos':
73 | loss_fn = InfoNCELoss_gradcache_multipos(temperature=args.temperature)
74 |
75 | data_collator = ContrastiveDataCollator(model.tokenizer)
76 |
77 | save_name = args.model_path.split('/')[-1]
78 | save_name += f"-length{args.max_length}"
79 | save_name += f"-bs_per_device{args.per_device_batch_size}*gpus{os.environ['WORLD_SIZE']}"
80 | save_name += f"-epochs{args.num_train_epochs}"
81 | save_name += f"-temperature{args.temperature}"
82 | save_name += f"-lr{args.lr}"
83 | save_name += f"-multi_pos{args.multi_pos}"
84 | save_name += f"-use_data_augmentation_{args.use_data_augmentation}"
85 | save_name += f"-loss{args.loss_type}"
86 |
87 | print(f"Save name: {save_name}")
88 |
89 | training_args = TrainingArguments(
90 | output_dir="./results/" + save_name,
91 | per_device_train_batch_size=args.per_device_batch_size,
92 | num_train_epochs=args.num_train_epochs,
93 | logging_dir="./logs/" + save_name,
94 | logging_steps=1,
95 | eval_strategy="epoch",
96 | # eval_accumulation_steps=5,
97 | save_strategy="epoch",
98 | # save_steps=1,
99 | bf16=True,
100 | dataloader_num_workers=4,
101 | dataloader_drop_last=True,
102 | save_total_limit=1,
103 | remove_unused_columns=False,
104 | learning_rate=args.lr,
105 | weight_decay=0.01,
106 | max_grad_norm=1.0,
107 | gradient_accumulation_steps=1,
108 | lr_scheduler_type="constant",
109 | )
110 |
111 | trainer = ContrastiveTrainer(
112 | multi_pos=args.multi_pos,
113 | chunk_size=args.chunk_size,
114 | model=model,
115 | args=training_args,
116 | train_dataset=dataset,
117 | eval_dataset=val_dataset,
118 | loss_fn=loss_fn,
119 | data_collator=data_collator,
120 | compute_metrics=compute_metrics_custom,
121 | )
122 |
123 | if args.eval_only:
124 | eval_results = trainer.evaluate()
125 | print(eval_results)
126 | else:
127 | trainer.train()
128 |
--------------------------------------------------------------------------------
/stage1/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset, DataLoader
3 | import random
4 | from datasets import load_from_disk
5 | from transformers import AutoModel, AutoTokenizer, DataCollatorWithPadding
6 |
7 |
8 | # Define the ContrastiveDataset
9 | class ContrastiveDataset(Dataset):
10 | def __init__(
11 | self,
12 | data, tokenizer,
13 | max_length = 1024,
14 | eval_mode = False,
15 | prompt_question = "",
16 | prompt_code = "",
17 | eval_max_solutions_per_question = 20,
18 | multi_pos = 1,
19 | use_data_augmentation = True,
20 | ):
21 | self.data = data
22 | self.tokenizer = tokenizer
23 | self.max_length = max_length
24 | self.eval_mode = eval_mode
25 | self.prompt_question = prompt_question
26 | self.prompt_code = prompt_code
27 | self.multi_pos = multi_pos
28 | self.use_data_augmentation = use_data_augmentation
29 |
30 | self.eval_max_solutions_per_question = eval_max_solutions_per_question
31 |
32 | self.len = len(self.data)
33 |
34 | # for debug
35 | # if (eval_mode):
36 | # self.len = 8
37 |
38 | def __len__(self):
39 | return self.len
40 |
41 | def __getitem__(self, idx):
42 | if (self.eval_mode or not self.use_data_augmentation):
43 | query = self.data[idx]['question']
44 | else:
45 | query = random.choice([
46 | self.data[idx]['question'],
47 | self.data[idx]['question-main'],
48 | ])
49 |
50 | query = self.prompt_question + query
51 | query_encoding = self.tokenizer([query], return_tensors="pt", padding='max_length', max_length=self.max_length, truncation=True)
52 |
53 | if (self.eval_mode):
54 | n = min(self.eval_max_solutions_per_question, len(self.data[idx]['solutions']))
55 | doc = self.data[idx]['solutions'][:n]
56 | doc = [self.prompt_code + d for d in doc]
57 | if (n < self.eval_max_solutions_per_question):
58 | doc = doc + [""] * (self.eval_max_solutions_per_question - n)
59 | doc_encoding = self.tokenizer(doc, return_tensors="pt", padding='max_length', max_length=self.max_length, truncation=True)
60 | doc_id = torch.tensor(n)
61 | else:
62 | if (self.multi_pos == 1):
63 | doc = random.choice(self.data[idx]['solutions'])
64 | doc = self.prompt_code + doc
65 | doc_encoding = self.tokenizer([doc], return_tensors="pt", padding='max_length', max_length=self.max_length, truncation=True)
66 | doc_encoding['input_ids'] = doc_encoding['input_ids'].squeeze(0)
67 | doc_encoding['attention_mask'] = doc_encoding['attention_mask'].squeeze(0)
68 | doc_id = torch.tensor(1)
69 | else:
70 | n = len(self.data[idx]['solutions'])
71 | if (n < self.multi_pos):
72 | doc = self.data[idx]['solutions'] * (self.multi_pos // n) + self.data[idx]['solutions'][:self.multi_pos % n]
73 | else:
74 | doc = random.sample(self.data[idx]['solutions'], self.multi_pos)
75 | doc = [self.prompt_code + d for d in doc]
76 | doc_encoding = self.tokenizer(doc, return_tensors="pt", padding='max_length', max_length=self.max_length, truncation=True)
77 | doc_id = torch.tensor(n)
78 |
79 | # Return a dictionary with the required fields
80 | return {
81 | # shape (max_length)
82 | 'input_ids': query_encoding['input_ids'].squeeze(0), # Standard field for Trainer
83 | 'attention_mask': query_encoding['attention_mask'].squeeze(0), # Standard field for Trainer
84 | 'id' : torch.tensor(idx),
85 |
86 | # shape : (1, max_length) or (n, max_length)
87 | 'doc_input_ids': doc_encoding['input_ids'], # Additional field for document
88 | 'doc_attention_mask': doc_encoding['attention_mask'], # Additional field for document
89 | 'doc_id': doc_id, # Additional field for evaluation
90 | }
91 |
92 | # Custom DataCollator to handle additional fields
93 | from transformers import DataCollator
94 |
95 | class ContrastiveDataCollator:
96 | def __init__(self, tokenizer, padding=True, truncation=True, max_length=None):
97 | self.tokenizer = tokenizer
98 | self.padding = padding
99 | self.truncation = truncation
100 | self.max_length = max_length
101 |
102 | def __call__(self, features):
103 | batch = {}
104 | if 'input_ids' in features[0]:
105 | batch['input_ids'] = torch.stack([f['input_ids'] for f in features])
106 | if 'attention_mask' in features[0]:
107 | batch['attention_mask'] = torch.stack([f['attention_mask'] for f in features])
108 | if 'id' in features[0]:
109 | batch['id'] = torch.stack([f['id'] for f in features])
110 |
111 | if 'doc_input_ids' in features[0]:
112 | batch['doc_input_ids'] = torch.stack([f['doc_input_ids'] for f in features])
113 | if 'doc_attention_mask' in features[0]:
114 | batch['doc_attention_mask'] = torch.stack([f['doc_attention_mask'] for f in features])
115 | if 'doc_id' in features[0]:
116 | batch['doc_id'] = torch.stack([f['doc_id'] for f in features])
117 | return batch
--------------------------------------------------------------------------------
/stage1/grad_cache/pytorch_lightning/readme.md:
--------------------------------------------------------------------------------
1 | # PL_GradCache
2 |
3 | This is an experimental folder to provide example of using Grad Cache with PyTorch Lightning (pl), tested on pl version '2.2.0.post0' with Multi-GPUs and Mix-Precision (fp-16). Pytorch Metric Learning is required to install as well for contrastive loss calculation.
4 |
5 | - [Wandb Logging Experiments for Sanity Test](https://api.wandb.ai/links/xyznlp/nmf8d551)
6 |
7 | ### Installation
8 |
9 | After GradCache is installed, do
10 |
11 | ```
12 | cd GradCache/src/grad_cache/pytorch_lightning
13 | python -m venv plgc
14 | . ./plgc/bin/activate
15 | pip3 install -U pip
16 | pip3 install -r requirements.txt
17 | ```
18 |
19 | ### Reproducing Wandb Experiments
20 |
21 | ```
22 | # 1-gpu
23 | python pl_example.py --gpus 1 --batch_size 16
24 | # 2-gpus
25 | python pl_example.py --gpus 2 --batch_size 8
26 | # 1-gpu, gradcache
27 | python pl_example.py --gpus 1 --batch_size 16 --use_gc --gc_minibatch_size 2
28 | # 2-gpus, gradcache
29 | python pl_example.py --gpus 2 --batch_size 8 --use_gc --gc_minibatch_size 2
30 | ```
31 |
32 | Optionally, do mix-precision training with `--precision 16`, run different ddp_backend with `--ddp_backend {gloo/nccl/etc.}`
33 |
34 | ### Example
35 |
36 | Run `python pl_example.py` with the following flags.
37 |
38 | * `--use_gc` activates GradCache.
39 | * `--gc_minibatch_size {minibatch_size}` defines the batch size that each GPU needs to hold its memory into. If we specify `--gpus 2 --batch_size 8 --gc_minibatch 2`, for example, the model would be trained with batch size 8 * 2 = 16, the trainer would split each batch on each GPU (8 data samples) into 4 chunks of mini batches (2 data samples per mini batch). Set this to 1 gives the minimal possible gpu memory usage.
40 |
41 | ### Summary
42 |
43 | - Add `pl_gradcache.py` as customized GradCache on PyTorch Lightning.
44 | - Use manual backward in gradcache by calling `lightning_trainer.manual_backward(loss)` instead of using `loss.backward()` (this requires changing gradcache).
45 | - Set gradcache `no_sync_except_last=True` in multi-GPU case.
46 |
47 | ### Changes to the original GradCache
48 |
49 | #### File Change
50 | - `pl_gradcache.py` is the GradCache we will run on PyTorch Lightning (pl) with Distributed Data Parallel (ddp).
51 |
52 | #### Change in Optimization
53 | - In pt ddp setting, we need to first set `lightning_trainer.automatic_optimization=False` for us to customize calling backward.
54 | - See [the pl optimization doc](https://lightning.ai/docs/pytorch/stable/common/optimization.html) for implementation details, make sure that we are calling `self.optimizers()` instead of creating one by ourselves: if we do `self.optimizer = optimizer` in `self.configure_optimizers()`, this is not correct as it initializes a base optimizer in pt, but `self.optimizers()` is a wrapper for that. The base optimizer does not have the correct access to ddp and logging.
55 | - Then, replace all `loss.backward()` in GradCache with `lightning_trainer.manual_backward(loss)`.
56 |
57 | #### Change in GradCache
58 | - Set `no_sync_except_last=True` in Multi-GPU case to avoid unnecessary gradient reduction in the last step of gradcache.
59 |
60 | #### If you want to run GradCache in PyTorch Lightning with Multi-GPUs
61 | - In short, you are good to go by not worrying about this part. But here are the key changes in the original gradcache that are necessary for this to work.
62 | - we have two options to use gradcache in pl ddp setting and call `self.init_gc(scaler, gc_model)`.
63 | - We can set `gc_model=pytorch lightning trainer`.
64 | - PyTorch Lightning would then wrap the base model (transformer) by their implementation of DDP.
65 | - In this case, just set `no_sync_except_last = False`, because lightning will handle gradient sync before `optimizer.step()`.
66 | - Set `no_sync_except_last = True` in this case does not work as the base model in gradcache is the transformer, which causes gradcache assertion and `model.no_sync` not available error.
67 | - Or, we can just change gradcache instead (remove assert DDP and `model.no_sync`).
68 | - The only downside of this approach is that the training may take a little longer, because gradient sync is done on the full batch size (expected by pytorch lightning) instead of the last minibatch (expected by gradcache). But based on some sanity runs, it is ok (less than a 10% runtime increase).
69 | - We can set `gc_model=pytorch lightning trainer.strategy.model`, i.e. the wrapped base model by PyTorch DDP.
70 | - This is tricky as PyTorch Lightning uses a parameter `require_backward_grad_sync` to determine whether gradients would be synced across GPUs.
71 | - Firstly, Pytorch Lightning overrides the PyTorch DDP by their own implementation and set `require_backward_grad_sync=False` before each training step (when `automatic optimization=False`). Then, it is set it to True **after** each training step.
72 | - The issue here is that gradcache needs the gradient to be synced in the last backward step, which happens inside the training step hook of pytorch lightning. Thus, what we can only do is to set this variable manually before the last backward step in gradcache - we cannot set it outside of gradcache either, because the first backward of gradcache to do gradient checkpointing should NOT sync gradient (this is the point of gradcache essentially).
73 | - Thus, we do `model.require_backward_grad_sync=True` at the very end of gradcache - before the backward of the last minibatch surrogate.
74 | - The advantage of this is that we can do `no_sync_except_last` as what gradcache hopes us to do (no runtime increase). The downside is that we need to modify gradcache in a very hacky way. This is the default setup.
75 |
--------------------------------------------------------------------------------
/TestCases.md:
--------------------------------------------------------------------------------
1 |
2 | # 🔍 Test Cases of CPRet
3 |
4 | ## 📌 System Launch Overview
5 |
6 | * Our open-source **problem retrieval platform** was launched on **May 21, 2025**: https://cpret.online/ (Before : http://1.94.255.218:5000/)
7 | * In less than a week, the platform has recorded nearly **2,000 search queries**.
8 | * The blog post introducing the system on **Codeforces** has received over **250 upvotes** and positive community feedback: https://codeforces.com/blog/entry/143098
9 | * **Use cases** include:
10 |
11 | * **Similar Problem Retrieval**: Assisting contestants in expanding their problem-solving perspective and addressing knowledge blind spots.
12 | * **Duplicate Problem Retrieval**: Helping problem setters identify previously seen ideas or solutions early on.
13 |
14 | ---
15 |
16 | ## 🧪 Test Case 1: Duplicate Retrieval in a Recent Contest
17 |
18 | **Target Contest**
19 |
20 | * **Name**: The 2025 CCPC National Invitational Contest (Northeast), The 19th Northeast Collegiate Programming Contest
21 | * **Url**: https://codeforces.com/gym/105924
22 | * **Date**: May 25, 2025
23 | * **Total Problems**: 12
24 |
25 | **Key Findings**
26 |
27 | * The system successfully identified **6 problems** with highly similar or identical historical counterparts.
28 | * Only the **top 3 results per query** were manually inspected, indicating a **duplicate rate of at least 50%**—likely higher in practice.
29 | * Under current scoring rules, solving these 6 problems is enough to secure a **silver medal**, and just one additional easy problem would result in a **gold medal**, raising **serious fairness concerns**.
30 |
31 | **Detection Summary**
32 |
33 | | Contest Problem | Matched Historical Problem | Similarity Level | Rank |
34 | | ------------------------------------------------------------------------ | -------------------------------------------------------------------------------- | ---------------- | ---- |
35 | | [A. GD Ultimate Rhythm Lab](https://codeforces.com/gym/105924/problem/A) | [Nowcoder - 小睿睿的数列](https://ac.nowcoder.com/acm/problem/24479) | Same approach | 1 |
36 | | [D. Defend the Carrot](https://codeforces.com/gym/105924/problem/D) | [SPOJ - UOFTBB](https://www.spoj.com/problems/UOFTBB/) | Almost identical | 1 |
37 | | [E. Tree Edge Removal](https://codeforces.com/gym/105924/problem/E) | [Luogu - \[JRKSJ R7\] 茎](https://www.luogu.com.cn/problem/P8935) | Almost identical | 1 |
38 | | [F. Youthful Oath II](https://codeforces.com/gym/105924/problem/F) | [Codeforces - 80B Depression](https://codeforces.com/problemset/problem/80/B) | Almost identical | 1 |
39 | | [J. Kingdom: Memories](https://codeforces.com/gym/105924/problem/J) | [AtCoder - R Walk](https://atcoder.jp/contests/dp/tasks/dp_r) | Almost identical | 3 |
40 | | [L. Bathhouse](https://codeforces.com/gym/105924/problem/L) | [Codeforces - 219E Parking Lot](https://codeforces.com/problemset/problem/219/E) | Same approach | 2 |
41 |
42 | ---
43 |
44 | ## 🧪 Test Case 2: Similar Problem Retrieval – MEX Variants
45 |
46 | We conducted a query with the classic problem "**interval MEX**" to identify its **variants across different contests**, aiming to showcase the system’s utility for **idea expansion and knowledge transfer**.
47 |
48 | ### Query Problem Description
49 |
50 | > Given a sequence of $n$ natural numbers $a[1..n]$, answer $m$ queries.
51 | > Each query specifies a range $[l, r]$, and asks for $\mathrm{mex}({a_l, a_{l+1}, \dots, a_r})$ —
52 | > the **minimum excluded value** in the subarray.
53 | >
54 | > This problem can be solved in **$O((n + m) log n)$** time using segment trees.
55 |
56 | ### Retrieval Results
57 |
58 | | Rank | Title | Description |
59 | | ---- | ---------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------- |
60 | | 1 | [Luogu P4137: RMQ Problem / mex](https://www.luogu.com.cn/problem/P4137) | Original problem |
61 | | 2 | [LOJ 6908: THUPC 2024 Prelim - "Matryoshka"](https://loj.ac/p/6908) | MEX of all subarrays of length $k$, then take the MEX of those |
62 | | 5 | [AtCoder ABC194E: Mex Min](https://atcoder.jp/contests/abc194/tasks/abc194_e) | MEX of all subarrays of length $k$, then take the **minimum** |
63 | | 6 | [Luogu P10032: Mex of Sequence](https://www.luogu.com.cn/problem/P10032) | Repeated operations: $a'[i] = \mathrm{mex}(a \setminus a[i])$ |
64 | | 11 | [Nowcoder 237670: “经典”问题](https://ac.nowcoder.com/acm/problem/237670) | MEX queries on permutations, optimized to $O(n + m)$ |
65 | | 14 | [Luogu P8087: 『JROI-5』Interval](https://www.luogu.com.cn/problem/P8087) | MEX of all subarrays of length $k$, then take the **maximum** |
66 | | 15 | [AtCoder ABC290C: Max MEX](https://atcoder.jp/contests/abc290/tasks/abc290_c) | MEX of all **subsequences** of length $k$, then take the **minimum** |
67 | | 16 | [Codeforces 1436E: Complicated Computations](https://codeforces.com/problemset/problem/1436/E) | MEX of all subarrays, then take the MEX again |
68 | | 23 | [AtCoder ABC330E: Mex and Update](https://atcoder.jp/contests/abc330/tasks/abc330_e) | Support element modification or querying full array MEX |
69 | | 24 | [Luogu P11837: Making Mexes B](https://www.luogu.com.cn/problem/P11837) | Minimum edits to ensure $\mathrm{mex}(a) = i$ |
70 |
--------------------------------------------------------------------------------
/stage1/grad_cache_custom.py:
--------------------------------------------------------------------------------
1 | from grad_cache import GradCache
2 | from typing import List, Union, Callable, Any
3 | from contextlib import nullcontext
4 | import torch
5 | from torch import nn, Tensor
6 | from torch.cuda.amp import GradScaler, autocast
7 | from grad_cache.context_managers import RandContext
8 |
9 | class GradCacheCustom(GradCache):
10 | def __init__(
11 | self,
12 | multi_pos : int = 1,
13 | backward_fn = None,
14 | compute_loss_context_manager = None,
15 | *args,
16 | **kwargs,
17 | ):
18 | super().__init__(*args, **kwargs)
19 | self.multi_pos = multi_pos
20 | self.backward_fn = backward_fn
21 | self.compute_loss_context_manager = None
22 | self.get_rep_fn = lambda x: x['sentence_embedding']
23 |
24 | def model_call(self, model: nn.Module, model_input):
25 | """
26 | Literally call the model's __call__ method.
27 | :param model: model to be called
28 | :param model_input: input to the model call
29 | :return: model output
30 | """
31 | with self.compute_loss_context_manager():
32 | return model(model_input)
33 |
34 | def compute_loss(self, *reps: Tensor, **loss_kwargs) -> Tensor:
35 | """
36 | Compute the loss based on the representation tensors. The tensors should be ordered same as the list of models
37 | registered in this GradCache class instance.
38 | :param reps: Representations for computing the loss.
39 | :param loss_kwargs: Keyword arguments input to the loss function.
40 | :return: the loss tensor.
41 | """
42 | with self.compute_loss_context_manager():
43 | loss = self.loss_fn(*reps, **loss_kwargs)
44 | return loss
45 |
46 | def build_cache(self, *reps: Tensor, **loss_kwargs) -> [List[Tensor], Tensor]:
47 | """
48 | Compute the gradient cache
49 | :param reps: Computed representations from all encoder models
50 | :param loss_kwargs: Extra keyword arguments to the loss function
51 | :return: A tuple of a) gradient cache for each encoder model, and b) loss tensor
52 | """
53 | # for r in reps:
54 | # print(f"r.shape : {r.shape}")
55 | reps = [r.detach().requires_grad_() for r in reps]
56 | with autocast() if self.fp16 else nullcontext():
57 | loss = self.compute_loss(*reps, **loss_kwargs)
58 |
59 | if self.fp16:
60 | self.backward_fn(self.scaler.scale(loss))
61 | else:
62 | self.backward_fn(loss)
63 |
64 | cache = [r.grad for r in reps]
65 |
66 | return cache, loss.detach()
67 |
68 | def forward_backward(
69 | self,
70 | model: nn.Module,
71 | model_inputs,
72 | cached_gradients: List[Tensor],
73 | random_states: List[RandContext],
74 | no_sync_except_last: bool = False
75 | ):
76 | """
77 | Run the second forward and the backward pass to compute gradient for a model.
78 | :param model: Encoder model.
79 | :param model_inputs: Chunked input to the encoder model.
80 | :param cached_gradients: Chunked gradient cache tensor for each input.
81 | :param random_states: Each input's device random state during the first forward.
82 | :param no_sync_except_last: If True, under distributed setup, only trigger gradient reduction across processes
83 | for the last sub-batch's forward-backward pass.
84 | """
85 | if no_sync_except_last:
86 | sync_contexts = [model.no_sync for _ in range(len(model_inputs) - 1)] + [nullcontext]
87 | else:
88 | sync_contexts = [nullcontext for _ in range(len(model_inputs))]
89 |
90 | for x, state, gradient, sync_context in zip(model_inputs, random_states, cached_gradients, sync_contexts):
91 | with sync_context():
92 | with state:
93 | y = self.model_call(model, x)
94 | reps = self.get_reps(y)
95 |
96 | surrogate = torch.dot(reps.flatten(), gradient.flatten())
97 | self.backward_fn(surrogate)
98 |
99 | def cache_step(
100 | self,
101 | *model_inputs,
102 | no_sync_except_last: bool = False,
103 | **loss_kwargs
104 | ) -> Tensor:
105 | """
106 | Run a cached step to compute gradient over the inputs.
107 | :param model_inputs: Input to each encoder model. Should be in similar order as the class's model.
108 | :param no_sync_except_last: If True, under distributed setup, for each model, only trigger gradient reduction
109 | across processes for the last sub-batch's forward-backward pass.
110 | :param loss_kwargs: Additional keyword arguments to the loss function.
111 | :return: The current's loss.
112 | """
113 | all_reps = []
114 | all_rnd_states = []
115 |
116 | if no_sync_except_last:
117 | assert all(map(lambda m: isinstance(m, nn.parallel.DistributedDataParallel), self.models)), \
118 | 'Some of models are not wrapped in DistributedDataParallel. Make sure you are running DDP with ' \
119 | 'proper initializations.'
120 |
121 | model_inputs = model_inputs[0]
122 | model_inputs = [
123 | {
124 | 'input_ids': model_inputs['input_ids'],
125 | 'attention_mask': model_inputs['attention_mask'],
126 | },
127 | {
128 | 'input_ids': model_inputs['doc_input_ids'],
129 | 'attention_mask': model_inputs['doc_attention_mask'],
130 | }
131 | ]
132 |
133 | if (self.multi_pos > 1):
134 | for k in ['input_ids', 'attention_mask']:
135 | model_inputs[1][k] = model_inputs[1][k].flatten(0, 1)
136 |
137 | model_inputs = [self.split_inputs(x, chunk_size) for x, chunk_size in zip(model_inputs, self.chunk_sizes)]
138 | # print('model_inputs : ', model_inputs)
139 |
140 | for model, x in zip(self.models, model_inputs):
141 | model_reps, rnd_states = self.forward_no_grad(model, x)
142 | all_reps.append(model_reps)
143 | all_rnd_states.append(rnd_states)
144 |
145 | cache, loss = self.build_cache(*all_reps, **loss_kwargs)
146 | cache = [c.split(chunk_size) for c, chunk_size in zip(cache, self.chunk_sizes)]
147 |
148 | for model, x, model_cache, rnd_states in zip(
149 | self.models, model_inputs, cache, all_rnd_states):
150 | self.forward_backward(model, x, model_cache, rnd_states, no_sync_except_last=no_sync_except_last)
151 |
152 | return loss
--------------------------------------------------------------------------------
/stage1/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from typing import Any, Dict
5 |
6 | class InfoNCELoss_gradcache(nn.Module):
7 | def __init__(self, temperature: float = 0.07):
8 | super(InfoNCELoss_gradcache, self).__init__()
9 | self.temperature = temperature
10 |
11 | def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, **loss_kwargs) -> torch.Tensor:
12 | # Convert bf16 to fp32 for numerical stability
13 | query_embeddings = query_embeddings.float()
14 | doc_embeddings = doc_embeddings.float()
15 |
16 | # Gather embeddings across distributed processes
17 | query_embeddings = self.gather_tensor(query_embeddings)
18 | doc_embeddings = self.gather_tensor(doc_embeddings)
19 |
20 | # Normalize the embeddings to unit vectors
21 | query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
22 | doc_embeddings = F.normalize(doc_embeddings, p=2, dim=1)
23 |
24 | return self._compute_loss(query_embeddings, doc_embeddings)
25 |
26 | def _compute_loss(self, query_embeddings, doc_embeddings):
27 | # Positive similarity (dot product between aligned query-doc pairs)
28 | sim_u = torch.sum(query_embeddings * doc_embeddings, dim=1) / self.temperature
29 | sim_u = torch.exp(sim_u)
30 |
31 | # Similarities with all queries and docs (positive + negatives)
32 | sim_v = torch.sum(torch.exp(torch.matmul(query_embeddings, doc_embeddings.T) / self.temperature), dim=1) + \
33 | torch.sum(torch.exp(torch.matmul(query_embeddings, query_embeddings.T) / self.temperature), dim=1)
34 | # Remove self-similarity term
35 | sim_v = sim_v - torch.exp(torch.tensor(1 / self.temperature, device=sim_v.device))
36 |
37 | return -torch.log(sim_u / sim_v).mean()
38 |
39 | def gather_tensor(self, t):
40 | # All-gather operation for distributed training
41 | t_new = torch.distributed.nn.all_gather(t)
42 | t_new = torch.cat(t_new, dim=0)
43 | return t_new
44 |
45 |
46 | # Version that uses sum of positive similarities in numerator
47 | class InfoNCELoss_gradcache_multipos(InfoNCELoss_gradcache):
48 | def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, **loss_kwargs) -> torch.Tensor:
49 | query_embeddings = query_embeddings.float()
50 | doc_embeddings = doc_embeddings.float()
51 |
52 | query_embeddings = self.gather_tensor(query_embeddings)
53 | doc_embeddings = self.gather_tensor(doc_embeddings)
54 |
55 | # Reshape to [batch_size, num_pos, dim]
56 | doc_embeddings = doc_embeddings.reshape(query_embeddings.shape[0], -1, query_embeddings.shape[1])
57 |
58 | query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
59 | doc_embeddings = F.normalize(doc_embeddings, p=2, dim=2)
60 |
61 | sim_mat = torch.exp(query_embeddings @ query_embeddings.T / self.temperature)
62 | sim_v = torch.sum(sim_mat, dim=1) - torch.exp(torch.tensor(float(1.0) / self.temperature, device=sim_mat.device))
63 |
64 | sim_w = torch.einsum("ik, jnk -> ijn", query_embeddings, doc_embeddings) / self.temperature
65 | sim_w = torch.exp(sim_w) # shape: [bs, bs, pos]
66 | sim_w = torch.sum(sim_w, dim=2) # shape: [bs, bs]
67 | sim_u = sim_w.diagonal() # positive similarities on the diagonal
68 | sim_w = torch.sum(sim_w, dim=1) # total similarities
69 |
70 | return -torch.log(sim_u / (sim_w + sim_v)).mean()
71 |
72 |
73 | # Version that separates query-doc and query-query terms
74 | class InfoNCELoss_gradcache_multipos(InfoNCELoss_gradcache):
75 | def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, **loss_kwargs) -> torch.Tensor:
76 | query_embeddings = query_embeddings.float()
77 | doc_embeddings = doc_embeddings.float()
78 |
79 | query_embeddings = self.gather_tensor(query_embeddings)
80 | doc_embeddings = self.gather_tensor(doc_embeddings)
81 |
82 | doc_embeddings = doc_embeddings.reshape(query_embeddings.shape[0], -1, query_embeddings.shape[1])
83 |
84 | query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
85 | doc_embeddings = F.normalize(doc_embeddings, p=2, dim=2)
86 |
87 | sim_mat = torch.exp(query_embeddings @ query_embeddings.T / self.temperature)
88 | sim_v = torch.sum(sim_mat, dim=1) - torch.exp(torch.tensor(float(1.0) / self.temperature, device=sim_mat.device))
89 |
90 | # Positive similarities: shape [bs, pos]
91 | sim_a = torch.exp(torch.sum(query_embeddings.unsqueeze(1) * doc_embeddings, dim=2) / self.temperature)
92 |
93 | sim_w = torch.einsum("ik, jnk -> ijn", query_embeddings, doc_embeddings) / self.temperature
94 | sim_w = torch.exp(sim_w) # shape: [bs, bs, pos]
95 |
96 | sim_p = torch.sum(sim_w, dim=2) # [bs, bs]
97 | sim_u_origin = sim_p.diagonal()
98 | sim_q = torch.sum(sim_p, dim=1)
99 |
100 | return torch.mean(-torch.log(sim_a / ((sim_q + sim_v - sim_u_origin).unsqueeze(1) + sim_a)))
101 |
102 |
103 | # Version that implements Group-InfoNCE with regularization
104 | class GroupInfoNCELoss_gradcache_multipos(InfoNCELoss_gradcache):
105 | def forward(self, query_embeddings: torch.Tensor, doc_embeddings: torch.Tensor, **loss_kwargs) -> torch.Tensor:
106 | query_embeddings = query_embeddings.float()
107 | doc_embeddings = doc_embeddings.float()
108 |
109 | query_embeddings = self.gather_tensor(query_embeddings)
110 | doc_embeddings = self.gather_tensor(doc_embeddings)
111 |
112 | doc_embeddings = doc_embeddings.reshape(query_embeddings.shape[0], -1, query_embeddings.shape[1])
113 |
114 | query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
115 | doc_embeddings = F.normalize(doc_embeddings, p=2, dim=2)
116 |
117 | # sim_a: similarity between each query and its group of positive docs
118 | sim_a = torch.einsum('ik, ijk -> ij', query_embeddings, doc_embeddings)
119 | loss_penalty = torch.var(sim_a, dim=1).mean() # penalty: variance within group
120 |
121 | sim_b = torch.mean(sim_a, dim=1)
122 | sim_b = torch.exp(sim_b / self.temperature)
123 |
124 | # query-to-query similarity matrix (excluding self-similarity)
125 | sim_p = torch.einsum('ik, jk -> ij', query_embeddings, query_embeddings)
126 | sim_p = torch.exp(sim_p / self.temperature)
127 | sim_p = torch.sum(sim_p, dim=1) - torch.exp(torch.tensor(1 / self.temperature, device=sim_p.device))
128 |
129 | # query-to-all-doc-groups similarity
130 | sim_q = torch.einsum('ik, jpk -> ijp', query_embeddings, doc_embeddings)
131 | sim_q = torch.mean(sim_q, dim=2) # [bs, bs]
132 | sim_q = torch.exp(sim_q / self.temperature)
133 | sim_q = torch.sum(sim_q, dim=1)
134 |
135 | loss = -torch.log(sim_b / (sim_p + sim_q)).mean()
136 |
137 | if torch.distributed.get_rank() == 0:
138 | print(f"\nloss_penalty: {loss_penalty} / T^2 = {loss_penalty / self.temperature ** 2}")
139 | print(f"loss: {loss}")
140 |
141 | return loss + loss_penalty / self.temperature ** 2
142 |
--------------------------------------------------------------------------------
/stage1/eval_metric.py:
--------------------------------------------------------------------------------
1 | import heapq
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from sentence_transformers.evaluation import InformationRetrievalEvaluator
6 | from sentence_transformers.similarity_functions import SimilarityFunction
7 | from typing import TYPE_CHECKING, Callable
8 | from torch import Tensor
9 | from tqdm import trange
10 |
11 | # Modified InformationRetrievalEvaluator to support directly passing in embeddings
12 | class CustomRetrievalEvaluator(InformationRetrievalEvaluator):
13 | def __init__(
14 | self,
15 | corpus_chunk_size: int = 50000,
16 | mrr_at_k: list[int] = [10],
17 | ndcg_at_k: list[int] = [10, 100],
18 | accuracy_at_k: list[int] = [1, 3, 5, 10],
19 | precision_recall_at_k: list[int] = [1, 3, 5, 10],
20 | map_at_k: list[int] = [10, 100],
21 | show_progress_bar: bool = False,
22 | batch_size: int = 32,
23 | name: str = "",
24 | write_csv: bool = True,
25 | truncate_dim: int | None = None,
26 | score_functions: dict[str, Callable[[Tensor, Tensor], Tensor]] | None = {
27 | "cos": SimilarityFunction.to_similarity_fn("cosine")
28 | },
29 | main_score_function: str | SimilarityFunction | None = None,
30 | ) -> None:
31 | # Manually call base class __init__ (bypassing argument mismatches)
32 | super(InformationRetrievalEvaluator).__init__()
33 | self.corpus_chunk_size = corpus_chunk_size
34 | self.mrr_at_k = mrr_at_k
35 | self.ndcg_at_k = ndcg_at_k
36 | self.accuracy_at_k = accuracy_at_k
37 | self.precision_recall_at_k = precision_recall_at_k
38 | self.map_at_k = map_at_k
39 |
40 | self.show_progress_bar = show_progress_bar
41 | self.batch_size = batch_size
42 | self.name = name
43 | self.write_csv = write_csv
44 | self.score_functions = score_functions
45 | self.score_function_names = sorted(list(self.score_functions.keys())) if score_functions else []
46 | self.main_score_function = SimilarityFunction(main_score_function) if main_score_function else None
47 | self.truncate_dim = truncate_dim
48 |
49 | if name:
50 | name = "_" + name
51 |
52 | self.csv_file: str = "Information-Retrieval_evaluation" + name + "_results.csv"
53 | self.csv_headers = ["epoch", "steps"]
54 |
55 | self._append_csv_headers(self.score_function_names)
56 |
57 | def calc_results(self, query_embeddings, doc_embeddings, r_docs):
58 | # Construct internal ID mappings and relevance ground truth
59 | self.queries_ids = [f"q{i}" for i in range(len(query_embeddings))]
60 | self.corpus_ids = [f"c{i}" for i in range(len(doc_embeddings))]
61 | self.relevant_docs = {
62 | f"q{i}": set([f"c{j}" for j in r_docs[i]]) for i in range(len(r_docs))
63 | }
64 | self.queries = [""] * len(query_embeddings)
65 | self.corpus = [""] * len(doc_embeddings)
66 |
67 | # Determine the largest k needed
68 | max_k = max(
69 | max(self.mrr_at_k),
70 | max(self.ndcg_at_k),
71 | max(self.accuracy_at_k),
72 | max(self.precision_recall_at_k),
73 | max(self.map_at_k),
74 | )
75 |
76 | # Initialize result storage
77 | queries_result_list = {
78 | name: [[] for _ in range(len(query_embeddings))]
79 | for name in self.score_functions
80 | }
81 |
82 | # Process corpus in chunks to avoid memory overflow
83 | for corpus_start_idx in trange(
84 | 0, len(self.corpus), self.corpus_chunk_size, desc="Corpus Chunks", disable=not self.show_progress_bar
85 | ):
86 | corpus_end_idx = min(corpus_start_idx + self.corpus_chunk_size, len(self.corpus))
87 | sub_corpus_embeddings = doc_embeddings[corpus_start_idx:corpus_end_idx]
88 |
89 | for name, score_function in self.score_functions.items():
90 | # Compute similarity scores
91 | pair_scores = score_function(query_embeddings, sub_corpus_embeddings)
92 |
93 | # Top-k scores and indices
94 | pair_scores_top_k_values, pair_scores_top_k_idx = torch.topk(
95 | pair_scores, min(max_k, len(pair_scores[0])), dim=1, largest=True, sorted=False
96 | )
97 | pair_scores_top_k_values = pair_scores_top_k_values.cpu().tolist()
98 | pair_scores_top_k_idx = pair_scores_top_k_idx.cpu().tolist()
99 |
100 | for query_itr in range(len(query_embeddings)):
101 | for sub_corpus_id, score in zip(
102 | pair_scores_top_k_idx[query_itr], pair_scores_top_k_values[query_itr]
103 | ):
104 | corpus_id = self.corpus_ids[corpus_start_idx + sub_corpus_id]
105 | # Do not skip query==corpus as some setups use identical indexing
106 | if len(queries_result_list[name][query_itr]) < max_k:
107 | heapq.heappush(queries_result_list[name][query_itr], (score, corpus_id))
108 | else:
109 | heapq.heappushpop(queries_result_list[name][query_itr], (score, corpus_id))
110 |
111 | # Convert heap tuples to result dicts
112 | for name in queries_result_list:
113 | for query_itr in range(len(queries_result_list[name])):
114 | queries_result_list[name][query_itr] = [
115 | {"corpus_id": corpus_id, "score": score}
116 | for score, corpus_id in queries_result_list[name][query_itr]
117 | ]
118 |
119 | # Compute standard IR metrics
120 | scores = {
121 | name: self.compute_metrics(queries_result_list[name])
122 | for name in self.score_functions
123 | }
124 |
125 | # Compute average cosine similarity between each query and its relevant docs
126 | for name, score_func in self.score_functions.items():
127 | v_list = []
128 | for i in range(len(query_embeddings)):
129 | score_v = score_func(query_embeddings[i].reshape(1, -1), doc_embeddings[r_docs[i]])
130 | v_list.append(score_v.mean().item())
131 | v_list = np.array(v_list)
132 | scores["cos"][f"{name}_mean"] = np.mean(v_list)
133 | p_list = [10, 50, 90]
134 | scores["cos"][f"{name}_percentile@k"] = {p: np.percentile(v_list, p) for p in p_list}
135 |
136 | return scores
137 |
138 |
139 | def compute_metrics_custom(p):
140 | # p is an EvalPrediction object containing predictions and label_ids
141 | query_embeddings, doc_embeddings, ids, doc_ids = p.predictions
142 |
143 | n = len(ids)
144 |
145 | doc_embeddings_flatten = []
146 | rdocs = []
147 | tot = 0
148 | for i in range(n):
149 | num = doc_ids[i] # number of documents for query i
150 | rdocs.append(list(range(tot, tot + num))) # relevant doc indices for query i
151 | for j in range(num):
152 | doc_embeddings_flatten.append(doc_embeddings[i][j])
153 | tot += num
154 |
155 | doc_embeddings_flatten = np.array(doc_embeddings_flatten)
156 |
157 | evaluator = CustomRetrievalEvaluator()
158 | results = evaluator.calc_results(query_embeddings, doc_embeddings_flatten, rdocs)["cos"]
159 |
160 | # Flatten nested results into a flat dictionary
161 | new_results = {}
162 | for k in results.keys():
163 | if isinstance(results[k], dict):
164 | for kk in results[k].keys():
165 | new_results[f"{k}_{kk}"] = results[k][kk]
166 | else:
167 | new_results[k] = results[k]
168 |
169 | # Log results only on rank 0
170 | if torch.distributed.get_rank() == 0:
171 | print(results)
172 |
173 | return new_results
174 |
--------------------------------------------------------------------------------
/stage1/custom_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from transformers import Trainer, TrainingArguments
6 | import torch.distributed.nn
7 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
8 | from trainer_headers import *
9 | from grad_cache_custom import GradCacheCustom
10 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
11 |
12 | # grad_cache version
13 | class ContrastiveTrainer(Trainer):
14 | def __init__(
15 | self,
16 | loss_fn=None,
17 | multi_pos=1,
18 | chunk_size=8,
19 | *args,
20 | **kwargs,
21 | ):
22 | super().__init__(*args, **kwargs)
23 | self.loss_fn = loss_fn
24 | self.grad_cache = GradCacheCustom(
25 | models=[self.model, self.model],
26 | chunk_sizes=[chunk_size] * 2,
27 | loss_fn=self.loss_fn,
28 | multi_pos=multi_pos,
29 | )
30 |
31 | def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
32 | # Compute sentence embeddings for query and document
33 | query_embeddings = model.forward(
34 | {
35 | 'input_ids': inputs['input_ids'],
36 | 'attention_mask': inputs['attention_mask'],
37 | }
38 | )['sentence_embedding']
39 | doc_embeddings = model.forward(
40 | {
41 | 'input_ids': inputs['doc_input_ids'],
42 | 'attention_mask': inputs['doc_attention_mask'],
43 | }
44 | )['sentence_embedding']
45 |
46 | # Gather embeddings across all processes for contrastive loss
47 | query_embeddings_new = torch.distributed.nn.all_gather(query_embeddings)
48 | doc_embeddings_new = torch.distributed.nn.all_gather(doc_embeddings)
49 | query_embeddings_new = torch.cat(query_embeddings_new, dim=0)
50 | doc_embeddings_new = torch.cat(doc_embeddings_new, dim=0)
51 |
52 | # Compute the contrastive loss
53 | loss = self.loss_fn(query_embeddings_new, doc_embeddings_new)
54 |
55 | # Optionally return intermediate outputs
56 | return (loss, (query_embeddings, doc_embeddings)) if return_outputs else loss
57 |
58 | # Evaluation step
59 | def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None):
60 | # Get the device of the model
61 | device = next(model.parameters()).device
62 |
63 | # Move inputs to the model's device
64 | inputs = {key: value.to(device) for key, value in inputs.items()}
65 |
66 | # Disable gradient computation
67 | with torch.no_grad():
68 | # Compute query embedding
69 | query_embeddings = model.forward(
70 | {
71 | 'input_ids': inputs['input_ids'],
72 | 'attention_mask': inputs['attention_mask'],
73 | }
74 | )['sentence_embedding']
75 |
76 | n = inputs['doc_input_ids'].shape[1]
77 | doc_embeddings = None
78 |
79 | # Loop through each document in the list (dimension 1)
80 | for i in range(0, n):
81 | doc_embedding = model.forward(
82 | {
83 | 'input_ids': inputs['doc_input_ids'][:, i],
84 | 'attention_mask': inputs['doc_attention_mask'][:, i],
85 | }
86 | )['sentence_embedding'].unsqueeze(1)
87 | doc_embeddings = doc_embedding if doc_embeddings is None else torch.cat((doc_embeddings, doc_embedding), dim=1)
88 |
89 | # Return output format: (loss, predictions, labels)
90 | if prediction_loss_only:
91 | return (None, None, None)
92 |
93 | return None, \
94 | (query_embeddings, doc_embeddings, inputs['id'], inputs['doc_id']), \
95 | torch.zeros_like(inputs['id'])
96 |
97 | def training_step(
98 | self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], num_items_in_batch=None
99 | ) -> torch.Tensor:
100 | """
101 | Perform a training step with support for gradient caching.
102 |
103 | Args:
104 | model (nn.Module): The model being trained.
105 | inputs (dict): A batch of input data.
106 | num_items_in_batch (optional): The number of training examples in the batch.
107 |
108 | Returns:
109 | torch.Tensor: The computed training loss.
110 | """
111 | model.train()
112 | if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
113 | self.optimizer.train()
114 |
115 | inputs = self._prepare_inputs(inputs)
116 |
117 | # SageMaker model parallelism path
118 | if is_sagemaker_mp_enabled():
119 | loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps)
120 | return loss_mb.reduce_mean().detach().to(self.args.device)
121 |
122 | self.grad_cache.compute_loss_context_manager = self.compute_loss_context_manager
123 |
124 | # Empty cache conditionally at configured steps
125 | if (
126 | self.args.torch_empty_cache_steps is not None
127 | and self.state.global_step % self.args.torch_empty_cache_steps == 0
128 | ):
129 | if is_torch_xpu_available():
130 | torch.xpu.empty_cache()
131 | elif is_torch_mlu_available():
132 | torch.mlu.empty_cache()
133 | elif is_torch_musa_available():
134 | torch.musa.empty_cache()
135 | elif is_torch_npu_available():
136 | torch.npu.empty_cache()
137 | elif is_torch_mps_available(min_version="2.0"):
138 | torch.mps.empty_cache()
139 | else:
140 | torch.cuda.empty_cache()
141 |
142 | kwargs = {}
143 |
144 | # Handle special optimizer requirements
145 | if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
146 | kwargs["learning_rate"] = self._get_learning_rate()
147 |
148 | # Mean-reduce loss in multi-GPU setting
149 | if self.args.n_gpu > 1:
150 | loss = loss.mean()
151 |
152 | # Setup backward function depending on backend
153 | if self.use_apex:
154 | def apex_backward(loss):
155 | with amp.scale_loss(loss, self.optimizer) as scaled_loss:
156 | scaled_loss.backward()
157 | self.grad_cache.backward_fn = apex_backward
158 | else:
159 | if self.accelerator.distributed_type == DistributedType.DEEPSPEED:
160 | kwargs["scale_wrt_gas"] = False
161 |
162 | def accelerator_backward(loss):
163 | # Normalize loss across accumulation steps unless the model handles this internally
164 | if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
165 | loss = loss / self.args.gradient_accumulation_steps
166 | self.accelerator.backward(loss, **kwargs)
167 | self.grad_cache.backward_fn = accelerator_backward
168 |
169 | # Use gradient caching to compute and apply gradients
170 | return self.grad_cache.cache_step(inputs, num_items_in_batch=num_items_in_batch).detach()
171 |
172 | def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
173 | # Call the base class's save_model method
174 | super().save_model(output_dir, _internal_call)
175 |
176 | model = self.model
177 |
178 | # Only save model on rank 0 in distributed setup
179 | if torch.distributed.get_rank() == 0:
180 | print(f"Rank {torch.distributed.get_rank()} saving model.")
181 | save_dir = os.path.join(output_dir, "sentence-transformer-checkpoint")
182 | os.makedirs(save_dir, exist_ok=True)
183 | # Save in SentenceTransformer-compatible format
184 | model.save(save_dir, safe_serialization=False)
185 |
186 | # Synchronize all processes to ensure save is complete
187 | torch.distributed.barrier()
188 |
--------------------------------------------------------------------------------
/stage2/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | # os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
3 | import torch
4 | from sentence_transformers import (
5 | SentenceTransformer,
6 | SentenceTransformerTrainingArguments,
7 | )
8 | from datasets import load_dataset, Dataset
9 | from transformers.integrations import TensorBoardCallback
10 | from sentence_transformers.evaluation import InformationRetrievalEvaluator
11 | from sentence_transformers.losses import TripletLoss, CachedMultipleNegativesRankingLoss
12 | from sentence_transformers.losses import TripletDistanceMetric
13 | from trainer_custom import CustomTrainer
14 | from sentence_transformers.training_args import BatchSamplers
15 | import argparse
16 | from dataset_process import process_dataset, process_eval_dataset
17 | from sentence_transformers.training_args import MultiDatasetBatchSamplers
18 | import random
19 |
20 | def str2bool(v):
21 | if isinstance(v, bool):
22 | return v
23 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
24 | return True
25 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
26 | return False
27 | else:
28 | raise argparse.ArgumentTypeError('Boolean value expected.')
29 |
30 | # Initialize argument parser
31 | parser = argparse.ArgumentParser(description='Fine-tune a model for sentence similarity')
32 |
33 | # Add arguments
34 | parser.add_argument('--model_path',
35 | default='coldchair16/CPRetriever-Code',
36 | type=str,
37 | help='The name of the model to fine-tune')
38 |
39 | parser.add_argument('--triplet_margin',
40 | default=0.3,
41 | type=float,
42 | help='The margin for triplet loss')
43 |
44 | parser.add_argument('--cut',
45 | default=True,
46 | type=str2bool,
47 | help='Whether to cut the dataset SimplifiedRetrieval')
48 |
49 | parser.add_argument('--lr',
50 | default=2e-6,
51 | type=float,
52 | help='The learning rate for training')
53 |
54 | parser.add_argument('--epochs',
55 | default=1,
56 | type=int,
57 | help='Number of epochs for training')
58 |
59 | parser.add_argument('--dataset_id',
60 | default='coldchair16/CPRet-data',
61 | type=str,
62 | help='The dataset ID for training')
63 |
64 | parser.add_argument('--eval_task_list',
65 | default=['T2C', 'C2C', 'P2Dup', 'S2Full'],
66 | type=str,
67 | nargs='+',
68 | help='List of evaluation tasks')
69 |
70 | parser.add_argument('--train_task_list',
71 | default=['P2Dup', 'S2Full'],
72 | # default=['PCD', 'P2Dup', 'S2Full'],
73 | type=str,
74 | nargs='+',
75 | help="List of training tasks: ['PCD', 'P2Dup', 'S2Full']")
76 |
77 | parser.add_argument('--max_length',
78 | default=1024,
79 | type=int,
80 | help='Maximum input length for the model')
81 |
82 | parser.add_argument('--eval_only',
83 | default=False,
84 | type=str2bool,
85 | help='Whether to run evaluation only, not training')
86 |
87 |
88 | def main():
89 | # Parse arguments
90 | args = parser.parse_args()
91 | # Use the parsed arguments in your code
92 | dataset_id = args.dataset_id
93 | eval_task_list = args.eval_task_list
94 | train_task_list = args.train_task_list
95 | max_length = args.max_length
96 | eval_only = args.eval_only
97 | model_path = args.model_path
98 | triplet_margin = args.triplet_margin
99 | cut = args.cut
100 | lr = args.lr
101 | epochs = args.epochs
102 |
103 | model_name = model_path.split('/')[-1]
104 |
105 | model_kwargs = {'torch_dtype': torch.bfloat16} # TODO: Make sure your GPU and model support bfloat16
106 | model = SentenceTransformer(model_path, trust_remote_code=True, model_kwargs=model_kwargs)
107 |
108 | model.tokenizer.model_max_length = max_length
109 | model.max_seq_length = max_length
110 |
111 | evaluator_list = []
112 | for task in eval_task_list:
113 | queries = load_dataset(dataset_id, f"{task}-queries", split='test')
114 | corpus = load_dataset(dataset_id, f"{task}-corpus", split='test')
115 | qrels = load_dataset(dataset_id, f"{task}-qrels", split='test')
116 | queries, corpus, relevant_docs = process_eval_dataset(queries, corpus, qrels)
117 | evaluator = InformationRetrievalEvaluator(
118 | queries=queries,
119 | corpus=corpus,
120 | relevant_docs=relevant_docs,
121 | name=task,
122 | )
123 | evaluator_list.append(evaluator)
124 |
125 | train_dataset = {}
126 | for task in train_task_list:
127 | if (task == 'PCD'):
128 | data = load_dataset(dataset_id, f"PCPCD")
129 | traindata_PCPCD, testdata_PCPCD = process_dataset(data, max_length)
130 | anchor = []
131 | positive = []
132 | for d in traindata_PCPCD:
133 | anchor.append(d['question'])
134 | positive.append(random.choice(d['solutions']))
135 | dataset = Dataset.from_dict({
136 | 'anchor' : anchor,
137 | 'positive' : positive,
138 | })
139 | elif (task == 'P2Dup' or task == 'S2Full'):
140 | data = load_dataset(dataset_id, f"{task}-train-pairs", split='train')
141 | dataset = Dataset.from_dict({
142 | 'anchor': data['query'],
143 | 'positive': data['pos'],
144 | 'negative': data['neg'],
145 | })
146 | if (cut and task == 'S2Full'):
147 | dataset = dataset.select(random.sample(range(len(dataset)), 1000))
148 | print(f"train task = {task} len = {len(dataset)}")
149 | train_dataset[task] = dataset
150 |
151 |
152 | output_dir = f"./results/{model_name}"
153 | output_dir = output_dir + '_' + '-'.join(train_task_list)
154 | output_dir = output_dir + f"_margin{triplet_margin}"
155 | output_dir = output_dir + f"_lr{lr}"
156 | output_dir = output_dir + f'_maxlength{max_length}'
157 | if (cut):
158 | output_dir = output_dir + "_cut"
159 | output_dir = output_dir + f"_epochs{epochs}"
160 | print(f"output_dir : {output_dir}")
161 |
162 |
163 | args = SentenceTransformerTrainingArguments(
164 | output_dir=output_dir,
165 | num_train_epochs=epochs,
166 | per_device_train_batch_size=1,
167 | per_device_eval_batch_size=1,
168 | learning_rate=lr,
169 | lr_scheduler_type='constant',
170 | bf16=True,
171 | eval_strategy="epoch",
172 | save_strategy="epoch",
173 | save_total_limit=1,
174 | logging_steps=10,
175 | run_name='CPRet', # Will be used in W&B if `wandb` is installed
176 | save_on_each_node=False,
177 | max_grad_norm=1.0,
178 | batch_sampler=BatchSamplers.NO_DUPLICATES,
179 | multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
180 | )
181 |
182 |
183 | losses = {
184 | 'PCD' : CachedMultipleNegativesRankingLoss(model, mini_batch_size=2, scale=14.28571429),
185 | 'P2Dup' : TripletLoss(model, TripletDistanceMetric.COSINE, triplet_margin=triplet_margin),
186 | 'S2Full' : TripletLoss(model, TripletDistanceMetric.COSINE, triplet_margin=triplet_margin),
187 | }
188 |
189 | # 为每个数据集单独设置 batch size
190 | batch_sizes = {
191 | 'PCD' : 32,
192 | 'P2Dup' : 1,
193 | 'S2Full' : 1,
194 | }
195 |
196 | trainer = CustomTrainer(
197 | batch_sizes=batch_sizes,
198 | model=model,
199 | train_dataset=train_dataset,
200 | args = args,
201 | loss=losses,
202 | callbacks=[TensorBoardCallback()],
203 | evaluator=evaluator_list,
204 | )
205 | if (eval_only):
206 | eval_results = trainer.evaluate()
207 | print(eval_results)
208 | else:
209 | trainer.train()
210 |
211 | if __name__ == '__main__':
212 | main()
213 |
--------------------------------------------------------------------------------
/stage1/grad_cache/pytorch_lightning/pl_example.py:
--------------------------------------------------------------------------------
1 | """
2 | PyTorch Lightning Example of using Grad Cache, tested on PyTorch Lightning version '2.2.0.post0' with Multi-GPUs and Mix-Precision (fp-16).
3 | Required to install Pytorch Metric Learning as well for contrastive loss calculation.
4 | """
5 |
6 | import os
7 | import argparse
8 | import torch
9 | import lightning as pl
10 | from contextlib import nullcontext
11 | from lightning.pytorch.loggers import WandbLogger
12 | from lightning.pytorch.strategies import DDPStrategy
13 | from pytorch_metric_learning.utils import distributed as pml_dist
14 | from pytorch_metric_learning.losses import SupConLoss
15 |
16 | from grad_cache.pytorch_lightning.pl_gradcache import PLGradCache
17 |
18 |
19 | class RandomDataset(torch.utils.data.Dataset):
20 | def __init__(self, params):
21 | self.params = params
22 |
23 | def __len__(self):
24 | return self.params.data_size
25 |
26 | def __getitem__(self, idx):
27 | # Generate random float inputs with shape [2, input_dim] for contrastive learning
28 | input_data = torch.randn(2, self.params.input_dim)
29 | # Generate a random integer label for binary classification (0 or 1), replicate it to have shape [2]
30 | label = torch.randint(0, 2, (1,), dtype=torch.long)
31 | label = torch.tensor([label, label], dtype=torch.long)
32 | return input_data, label
33 |
34 |
35 | class SimpleLitModel(pl.LightningModule):
36 | def __init__(self, params):
37 | super().__init__()
38 | self.params = params
39 | self.loss = SupConLoss(temperature=params.temperature)
40 | if params.gpus > 1:
41 | self.loss = pml_dist.DistributedLossWrapper(self.loss)
42 | self.automatic_optimization = (not self.params.use_gc) # needed when use_gc is on
43 | self.fp16 = (self.params.precision == 16)
44 | self.linear = torch.nn.Linear(params.input_dim, params.embed_dim) # our simple model
45 |
46 | def init_gc(self, scaler, ddp_module):
47 | """Sets up the required components of GradCache. This method is called after the model is initialized."""
48 | assert self.params.use_gc
49 | if self.fp16 and self.params.use_gc:
50 | # pytorch lightning autocast wraps everything in it
51 | # it needs to be disabled in gradcache because we do forward twice, and one with no grad
52 | # then we do autocast manually in gradcache when we need to
53 | # original post: https://discuss.pytorch.org/t/autocast-and-torch-no-grad-unexpected-behaviour/93475/3
54 | # pl source code: your_venv_name/lib/python3.8/site-packages/lightning/pytorch/plugins/precision/amp.py::forward_context
55 | self.trainer.strategy.precision_plugin.forward_context = nullcontext
56 |
57 | print(f"*** initializing gradcache with ddp_module={type(ddp_module)}, minibatch_size={self.params.gc_minibatch_size}")
58 | self.gc = PLGradCache(
59 | models=[ddp_module],
60 | chunk_sizes=self.params.gc_minibatch_size,
61 | loss_fn=self.calculate_loss,
62 | fp16=self.fp16,
63 | scaler=(scaler if self.fp16 else None), # needed when using automatic_optimization is off and fp16 is on
64 | backward_fn=self.manual_backward, # needed when automatic_optimization is off
65 | )
66 |
67 | def train_dataloader(self):
68 | train_dataset = RandomDataset(params)
69 | train_loader = torch.utils.data.DataLoader(
70 | train_dataset,
71 | batch_size=params.batch_size,
72 | num_workers=params.num_workers,
73 | drop_last=True,
74 | )
75 | return train_loader
76 |
77 | def configure_optimizers(self):
78 | optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
79 | return optimizer
80 |
81 | def calculate_loss(self, embeddings, labels):
82 | # embeddings shape [batch_size, 2, embed_dim]
83 | # labels shape [batch_size, 2]
84 | embeddings = embeddings.flatten(0, 1)
85 | labels = labels.flatten()
86 | return self.loss(embeddings, labels)
87 |
88 | def forward(self, inputs): # needed for grad cache
89 | return self.linear(inputs)
90 |
91 | def on_train_start(self): # initialize grad cache here
92 | if self.params.use_gc:
93 | self.init_gc(self.trainer.scaler, self.trainer.strategy.model)
94 | # self.init_gc(self.trainer.scaler, self.trainer.lightning_module) # we can use this if nccl strategy is available
95 |
96 | def training_step(self, batch, batch_idx):
97 | # inputs shape [batch_size, 2, input_dim]
98 | # labels shape [batch_size, 2]
99 | inputs, labels = batch
100 | if self.params.use_gc:
101 | assert self.gc is not None
102 | optimizer = self.optimizers()
103 | optimizer.zero_grad()
104 | loss = self.gc(
105 | inputs,
106 | no_sync_except_last=(self.params.gpus > 1),
107 | labels=labels.flatten(),
108 | )
109 | loss /= max(1, self.params.gpus) # needed when automatic_optimization is off
110 | log_loss = loss
111 | optimizer.step()
112 | else:
113 | outputs = self.linear(inputs)
114 | loss = self.calculate_loss(outputs, labels)
115 | log_loss = loss / max(1, self.params.gpus)
116 | self.log(
117 | "train_loss",
118 | log_loss,
119 | on_step=True,
120 | on_epoch=True,
121 | sync_dist=self.params.use_gc, # needed when automatic_optimization is off
122 | )
123 | print(f"batch_idx={batch_idx}, loss={loss}")
124 | return loss
125 |
126 |
127 | def get_argument_parser():
128 | parser = argparse.ArgumentParser()
129 | parser.add_argument("--random_seed", type=int, default=42)
130 | parser.add_argument("--num_workers", type=int, default=4)
131 | parser.add_argument("--gpus", type=int, default=0)
132 | parser.add_argument("--precision", type=int, default=32)
133 | parser.add_argument("--ddp_backend", type=str, default="nccl", help="torch distributed backend (Default: nccl), use 'gloo' if nccl doesn't work")
134 | parser.add_argument("--project_name", type=str, default="debug_gradcache")
135 |
136 | # training params
137 | parser.add_argument("--data_size", type=int, default=100)
138 | parser.add_argument("--epochs", type=int, default=10)
139 | parser.add_argument("--batch_size", type=int, default=16)
140 | parser.add_argument("--temperature", type=float, default=0.1)
141 |
142 | # model hyperparams
143 | parser.add_argument("--input_dim", type=int, default=784)
144 | parser.add_argument("--embed_dim", type=int, default=512)
145 |
146 | # grad cache params
147 | parser.add_argument("--use_gc", action="store_true", default=False, help="whether to use grad cache")
148 | parser.add_argument("--gc_minibatch_size", type=int, default=2, help="mini batch size of grad cache, must be provided if use_gc is on")
149 |
150 | return parser
151 |
152 |
153 | def main(params):
154 | # set random seeds reproduceability
155 | torch.backends.cudnn.deterministic = True
156 | torch.backends.cudnn.benchmark = False
157 |
158 | # set different random seeds for each worker
159 | pl.seed_everything(seed=params.random_seed, workers=True)
160 |
161 | # weirdness with HuggingFace tokenizer when processing things in parallel
162 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
163 | torch.multiprocessing.set_sharing_strategy("file_system")
164 |
165 | # load model
166 | model = SimpleLitModel(params)
167 |
168 | # load trainer
169 | experiment_id = f"gpus-{params.gpus}_precision-{params.precision}"
170 | if params.use_gc:
171 | experiment_id += "_gc"
172 | experiment_id += "_pl"
173 | wandb_logger = WandbLogger(
174 | project=params.project_name,
175 | name=experiment_id,
176 | )
177 | ddp = DDPStrategy(process_group_backend=params.ddp_backend)
178 | trainer = pl.Trainer(
179 | accelerator="gpu" if params.gpus > 0 else "cpu",
180 | strategy=ddp if params.gpus > 1 else "auto",
181 | devices=params.gpus if params.gpus > 0 else "auto",
182 | precision=params.precision,
183 | logger=wandb_logger,
184 | max_epochs=params.epochs,
185 | log_every_n_steps=1,
186 | )
187 | trainer.fit(model)
188 |
189 |
190 | if __name__ == "__main__":
191 | params = get_argument_parser().parse_args()
192 | main(params)
193 |
--------------------------------------------------------------------------------
/stage1/grad_cache/pytorch_lightning/pl_gradcache.py:
--------------------------------------------------------------------------------
1 | from contextlib import nullcontext
2 | from typing import Any, Callable, List, Tuple, Union
3 |
4 | import torch
5 | from torch import Tensor, nn
6 | from torch.cuda.amp import GradScaler, autocast
7 |
8 | from ..grad_cache import GradCache, RandContext
9 |
10 | class PLGradCache(GradCache):
11 | """
12 | Gradient Cache class with PyTorch Lightning Support.
13 | Implements input chunking, first graph-less forward pass, Gradient Cache creation, second forward & backward gradient computation.
14 | Optimizer step is not included. Native torch automatic mixed precision is supported.
15 | Gradient unscaling and scaler update are handled internally.
16 | """
17 |
18 | def __init__(
19 | self,
20 | models: List[nn.Module],
21 | chunk_sizes: Union[int, List[int]],
22 | loss_fn: Callable[..., Tensor],
23 | split_input_fn: Callable[[Any, int], Any] = None,
24 | get_rep_fn: Callable[..., Tensor] = None,
25 | fp16: bool = False,
26 | scaler: GradScaler = None,
27 | backward_fn=None, # [added]
28 | ):
29 | """
30 | Initialize the Gradient Cache class instance.
31 | :param models: A list of all encoder models to be updated by the current cache.
32 | :param chunk_sizes: An integer indicating chunk size. Or a list of integers of chunk size for each model.
33 | :param loss_fn: A loss function that takes arbitrary numbers of representation tensors and
34 | arbitrary numbers of keyword arguments as input. It should not in any case modify the input tensors' relations
35 | in the autograd graph, which are later relied upon to create the gradient cache.
36 | :param split_input_fn: An optional function that split generic model input into chunks. If not provided, this
37 | class will try its best to split the inputs of supported types. See `split_inputs` function.
38 | :param get_rep_fn: An optional function that takes generic model output and return representation tensors. If
39 | not provided, the generic output is assumed to be the representation tensor.
40 | :param fp16: If True, run mixed precision training, which requires scaler to also be set.
41 | :param scaler: A GradScaler object for automatic mixed precision training.
42 | :[added] param backward_fn: The `manual_backward` function of pytorch lightning trainer when automatic_optimization is disabled.
43 | """
44 | super().__init__(models, chunk_sizes, loss_fn, split_input_fn, get_rep_fn, fp16, scaler)
45 | self.backward_fn = backward_fn
46 |
47 | def build_cache(self, *reps: Tensor, **loss_kwargs) -> Union[List[Tensor], Tensor]:
48 | """
49 | Compute the gradient cache
50 | :param reps: Computed representations from all encoder models
51 | :param loss_kwargs: Extra keyword arguments to the loss function
52 | :return: A tuple of a) gradient cache for each encoder model, and b) loss tensor
53 | """
54 | reps = [r.detach().requires_grad_() for r in reps]
55 | with autocast() if self.fp16 else nullcontext():
56 | loss = self.compute_loss(*reps, **loss_kwargs)
57 |
58 | self.backward_fn(loss) # [modified]
59 |
60 | cache = [r.grad for r in reps]
61 |
62 | return cache, loss.detach()
63 |
64 | def forward_backward(
65 | self,
66 | model: nn.Module,
67 | model_inputs,
68 | cached_gradients: List[Tensor],
69 | random_states: List[RandContext],
70 | no_sync_except_last: bool = False,
71 | ):
72 | """
73 | Run the second forward and the backward pass to compute gradient for a model.
74 | :param model: Encoder model.
75 | :param model_inputs: Chunked input to the encoder model.
76 | :param cached_gradients: Chunked gradient cache tensor for each input.
77 | :param random_states: Each input's device random state during the first forward.
78 | :param no_sync_except_last: If True, under distributed setup, only trigger gradient reduction across processes
79 | for the last sub-batch's forward-backward pass.
80 | """
81 | if isinstance(
82 | model, nn.parallel.DistributedDataParallel
83 | ): # [use ddp_model]
84 |
85 | if no_sync_except_last:
86 | sync_contexts = [
87 | model.no_sync for _ in range(len(model_inputs) - 1)
88 | ] + [nullcontext]
89 | sync_flags = [True] * (len(model_inputs)) # [added]
90 | else:
91 | sync_contexts = [nullcontext for _ in range(len(model_inputs))]
92 | sync_flags = [False] * (len(model_inputs)) # [added]
93 |
94 | # [modified]
95 | for x, state, gradient, sync_context, sync_flag in zip(
96 | model_inputs, random_states, cached_gradients, sync_contexts, sync_flags
97 | ):
98 | with sync_context():
99 | with state:
100 | y = self.model_call(model, x)
101 | reps = self.get_reps(y)
102 | surrogate = torch.dot(reps.flatten(), gradient.flatten())
103 | if sync_flag:
104 | model.require_backward_grad_sync = True
105 | if self.fp16: # [added]
106 | self.scaler._enabled = False
107 | self.backward_fn(surrogate)
108 | self.scaler._enabled = True
109 | else:
110 | self.backward_fn(surrogate) # [modified]
111 | else: # [use base model (i.e. SimpleLitModel)]
112 |
113 | # [remove no_sync_except_last: pytorch lightning would handle gradient sync automatically]
114 | for x, state, gradient in zip(
115 | model_inputs, random_states, cached_gradients
116 | ):
117 | with state:
118 | y = self.model_call(model, x)
119 | reps = self.get_reps(y)
120 | surrogate = torch.dot(reps.flatten(), gradient.flatten())
121 | if self.fp16: # [added]
122 | self.scaler._enabled = False
123 | self.backward_fn(surrogate)
124 | self.scaler._enabled = True
125 | else:
126 | self.backward_fn(surrogate) # [added]
127 |
128 | def cache_step(
129 | self, *model_inputs, no_sync_except_last: bool = False, **loss_kwargs
130 | ) -> Tuple[Tensor, Tensor]:
131 | """
132 | Run a cached step to compute gradient over the inputs.
133 | :param model_inputs: Input to each encoder model. Should be in similar order as the class's model.
134 | :param no_sync_except_last: If True, under distributed setup, for each model, only trigger gradient reduction
135 | across processes for the last sub-batch's forward-backward pass.
136 | :param loss_kwargs: Additional keyword arguments to the loss function.
137 | :return: A tuple of the current's loss and the model's representation.
138 | """
139 | all_reps = []
140 | all_rnd_states = []
141 |
142 | # [removed: we check it in forward_backward(.)]
143 | # if no_sync_except_last:
144 | # assert all(map(lambda m: isinstance(m, nn.parallel.DistributedDataParallel), self.models)), \
145 | # 'Some of models are not wrapped in DistributedDataParallel. Make sure you are running DDP with ' \
146 | # 'proper initializations.'
147 |
148 | model_inputs = [
149 | self.split_inputs(x, chunk_size)
150 | for x, chunk_size in zip(model_inputs, self.chunk_sizes)
151 | ]
152 |
153 | for model, x in zip(self.models, model_inputs):
154 | model_reps, rnd_states = self.forward_no_grad(model, x)
155 | all_reps.append(model_reps)
156 | all_rnd_states.append(rnd_states)
157 |
158 | # all_reps: len(self.models) x [batch_size, 2, embed_dim]
159 | # cache: len(self.models) x gc_minibatch x [(batch_size / gc_minibatch, 2, embed_dim]
160 |
161 | cache, loss = self.build_cache(*all_reps, **loss_kwargs)
162 | cache = [c.split(chunk_size) for c, chunk_size in zip(cache, self.chunk_sizes)]
163 |
164 | for model, x, model_cache, rnd_states in zip(
165 | self.models, model_inputs, cache, all_rnd_states
166 | ):
167 | self.forward_backward(
168 | model,
169 | x,
170 | model_cache,
171 | rnd_states,
172 | no_sync_except_last=no_sync_except_last,
173 | )
174 |
175 | return loss
176 |
--------------------------------------------------------------------------------
/cp-retrieval-server/templates/index.html:
--------------------------------------------------------------------------------
1 | {% extends "base.html" %}
2 | {% block content %}
3 | {{ t.site_name }}
4 |
5 |
6 | {{ t.info2 | safe }}
7 |
8 |
9 |
10 | {{ t.paper_info | safe }}
11 |
12 |
13 |
14 | {{ t.new_domain_info | safe }}
15 |
16 |
17 |
55 |
56 | {% if results %}
57 |
58 | {{ t.summary.format(total=total, page=page, max_page=max_page, elapsed=elapsed) | safe }}
60 |
61 |
62 |
63 | {% for r in results %}
64 | -
65 |
66 |
67 |
#{{ r.rank }}
68 | {# 👉 主标题 = 原题链接 #}
69 |
{{ r.title }}
70 | {# 🔍 小图标 = 站内详情页 #}
71 |
73 | 📄
74 |
75 |
76 |
{{ r.source }}
77 |
78 |
{{ "%.4f"|format(r.score) }}
79 |
80 |
81 | {% endfor %}
82 |
83 |
84 |
85 |
119 |
120 | {% endif %}
121 |
122 |
144 |
145 |
171 |
241 |
242 |
243 | {% endblock %}
244 |
245 |
246 |
--------------------------------------------------------------------------------
/cp-retrieval-server/app.py:
--------------------------------------------------------------------------------
1 | import os
2 | # os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
3 | import torch
4 | import math
5 | import json
6 | import numpy as np
7 | from flask import Flask, request, render_template, redirect, url_for
8 | from sentence_transformers import SentenceTransformer
9 | import time
10 | from flask import Flask, request, render_template, redirect, url_for
11 | from markupsafe import Markup, escape
12 | from datetime import datetime
13 |
14 | # ===== Multilingual dictionary =====
15 | I18N = {
16 | "zh": {
17 | "site_name" : "CPRet:编程竞赛题目检索",
18 | "new_domain_info": "我们的最新域名是 cpret.online,我们的 GitHub 仓库是 CPRet,欢迎收藏或 star!",
19 | "paper_info": "📰 2025 年 9 月更新:🎉 恭喜!我们的项目论文 CPRet 被 NeurIPS 2025 D&B track 接收!",
20 | "info": "📢 2025 年 7 月更新:我们已升级模型并同步更新了题目数据库,检索效果更佳!",
21 | "info2": "📢 2025 年 10 月更新:新增了部分 OJ 的题目,并进一步优化了模型效果!",
22 | "placeholder": "输入题目描述或简略题意(超过 2048 token 的查询将被截断)…",
23 | "template_btn": "填入示例查询",
24 | "search_btn": "搜索",
25 | "summary" : "共 {total} 条结果,页 {page}/{max_page},耗时 {elapsed:.1f} ms",
26 | "prev" : "上一页",
27 | "next" : "下一页",
28 | "untitled" : "未命名",
29 | "view_origin": "原站链接",
30 | "back": "返回搜索",
31 | "view_stats": "📊 每日搜索统计",
32 | "date": "日期",
33 | "search_count": "搜索次数",
34 | "total_search_count": "总搜索次数",
35 | "example_report": "使用示例(实测报告)",
36 | "filter_by_oj": "筛选 OJ",
37 | "select_all": "全选",
38 | "deselect_all": "全不选",
39 | "moving_average": "渐近平均",
40 | },
41 | "en": {
42 | "site_name" : "CPRet: Competitive Programming Problem Retrieval",
43 | "new_domain_info": "Our new domain is cpret.online. Our GitHub repo is CPRet. Please bookmark or star it!",
44 | "paper_info": "📰 September 2025 Update: 🎉 Congrats! Our project paper CPRet has been accepted by the NeurIPS 2025 D&B track!",
45 | "info": "📢 July 2025 Update: We've upgraded our model and synchronized the problem database for better retrieval!",
46 | "info2": "📢 October 2025 Update: Added new problems from several OJs and further optimized the model performance!",
47 | "placeholder": "Enter problem description or simplified statement (queries longer than 2048 tokens will be truncated)…",
48 | "template_btn": "Insert example query",
49 | "search_btn": "Search",
50 | "summary" : "{total} results, page {page}/{max_page}, {elapsed:.1f} ms",
51 | "prev" : "Prev",
52 | "next" : "Next",
53 | "untitled" : "Untitled",
54 | "view_origin": "Original Link",
55 | "back": "Back to Search",
56 | "view_stats": "📊 Daily Search Stats",
57 | "date": "Date",
58 | "search_count": "Search Count",
59 | "total_search_count": "Total Search Count",
60 | "example_report": "Test Cases (Demo Report)",
61 | "filter_by_oj": "Filter by OJ",
62 | "select_all": "Select All",
63 | "deselect_all": "Deselect All",
64 | "moving_average": "Moving Average",
65 | },
66 | }
67 |
68 |
69 | def detect_lang():
70 | """Language priority: ?lang= -> Accept-Language -> zh"""
71 | qlang = request.args.get("lang")
72 | if qlang in ("zh", "en"):
73 | return qlang
74 | header = request.headers.get("Accept-Language", "")
75 | return "en" if header.lower().startswith("en") else "zh"
76 |
77 |
78 | # ---------------- Configuration ---------------- #
79 | SEARCH_STATS_PATH = "search_stats.json"
80 | MODEL_PATH = os.getenv(
81 | "MODEL_PATH",
82 | "coldchair16/CPRetriever-Prob-Qwen3-4B-2510"
83 | )
84 | EMB_PATH = os.getenv(
85 | 'EMB_PATH',
86 | './probs_2511_embs.npy'
87 | )
88 | PROB_PATH = os.getenv(
89 | 'PROB_PATH',
90 | './probs_2511.jsonl'
91 | )
92 | BF_16 = os.getenv(
93 | "BF_16",
94 | 1,
95 | )
96 |
97 | PAGE_SIZE = 20 # Number of results per page
98 | # ------------------------------------- #
99 |
100 | app = Flask(__name__)
101 |
102 | # ---------- Load model & data on startup ---------- #
103 | print("Loading SentenceTransformer model …")
104 | if BF_16 == 1:
105 | model = SentenceTransformer(MODEL_PATH, trust_remote_code=True, model_kwargs={"torch_dtype": torch.bfloat16})
106 | else:
107 | model = SentenceTransformer(MODEL_PATH, trust_remote_code=True)
108 | model.tokenizer.model_max_length = 2048
109 | model.max_seq_length = 2048
110 |
111 | print("Loading pre‑computed embeddings …")
112 | embs = np.load(EMB_PATH).astype("float32")
113 | embs /= np.linalg.norm(embs, axis=1, keepdims=True)
114 |
115 | print("Loading problem metadata …")
116 | probs = [json.loads(line) for line in open(PROB_PATH, "r", encoding="utf‑8")]
117 |
118 | # 收集所有可用的 OJ 源,用于前端下拉菜单
119 | available_ojs = sorted(list(set(p.get("source") for p in probs if p.get("source"))))
120 |
121 | assert len(probs) == embs.shape[0], "Mismatch between vector and problem count!"
122 | print(f"Ready! {len(probs)} problems indexed.\n")
123 |
124 |
125 | from functools import lru_cache
126 | import hashlib
127 |
128 | def _hash(text: str) -> str:
129 | """Simple hash to shorten long queries as dict keys"""
130 | return hashlib.md5(text.encode("utf-8")).hexdigest()
131 |
132 | @lru_cache(maxsize=1024) # Cache up to 1024 different queries
133 | def search_once(q: str):
134 | """Return ranked indices and similarity list (numpy array -> Python list)"""
135 | # -> float32
136 | q_emb = model.encode(q, convert_to_tensor=True).to(torch.float32).cpu().numpy()
137 | q_emb = q_emb / np.linalg.norm(q_emb)
138 | sims = embs.dot(q_emb)
139 | idx = sims.argsort()[::-1]
140 | return idx.tolist(), sims.tolist()
141 |
142 | def load_search_stats():
143 | """Load daily search statistics from file."""
144 | if not os.path.exists(SEARCH_STATS_PATH):
145 | return {}
146 | with open(SEARCH_STATS_PATH, "r", encoding="utf-8") as f:
147 | return json.load(f)
148 |
149 | def save_search_stats(stats: dict):
150 | """Save search statistics to file."""
151 | with open(SEARCH_STATS_PATH, "w", encoding="utf-8") as f:
152 | json.dump(stats, f, ensure_ascii=False, indent=2)
153 |
154 | def record_search():
155 | """Update today's search count and save to file."""
156 | stats = load_search_stats()
157 | today = datetime.now().strftime("%Y-%m-%d")
158 | stats[today] = stats.get(today, 0) + 1
159 | save_search_stats(stats)
160 |
161 | @app.route("/", methods=["GET"])
162 | def index():
163 | lang = detect_lang()
164 | t = I18N[lang]
165 |
166 | q = request.args.get("q", "").strip()
167 | page = max(int(request.args.get("page", "1")), 1)
168 |
169 | if "oj" in request.args:
170 | selected_ojs = request.args.getlist("oj")
171 | else:
172 | selected_ojs = available_ojs
173 |
174 | results, total, elapsed = [], 0, 0.0
175 |
176 | if q:
177 | record_search()
178 | tic = time.perf_counter()
179 | idx, sims = search_once(q)
180 | elapsed = (time.perf_counter() - tic) * 1_000
181 |
182 | # 根据 OJ 筛选结果
183 | filtered_idx = []
184 | for j in idx:
185 | p = probs[j]
186 | # 如果选中了特定的 OJ,并且当前问题的 source 不在选中列表中,则跳过
187 | if p.get("source") in selected_ojs:
188 | filtered_idx.append(j)
189 |
190 | total = len(filtered_idx)
191 |
192 | results = []
193 | start, end = (page - 1) * PAGE_SIZE, page * PAGE_SIZE
194 | for rank, j in enumerate(filtered_idx[start:end], start=start + 1):
195 | p = probs[j]
196 | results.append({
197 | "rank" : rank,
198 | "pid" : j,
199 | "score" : float(sims[j]),
200 | "title" : p.get("title") or t["untitled"],
201 | "url" : p.get("url", "#"),
202 | "source": p.get("source", ""),
203 | })
204 |
205 |
206 | return render_template(
207 | "index.html",
208 | lang=lang,
209 | t=t,
210 | query=q,
211 | results=results,
212 | page=page,
213 | page_size=PAGE_SIZE,
214 | total=total,
215 | max_page=max(1, math.ceil(total / PAGE_SIZE)),
216 | elapsed=elapsed,
217 | available_ojs=available_ojs,
218 | selected_ojs=selected_ojs,
219 | )
220 |
221 | @app.route("/p/")
222 | def problem(pid: int):
223 | lang = detect_lang()
224 | t = I18N[lang]
225 |
226 | if pid < 0 or pid >= len(probs):
227 | return f"Problem #{pid} not found", 404
228 |
229 | p = probs[pid]
230 |
231 | raw = p.get("text", "(No text)").replace("\n", "
")
232 | text_html = Markup(raw)
233 |
234 | return render_template(
235 | "problem.html",
236 | lang=lang,
237 | t=t,
238 | pid=pid,
239 | title=p.get("title") or t["untitled"],
240 | source=p.get("source", ""),
241 | url=p.get("url", "#"),
242 | text_html=text_html,
243 | query=request.args.get("q", ""), # Pass original query to return button
244 | page=request.args.get("page", "1"),
245 | selected_ojs_str=request.args.get("oj", "") # Pass selected_ojs_str for back button
246 | )
247 |
248 | @app.route("/stats")
249 | def stats():
250 | lang = detect_lang()
251 | t = I18N[lang]
252 |
253 | stats = load_search_stats()
254 | stats_data = sorted(stats.items(), key=lambda x: x[0], reverse=True)
255 |
256 | stats_draw = list(reversed(stats_data))
257 | stats_draw = stats_draw[:-1] # Exclude today
258 |
259 | return render_template(
260 | "stats.html",
261 | lang=lang,
262 | t=t,
263 | stats=stats_data,
264 | stats_draw=stats_draw,
265 | )
266 |
267 | # -------------- Local run entry -------------- #
268 | if __name__ == "__main__":
269 | # export FLASK_ENV=development to enable auto-reload/debug
270 | app.run(host="0.0.0.0", port=5000, debug=False)
--------------------------------------------------------------------------------
/stage1/trainer_headers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 | The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
17 | """
18 |
19 | import contextlib
20 | import copy
21 | import functools
22 | import glob
23 | import importlib.metadata
24 | import inspect
25 | import json
26 | import math
27 | import os
28 | import random
29 | import re
30 | import shutil
31 | import sys
32 | import tempfile
33 | import time
34 | import warnings
35 | from collections.abc import Mapping
36 | from pathlib import Path
37 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
38 |
39 |
40 | # Integrations must be imported before ML frameworks:
41 | # isort: off
42 | from transformers.integrations import (
43 | get_reporting_integration_callbacks,
44 | )
45 |
46 | # isort: on
47 |
48 | import huggingface_hub.utils as hf_hub_utils
49 | import numpy as np
50 | import torch
51 | import torch.distributed as dist
52 | from huggingface_hub import ModelCard, create_repo, upload_folder
53 | from packaging import version
54 | from torch import nn
55 | from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
56 |
57 | from transformers import __version__
58 | from transformers.configuration_utils import PretrainedConfig
59 | from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
60 | from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
61 | from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
62 | from transformers.feature_extraction_utils import FeatureExtractionMixin
63 | from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
64 | from transformers.image_processing_utils import BaseImageProcessor
65 | from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
66 | from transformers.integrations.tpu import tpu_spmd_dataloader
67 | from transformers.modelcard import TrainingSummary
68 | from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
69 | from transformers.models.auto.modeling_auto import (
70 | MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
71 | MODEL_MAPPING_NAMES,
72 | )
73 | from transformers.optimization import Adafactor, get_scheduler
74 | from transformers.processing_utils import ProcessorMixin
75 | from transformers.pytorch_utils import (
76 | ALL_LAYERNORM_LAYERS,
77 | is_torch_greater_or_equal_than_2_3,
78 | )
79 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase
80 | from transformers.trainer_callback import (
81 | CallbackHandler,
82 | DefaultFlowCallback,
83 | ExportableState,
84 | PrinterCallback,
85 | ProgressCallback,
86 | TrainerCallback,
87 | TrainerControl,
88 | TrainerState,
89 | )
90 | from transformers.trainer_pt_utils import (
91 | DistributedTensorGatherer,
92 | EvalLoopContainer,
93 | IterableDatasetShard,
94 | LabelSmoother,
95 | LayerWiseDummyOptimizer,
96 | LengthGroupedSampler,
97 | SequentialDistributedSampler,
98 | distributed_broadcast_scalars,
99 | distributed_concat,
100 | find_batch_size,
101 | get_model_param_count,
102 | get_module_class_from_name,
103 | get_parameter_names,
104 | nested_concat,
105 | nested_detach,
106 | nested_numpify,
107 | nested_xla_mesh_reduce,
108 | reissue_pt_warnings,
109 | remove_dummy_checkpoint,
110 | set_rng_state_for_device,
111 | )
112 | from transformers.trainer_utils import (
113 | PREFIX_CHECKPOINT_DIR,
114 | BestRun,
115 | EvalLoopOutput,
116 | EvalPrediction,
117 | HPSearchBackend,
118 | HubStrategy,
119 | PredictionOutput,
120 | RemoveColumnsCollator,
121 | SaveStrategy,
122 | TrainerMemoryTracker,
123 | TrainOutput,
124 | check_target_module_exists,
125 | default_compute_objective,
126 | denumpify_detensorize,
127 | enable_full_determinism,
128 | find_executable_batch_size,
129 | get_last_checkpoint,
130 | has_length,
131 | neftune_post_forward_hook,
132 | number_of_arguments,
133 | seed_worker,
134 | set_seed,
135 | speed_metrics,
136 | )
137 | from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
138 | from transformers.utils import (
139 | ADAPTER_CONFIG_NAME,
140 | ADAPTER_SAFE_WEIGHTS_NAME,
141 | ADAPTER_WEIGHTS_NAME,
142 | CONFIG_NAME,
143 | SAFE_WEIGHTS_INDEX_NAME,
144 | SAFE_WEIGHTS_NAME,
145 | WEIGHTS_INDEX_NAME,
146 | WEIGHTS_NAME,
147 | XLA_FSDPV2_MIN_VERSION,
148 | PushInProgress,
149 | PushToHubMixin,
150 | can_return_loss,
151 | find_labels,
152 | is_accelerate_available,
153 | is_apex_available,
154 | is_apollo_torch_available,
155 | is_bitsandbytes_available,
156 | is_datasets_available,
157 | is_galore_torch_available,
158 | is_grokadamw_available,
159 | is_in_notebook,
160 | is_ipex_available,
161 | is_liger_kernel_available,
162 | is_lomo_available,
163 | is_peft_available,
164 | is_safetensors_available,
165 | is_sagemaker_dp_enabled,
166 | is_sagemaker_mp_enabled,
167 | is_schedulefree_available,
168 | is_torch_compile_available,
169 | is_torch_mlu_available,
170 | is_torch_mps_available,
171 | is_torch_musa_available,
172 | is_torch_neuroncore_available,
173 | is_torch_npu_available,
174 | is_torch_xla_available,
175 | is_torch_xpu_available,
176 | is_torchao_available,
177 | logging,
178 | strtobool,
179 | )
180 | from transformers.utils.deprecation import deprecate_kwarg
181 | from transformers.utils.quantization_config import QuantizationMethod
182 |
183 |
184 | DEFAULT_CALLBACKS = [DefaultFlowCallback]
185 | DEFAULT_PROGRESS_CALLBACK = ProgressCallback
186 |
187 | if is_in_notebook():
188 | from transformers.utils.notebook import NotebookProgressCallback
189 |
190 | DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
191 |
192 | if is_apex_available():
193 | from apex import amp
194 |
195 | if is_datasets_available():
196 | import datasets
197 |
198 | if is_torch_xla_available():
199 | import torch_xla.core.xla_model as xm
200 | import torch_xla.debug.metrics as met
201 | from torch_xla import __version__ as XLA_VERSION
202 |
203 | IS_XLA_FSDPV2_POST_2_2 = version.parse(XLA_VERSION) >= version.parse(XLA_FSDPV2_MIN_VERSION)
204 | if IS_XLA_FSDPV2_POST_2_2:
205 | import torch_xla.distributed.spmd as xs
206 | import torch_xla.runtime as xr
207 | else:
208 | IS_XLA_FSDPV2_POST_2_2 = False
209 |
210 |
211 | if is_sagemaker_mp_enabled():
212 | import smdistributed.modelparallel.torch as smp
213 | from smdistributed.modelparallel import __version__ as SMP_VERSION
214 |
215 | IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
216 |
217 | from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
218 | else:
219 | IS_SAGEMAKER_MP_POST_1_10 = False
220 |
221 |
222 | if is_safetensors_available():
223 | import safetensors.torch
224 |
225 | if is_peft_available():
226 | from peft import PeftModel
227 |
228 |
229 | if is_accelerate_available():
230 | from accelerate import Accelerator, skip_first_batches
231 | from accelerate import __version__ as accelerate_version
232 | from accelerate.state import AcceleratorState
233 | from accelerate.utils import (
234 | AutocastKwargs,
235 | DistributedDataParallelKwargs,
236 | DistributedType,
237 | load_fsdp_model,
238 | load_fsdp_optimizer,
239 | save_fsdp_model,
240 | save_fsdp_optimizer,
241 | )
242 |
243 | DATA_SAMPLERS = [RandomSampler]
244 | if version.parse(accelerate_version) > version.parse("0.23.0"):
245 | from accelerate.data_loader import SeedableRandomSampler
246 |
247 | DATA_SAMPLERS += [SeedableRandomSampler]
248 |
249 | if is_deepspeed_available():
250 | from accelerate.utils import DeepSpeedSchedulerWrapper
251 |
252 | if is_accelerate_available("0.28.0"):
253 | from accelerate.utils import DataLoaderConfiguration
254 |
255 |
256 | def _is_peft_model(model):
257 | if is_peft_available():
258 | classes_to_check = (PeftModel,) if is_peft_available() else ()
259 | # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
260 | if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
261 | from peft import PeftMixedModel
262 |
263 | classes_to_check = (*classes_to_check, PeftMixedModel)
264 | return isinstance(model, classes_to_check)
265 | return False
266 |
267 |
268 | def _get_fsdp_ckpt_kwargs():
269 | # TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release
270 | if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters):
271 | return {"adapter_only": True}
272 | else:
273 | return {}
274 |
275 |
276 | def safe_globals():
277 | # Starting from version 2.4 PyTorch introduces a check for the objects loaded
278 | # with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes
279 | # a default and requires allowlisting of objects being loaded.
280 | # See: https://github.com/pytorch/pytorch/pull/137602
281 | # See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
282 | # See: https://github.com/huggingface/accelerate/pull/3036
283 | if version.parse(torch.__version__).release < version.parse("2.6").release:
284 | return contextlib.nullcontext()
285 |
286 | np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core
287 | allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype]
288 | # numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for
289 | # all versions of numpy
290 | allowlist += [type(np.dtype(np.uint32))]
291 |
292 | return torch.serialization.safe_globals(allowlist)
293 |
294 |
295 | if TYPE_CHECKING:
296 | import optuna
297 |
298 | if is_datasets_available():
299 | import datasets
300 |
301 | logger = logging.get_logger(__name__)
302 |
303 |
304 | # Name of the files used for checkpointing
305 | TRAINING_ARGS_NAME = "training_args.bin"
306 | TRAINER_STATE_NAME = "trainer_state.json"
307 | OPTIMIZER_NAME = "optimizer.pt"
308 | SCALER_NAME = "scaler.pt"
309 | OPTIMIZER_NAME_BIN = "optimizer.bin"
310 | SCHEDULER_NAME = "scheduler.pt"
311 | FSDP_MODEL_NAME = "pytorch_model_fsdp"
312 |
313 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CPRet: A Dataset, Benchmark, and Model for Retrieval in Competitive Programming
2 |
3 | [](https://arxiv.org/abs/2505.12925)
4 | [](https://huggingface.co/collections/coldchair16/cpret-682451276f05c5988fcbdf34)
5 |
6 | Email contact: 2317757009@qq.com
7 |
8 | ## 🌐 Try Online Demo
9 |
10 | We provide an **online demo** of the CPRet retrieval service, available at:
11 |
12 | 👉 [https://www.cpret.online/](https://www.cpret.online/)
13 |
14 | This demo can assist in **duplicate problem detection** by retrieving potentially similar problems, though final identification still requires manual verification.
15 |
16 | It also supports **similar problem retrieval** to help broaden your problem-solving perspective.
17 |
18 | You can input either a **full problem description** or a **simplified version**, and the system will return the most relevant existing problems.
19 |
20 | You can refer to the usage examples of the retrieval platform at: [https://github.com/coldchair/CPRet/blob/main/TestCases.md](https://github.com/coldchair/CPRet/blob/main/TestCases.md)
21 |
22 | It runs the same codebase and embedding model as the local deployment (see below), so you can preview its capabilities before setting up your own instance.
23 |
24 | ## 🚀 News
25 |
26 | **Oct 2025: CPRetriever-Prob-Qwen3-4B-2510 Released with Enhanced Retrieval Performance!**
27 |
28 | We're excited to announce a major update to the CPRetriever model series! We've trained the new [**CPRetriever-Prob-Qwen3-4B-2510**](https://huggingface.co/coldchair16/CPRetriever-Prob-Qwen3-4B-2510) model based on [**Qwen3-Embedding-4B**](https://huggingface.co/Qwen/Qwen3-Embedding-4B), released in June 2025, and it has achieved **state-of-the-art results** in problem-related retrieval tasks. Concurrently, we've also updated our website's retrieval problem database to the latest Oct 2025 version.
29 |
30 | Here's a comparison of model performance:
31 |
32 | | model | type | size | Text-to-Code | Code-to-Code | Problem-to-Duplicate | Simplified-to-Full | Avg |
33 | | :------------------------ | :--- | :--- | :----------- | :----------- | :------------------- | :----------------- | :----- |
34 | | CPRetriever-Code | code | 2B | 70.40 | 70.59 | 38.68 | 81.45 | 65.28 |
35 | | CPRetriever-Prob | code | 2B | 56.50 | 70.68 | 60.06 | 90.74 | 69.50 |
36 | | CPRetriever-Prob-Qwen3-4B | code | 4B | 65.85 | 70.19 | 71.45 | 95.03 | 75.63 |
37 | | CPRetriever-Prob-Qwen3-4B-2510 | code | 4B | 80.84 | 87.10 | 74.33 | 96.15 | 84.61 |
38 |
39 | The CPRetriever-Prob-Qwen3-4B-2510 model follows the same training procedure and dataset as CPRetriever-Prob-Qwen3-4B, but was retrained in October 2025 with adjusted data proportions, an extended maximum sequence length of 2048, and optimized hyperparameters for improved performance.
40 |
41 |
42 | **Sept 2025:** 🎉 We’re excited to announce that our paper has been accepted to the **NeurIPS 2025 D&B Track**!
43 |
44 | ## 📌 Overview
45 |
46 | **CPRet** is a comprehensive suite for competitive programming retrieval research, consisting of:
47 |
48 | * A large-scale dataset and benchmark for retrieval tasks in coding contests.
49 | * A dual-stage training pipeline with contrastive pretraining and task-specific fine-tuning.
50 | * A local retrieval server for **simplified description** and **duplicate problem** search, powered by our trained model [**CPRetriever-Prob-Qwen3-4B-2510**](https://huggingface.co/coldchair16/CPRetriever-Prob-Qwen3-4B-2510).
51 |
52 | We define the following **four core retrieval tasks** to support both practical applications and academic benchmarking:
53 |
54 | 1. **Text-to-Code (T2C):** Retrieve relevant code given a natural language problem description.
55 | 2. **Code-to-Code (C2C):** Retrieve other implementations of the same problem based on a given solution.
56 | 3. **Problem-to-Duplicate (P2D):** Detect duplicate or near-duplicate problems from existing contest archives.
57 | 4. **Simplified-to-Full (S2F):** Retrieve the original full version of a simplified problem.
58 |
59 |
60 | ## 🧰 Repository Contents
61 |
62 | * `cp-retrieval-server/`: Code for running a local retrieval web service.
63 | * `stage1/`: Code for stage-1 contrastive pretraining.
64 | * `stage2/`: Code for stage-2 problem-level fine-tuning.
65 |
66 | ---
67 |
68 | ## ⚙️ Setup
69 |
70 | ### Environment
71 |
72 | * Recommended: `python >= 3.10`
73 |
74 | * Install dependencies:
75 |
76 | ```bash
77 | pip install -r requirements.txt
78 | ```
79 |
80 | * Install PyTorch (with CUDA support if needed):
81 | → Refer to: [https://pytorch.org/get-started/locally/](https://pytorch.org/get-started/locally/)
82 |
83 | * PyTorch ≥ 2.0 is recommended.
84 |
85 | ### 🔁 Accessing Hugging Face from Restricted Regions
86 |
87 | If you're experiencing connectivity issues with Hugging Face, consider using the official mirror:
88 |
89 | ```python
90 | import os
91 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
92 | ```
93 |
94 | Or set it as an environment variable:
95 |
96 | ```bash
97 | export HF_ENDPOINT=https://hf-mirror.com
98 | ```
99 |
100 | ---
101 |
102 | ## 🚀 Run Local Retrieval Service
103 |
104 |
105 | 1. **Download embeddings:**
106 |
107 | * You run `cp-retrieval-server/download.py` to download the problems and embeddings.
108 |
109 | * **If you are using the new model, `CPRetriever-Prob-Qwen3-4B-2510`:**
110 | * Please download the following files from [HF dataset CPRet-Embeddings](https://huggingface.co/datasets/coldchair16/CPRet-Embeddings) into the `cp-retrieval-server/` directory:
111 | * `probs_2511.jsonl`
112 | * `probs_2511_embs.npy`
113 |
114 | * **If you are using the old model, `CPRetriever-Prob`:**
115 | * Please download the following files from [HF dataset CPRet-Embeddings](https://huggingface.co/datasets/coldchair16/CPRet-Embeddings) into the `cp-retrieval-server/` directory:
116 | * `probs_embs.npy`
117 | * `probs.jsonl`
118 |
119 |
120 |
121 | 2. **Start the service:**
122 |
123 | ```bash
124 | cd cp-retrieval-server
125 | ```
126 |
127 | If you're using the **old model (`CPRetriever-Prob`)**, set the following environment variables before starting the service:
128 |
129 | ```bash
130 | export MODEL_PATH=coldchair16/CPRetriever-Prob
131 | export EMB_PATH=./probs_embs.npy
132 | export PROB_PATH=./probs.jsonl
133 | ```
134 |
135 | Note: bf16 is enabled by default. If your device does not support it, set the environment variable BF_16=0:
136 |
137 | ```bash
138 | export BF_16=0
139 | ````
140 |
141 | Then, run:
142 |
143 | ```bash
144 | python app.py
145 | ```
146 |
147 | 3. **About the Dataset:**
148 |
149 | The current retrieval problem database (as of Nov 2025) includes problems from the following online judges:
150 |
151 | * [Codeforces](https://codeforces.com/)
152 | * [AtCoder](https://atcoder.jp/)
153 | * [SPOJ](https://www.spoj.com/)
154 | * [Nowcoder](https://ac.nowcoder.com/)
155 | * [Luogu](https://www.luogu.com.cn/)
156 | * [Loj](https://loj.ac/)
157 | * [CodeChef](https://www.codechef.com/dashboard)
158 | * [AIZU](https://judge.u-aizu.ac.jp/onlinejudge/)
159 | * [UOJ](https://uoj.ac/)
160 | * [QOJ](https://qoj.ac/)
161 |
162 | The data is collected up to **Nov 2025**.
163 | You can add your own data source and generate embeddings using [`compute_embs.py`](cp-retrieval-server/compute_embs.py). Running this process for the current database on a H100 GPU takes approximately 4 GPU hours.
164 |
165 | If you have access to a larger or more diverse problem dataset, **we welcome contributions and are happy to update the collection** — feel free to contact us (2317757009@qq.com) or open an issue/pull request.
166 |
167 |
168 | 4. **System Requirements:**
169 |
170 | This service can be run on **CPU** or **GPU**, depending on your environment.
171 | We recommend the following memory for smooth operation:
172 |
173 | * For the **2B old models** (e.g., `CPRetriever-Prob`): at least **16GB of system memory or GPU VRAM**.
174 | * For the **4B new model** (`CPRetriever-Prob-Qwen3-4B-2510`): **32GB or more of system memory or GPU VRAM**.
175 |
176 | The above requirements are for fp32; if the device supports bf16, only half of the memory/VRAM is needed.
177 |
178 | Typical query latency:
179 |
180 | * On **CPU** (8 cores): **10–20 seconds**.
181 | * On **GPU** (e.g., A800): **0.1–1 seconds**.
182 |
183 | Inference time depends on the input length.
184 |
185 |
186 | ---
187 |
188 | ## 🏋️♀️ Training Instructions
189 |
190 | > **⚠️ Note**: Recommended GPU memory ≥ **50 GB** to avoid OOM.
191 |
192 | ## 🔧 Stage 1: Contrastive Pretraining
193 |
194 | ```bash
195 | cd stage1
196 | torchrun --nproc_per_node=8 train.py
197 | ````
198 |
199 | * Change `--nproc_per_node` to match the number of available GPUs.
200 | * Use `--help` to see all configurable hyperparameters.
201 |
202 | ### ⚠️ Note on Using `Salesforce/SFR-Embedding-Code-2B_R`
203 |
204 | If you are using [`Salesforce/SFR-Embedding-Code-2B_R`](https://huggingface.co/Salesforce/SFR-Embedding-Code-2B_R) as your encoder, make sure to **manually disable `device_map="auto"`** when loading the model.
205 |
206 | The original code might look like this:
207 |
208 | ```python
209 | self.model = Gemma2Model.from_pretrained(config._name_or_path, trust_remote_code=True, is_causal=False, device_map="auto")
210 | self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, trust_remote_code=True, device_map="auto")
211 | ```
212 |
213 | This setting can cause the model to skip training due to automatic device placement.
214 | **Please change it to:**
215 |
216 | ```python
217 | self.model = Gemma2Model.from_pretrained(config._name_or_path, trust_remote_code=True, is_causal=False, device_map=None)
218 | self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path, trust_remote_code=True, device_map=None)
219 | ```
220 |
221 | Alternatively, you can directly copy the patched file from our repo:
222 | 👉 [modeling\_gemma2.py](https://huggingface.co/coldchair16/CPRetriever-Code/blob/main/modeling_gemma2.py)
223 |
224 |
225 | ---
226 | ### Stage 2: Problem-Level Fine-Tuning
227 |
228 | ```bash
229 | cd stage2
230 | torchrun --nproc_per_node=1 train.py
231 | ```
232 |
233 | * Also supports `--help` to inspect all args.
234 |
235 | ---
236 |
237 | ## 🔧 Notable Hyperparameters
238 |
239 | * `--model_path`: Can be either an HF model repo (e.g. `coldchair16/CPRetriever-Code`) or a local directory supporting SentenceTransformer.
240 | * `--eval_only True`: Run evaluation without training.
241 |
242 |
243 | ## 📫 Citation & License
244 |
245 | If you find **CPRet** useful in your research or applications, please consider citing our paper:
246 |
247 | ```bibtex
248 | @misc{deng2025cpretdatasetbenchmarkmodel,
249 | title = {CPRet: A Dataset, Benchmark, and Model for Retrieval in Competitive Programming},
250 | author = {Han Deng and Yuan Meng and Shixiang Tang and Wanli Ouyang and Xinzhu Ma},
251 | year = {2025},
252 | eprint = {2505.12925},
253 | archivePrefix = {arXiv},
254 | primaryClass = {cs.SE},
255 | url = {https://arxiv.org/abs/2505.12925}
256 | }
257 | ```
258 |
259 | ### 📄 License
260 |
261 | This work is released for **research and non-commercial use only**.
262 |
263 | **License:** [CC BY-NC 4.0 (Attribution–NonCommercial)](https://creativecommons.org/licenses/by-nc/4.0/)
264 |
265 |
266 |
267 |
--------------------------------------------------------------------------------
/stage1/grad_cache/grad_cache.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union, Callable, Any
2 | from contextlib import nullcontext
3 | from itertools import repeat
4 | from collections import UserDict
5 | import logging
6 |
7 | import torch
8 | from torch import nn, Tensor
9 | from torch.cuda.amp import GradScaler, autocast
10 |
11 | from grad_cache.context_managers import RandContext
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | class GradCache:
17 | """
18 | Gradient Cache class. Implements input chunking, first graph-less forward pass, Gradient Cache creation, second
19 | forward & backward gradient computation. Optimizer step is not included. Native torch automatic mixed precision is
20 | supported. User needs to handle gradient unscaling and scaler update after a gradeitn cache step.
21 | """
22 | def __init__(
23 | self,
24 | models: List[nn.Module],
25 | chunk_sizes: Union[int, List[int]],
26 | loss_fn: Callable[..., Tensor],
27 | split_input_fn: Callable[[Any, int], Any] = None,
28 | get_rep_fn: Callable[..., Tensor] = None,
29 | fp16: bool = False,
30 | scaler: GradScaler = None,
31 | ):
32 | """
33 | Initialize the Gradient Cache class instance.
34 | :param models: A list of all encoder models to be updated by the current cache.
35 | :param chunk_sizes: An integer indicating chunk size. Or a list of integers of chunk size for each model.
36 | :param loss_fn: A loss function that takes arbitrary numbers of representation tensors and
37 | arbitrary numbers of keyword arguments as input. It should not in any case modify the input tensors' relations
38 | in the autograd graph, which are later relied upon to create the gradient cache.
39 | :param split_input_fn: An optional function that split generic model input into chunks. If not provided, this
40 | class will try its best to split the inputs of supported types. See `split_inputs` function.
41 | :param get_rep_fn: An optional function that takes generic model output and return representation tensors. If
42 | not provided, the generic output is assumed to be the representation tensor.
43 | :param fp16: If True, run mixed precision training, which requires scaler to also be set.
44 | :param scaler: A GradScaler object for automatic mixed precision training.
45 | """
46 | self.models = models
47 |
48 | if isinstance(chunk_sizes, int):
49 | self.chunk_sizes = [chunk_sizes for _ in range(len(models))]
50 | else:
51 | self.chunk_sizes = chunk_sizes
52 |
53 | self.split_input_fn = split_input_fn
54 | self.get_rep_fn = get_rep_fn
55 | self.loss_fn = loss_fn
56 |
57 | if fp16:
58 | assert scaler is not None, "mixed precision training requires a gradient scaler passed in"
59 |
60 | self.fp16 = fp16
61 | self.scaler = scaler
62 |
63 | self._get_input_tensors_strict = False
64 |
65 | def __call__(self, *args, **kwargs):
66 | """
67 | Call the cache_step function.
68 | :return: Current step loss.
69 | """
70 | return self.cache_step(*args, **kwargs)
71 |
72 | def split_inputs(self, model_input, chunk_size: int) -> List:
73 | """
74 | Split input into chunks. Will call user provided `split_input_fn` if specified. Otherwise,
75 | it can handle input types of tensor, list of tensors and dictionary of tensors.
76 | :param model_input: Generic model input.
77 | :param chunk_size: Size of each chunk.
78 | :return: A list of chunked model input.
79 | """
80 | # delegate splitting to user provided function
81 | if self.split_input_fn is not None:
82 | return self.split_input_fn(model_input, chunk_size)
83 |
84 | if isinstance(model_input, (dict, UserDict)) and all(isinstance(x, Tensor) for x in model_input.values()):
85 | keys = list(model_input.keys())
86 | chunked_tensors = [model_input[k].split(chunk_size, dim=0) for k in keys]
87 | return [dict(zip(kk, tt)) for kk, tt in zip(repeat(keys), zip(*chunked_tensors))]
88 |
89 | elif isinstance(model_input, list) and all(isinstance(x, Tensor) for x in model_input):
90 | chunked_x = [t.split(chunk_size, dim=0) for t in model_input]
91 | return [list(s) for s in zip(*chunked_x)]
92 |
93 | elif isinstance(model_input, Tensor):
94 | return list(model_input.split(chunk_size, dim=0))
95 |
96 | elif isinstance(model_input, tuple) and list(map(type, model_input)) == [list, dict]:
97 | args_chunks = self.split_inputs(model_input[0], chunk_size)
98 | kwargs_chunks = self.split_inputs(model_input[1], chunk_size)
99 | return list(zip(args_chunks, kwargs_chunks))
100 |
101 | else:
102 | raise NotImplementedError(f'Model input split not implemented for type {type(model_input)}')
103 |
104 | def get_input_tensors(self, model_input) -> List[Tensor]:
105 | """
106 | Recursively go through model input and grab all tensors, which are then used to record current device random
107 | states. This method will do its best to parse types of Tensor, tuple, list, dict and UserDict. Other types will
108 | be ignored unless self._get_input_tensors_strict is set to True, in which case an exception will be raised.
109 | :param model_input: input to model
110 | :return: all torch tensors in model_input
111 | """
112 | if isinstance(model_input, Tensor):
113 | return [model_input]
114 |
115 | elif isinstance(model_input, (list, tuple)):
116 | return sum((self.get_input_tensors(x) for x in model_input), [])
117 |
118 | elif isinstance(model_input, (dict, UserDict)):
119 | return sum((self.get_input_tensors(x) for x in model_input.values()), [])
120 |
121 | elif self._get_input_tensors_strict:
122 | raise NotImplementedError(f'get_input_tensors not implemented for type {type(model_input)}')
123 |
124 | else:
125 | return []
126 |
127 | def model_call(self, model: nn.Module, model_input):
128 | """
129 | Literally call the model's __call__ method.
130 | :param model: model to be called
131 | :param model_input: input to the model call
132 | :return: model output
133 | """
134 | with autocast() if self.fp16 else nullcontext():
135 | if isinstance(model_input, Tensor):
136 | return model(model_input)
137 | elif isinstance(model_input, list):
138 | return model(*model_input)
139 | elif isinstance(model_input, (dict, UserDict)):
140 | return model(**model_input)
141 | elif isinstance(model_input, tuple) and list(map(type, model_input)) == [list, dict]:
142 | model_args, model_kwargs = model_input
143 | return model(*model_args, **model_kwargs)
144 | else:
145 | raise NotImplementedError
146 |
147 | def get_reps(self, model_out) -> Tensor:
148 | """
149 | Return representation tensor from generic model output
150 | :param model_out: generic model output
151 | :return: a single tensor corresponding to the model representation output
152 | """
153 | if self.get_rep_fn is not None:
154 | return self.get_rep_fn(model_out)
155 | else:
156 | return model_out
157 |
158 | def compute_loss(self, *reps: Tensor, **loss_kwargs) -> Tensor:
159 | """
160 | Compute the loss based on the representation tensors. The tensors should be ordered same as the list of models
161 | registered in this GradCache class instance.
162 | :param reps: Representations for computing the loss.
163 | :param loss_kwargs: Keyword arguments input to the loss function.
164 | :return: the loss tensor.
165 | """
166 | loss = self.loss_fn(*reps, **loss_kwargs)
167 | return loss
168 |
169 | def forward_no_grad(
170 | self,
171 | model: nn.Module,
172 | model_inputs,
173 | ) -> [Tensor, List[RandContext]]:
174 | """
175 | The first forward pass without gradient computation.
176 | :param model: Encoder model.
177 | :param model_inputs: Model input already broken into chunks.
178 | :return: A tuple of a) representations and b) recorded random states.
179 | """
180 | rnd_states = []
181 | model_reps = []
182 |
183 | with torch.no_grad():
184 | for x in model_inputs:
185 | rnd_states.append(RandContext(*self.get_input_tensors(x)))
186 | y = self.model_call(model, x)
187 | model_reps.append(self.get_reps(y))
188 |
189 | # concatenate all sub-batch representations
190 | model_reps = torch.cat(model_reps, dim=0)
191 | return model_reps, rnd_states
192 |
193 | def build_cache(self, *reps: Tensor, **loss_kwargs) -> [List[Tensor], Tensor]:
194 | """
195 | Compute the gradient cache
196 | :param reps: Computed representations from all encoder models
197 | :param loss_kwargs: Extra keyword arguments to the loss function
198 | :return: A tuple of a) gradient cache for each encoder model, and b) loss tensor
199 | """
200 | reps = [r.detach().requires_grad_() for r in reps]
201 | with autocast() if self.fp16 else nullcontext():
202 | loss = self.compute_loss(*reps, **loss_kwargs)
203 |
204 | if self.fp16:
205 | self.scaler.scale(loss).backward()
206 | else:
207 | loss.backward()
208 |
209 | cache = [r.grad for r in reps]
210 |
211 | return cache, loss.detach()
212 |
213 | def forward_backward(
214 | self,
215 | model: nn.Module,
216 | model_inputs,
217 | cached_gradients: List[Tensor],
218 | random_states: List[RandContext],
219 | no_sync_except_last: bool = False
220 | ):
221 | """
222 | Run the second forward and the backward pass to compute gradient for a model.
223 | :param model: Encoder model.
224 | :param model_inputs: Chunked input to the encoder model.
225 | :param cached_gradients: Chunked gradient cache tensor for each input.
226 | :param random_states: Each input's device random state during the first forward.
227 | :param no_sync_except_last: If True, under distributed setup, only trigger gradient reduction across processes
228 | for the last sub-batch's forward-backward pass.
229 | """
230 | if no_sync_except_last:
231 | sync_contexts = [model.no_sync for _ in range(len(model_inputs) - 1)] + [nullcontext]
232 | else:
233 | sync_contexts = [nullcontext for _ in range(len(model_inputs))]
234 |
235 | for x, state, gradient, sync_context in zip(model_inputs, random_states, cached_gradients, sync_contexts):
236 | with sync_context():
237 | with state:
238 | y = self.model_call(model, x)
239 | reps = self.get_reps(y)
240 |
241 | surrogate = torch.dot(reps.flatten(), gradient.flatten())
242 | surrogate.backward()
243 |
244 | def cache_step(
245 | self,
246 | *model_inputs,
247 | no_sync_except_last: bool = False,
248 | **loss_kwargs
249 | ) -> Tensor:
250 | """
251 | Run a cached step to compute gradient over the inputs.
252 | :param model_inputs: Input to each encoder model. Should be in similar order as the class's model.
253 | :param no_sync_except_last: If True, under distributed setup, for each model, only trigger gradient reduction
254 | across processes for the last sub-batch's forward-backward pass.
255 | :param loss_kwargs: Additional keyword arguments to the loss function.
256 | :return: The current's loss.
257 | """
258 | all_reps = []
259 | all_rnd_states = []
260 |
261 | if no_sync_except_last:
262 | assert all(map(lambda m: isinstance(m, nn.parallel.DistributedDataParallel), self.models)), \
263 | 'Some of models are not wrapped in DistributedDataParallel. Make sure you are running DDP with ' \
264 | 'proper initializations.'
265 |
266 | model_inputs = [self.split_inputs(x, chunk_size) for x, chunk_size in zip(model_inputs, self.chunk_sizes)]
267 |
268 | for model, x in zip(self.models, model_inputs):
269 | model_reps, rnd_states = self.forward_no_grad(model, x)
270 | all_reps.append(model_reps)
271 | all_rnd_states.append(rnd_states)
272 |
273 | cache, loss = self.build_cache(*all_reps, **loss_kwargs)
274 | cache = [c.split(chunk_size) for c, chunk_size in zip(cache, self.chunk_sizes)]
275 |
276 | for model, x, model_cache, rnd_states in zip(
277 | self.models, model_inputs, cache, all_rnd_states):
278 | self.forward_backward(model, x, model_cache, rnd_states, no_sync_except_last=no_sync_except_last)
279 |
280 | return loss
281 |
--------------------------------------------------------------------------------