├── src └── star_align │ ├── __init__.py │ ├── prompt_template.py │ ├── decontamination │ ├── utils.py │ ├── benchmark_data.py │ └── find_substrings.py │ ├── clean_data.py │ ├── utils.py │ ├── train.py │ ├── execution_filter.py │ ├── minhash_dedup.py │ ├── sanitize_data.py │ └── llm_wrapper.py ├── seed_gathering ├── requirements.txt ├── tree_sitter_parser.py ├── README.md ├── benchmark_data.py ├── generate_from_the_stack.py ├── high_quality_subset.py └── filter_dataset.py ├── .gitmodules ├── requirements.txt ├── pyproject.toml ├── sanitize.sh ├── self_ossinstruct_sc2.sh ├── self_ossinstruct_sc2_parallel.sh ├── evaluation ├── README.md ├── text2code.py ├── text2code_vllm.py └── ds_1000.py ├── .gitignore ├── README.md ├── LICENSE └── README-SC2INST.md /src/star_align/__init__.py: -------------------------------------------------------------------------------- 1 | from . import utils 2 | -------------------------------------------------------------------------------- /seed_gathering/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.18.0 2 | tree-sitter==0.20.4 3 | tqdm==4.65.0 4 | torch==2.1.2 5 | vllm==0.4.1 6 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "seed_gathering/tree-sitter-python"] 2 | path = seed_gathering/tree-sitter-python 3 | url = https://github.com/tree-sitter/tree-sitter-python 4 | [submodule "src/star_align/code_exec_server"] 5 | path = src/star_align/code_exec_server 6 | url = git@github.com:cassanof/code_exec_server 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.40.1 2 | vllm==0.4.1 3 | torch>=2.0.1 4 | openai>=1.3.7 5 | tenacity~=8.2.3 6 | tiktoken~=0.6.0 7 | accelerate>=0.27.2 8 | datasets>=2.17.1 9 | evalplus @ git+https://github.com/evalplus/evalplus.git@25e195e024b614f2671ad9ac5b8fdcd9b95a2b24#egg=evalplus 10 | evoeval~=0.1.0 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | description = "StarCoder2-Instruct: Fully Transparent and Permissive Self-Alignment for Code Generation" 3 | dynamic = ["dependencies"] 4 | license = {text = "Apache-2.0"} 5 | name = "star_align" 6 | readme = "README.md" 7 | requires-python = ">=3.10" 8 | version = "0.1.0" 9 | 10 | [tool.setuptools.packages.find] 11 | include = ["star_align*"] 12 | where = ["src"] 13 | 14 | [tool.black] 15 | include = '\.pyi?$' 16 | line-length = 88 17 | target-version = ["py310"] 18 | 19 | [tool.isort] 20 | line_length = 88 21 | profile = "black" 22 | skip_gitignore = true 23 | 24 | [tool.setuptools.dynamic] 25 | dependencies = {file = ["requirements.txt"]} 26 | 27 | [tool.mypy] 28 | check_untyped_defs = true 29 | follow_imports = "silent" 30 | ignore_missing_imports = true 31 | mypy_path = "src" 32 | packages = ["star_align"] 33 | python_version = "3.10" 34 | -------------------------------------------------------------------------------- /src/star_align/prompt_template.py: -------------------------------------------------------------------------------- 1 | SC2_INSTRUCT_PROMPT = """You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions. 2 | 3 | ### Instruction 4 | {instruction} 5 | 6 | ### Response 7 | {response}""" 8 | 9 | CHAT_TEMPLATE = """{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}} 10 | {%- for message in messages %} 11 | {%- if message['role'] == 'system' %} 12 | {{ raise_exception('System messages are not allowed in this template.') }} 13 | {%- else %} 14 | {%- if message['role'] == 'user' %} 15 | {{'### Instruction\n' + message['content'] + '\n\n'}} 16 | {%- else %} 17 | {{'### Response\n' + message['content'] + eos_token + '\n\n'}} 18 | {%- endif %} 19 | {%- endif %} 20 | {%- endfor %} 21 | {{'### Response\n'}}""" -------------------------------------------------------------------------------- /sanitize.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | SOURCE=$1 5 | TARGET=$2 6 | 7 | echo "Sanitizing.." 8 | python -m star_align.sanitize_data \ 9 | --data_files $SOURCE \ 10 | --output_file $TARGET \ 11 | --parse_raw_response True \ 12 | --exact_match_dedup True \ 13 | --passing_only True \ 14 | --include_left_failed False 15 | 16 | if [[ -n $DECONTAMINATION ]]; then 17 | echo "Decontaminating.. (saving to decontamination-output)" 18 | python -m star_align.decontamination.find_substrings \ 19 | --dataset_name "json" \ 20 | --output_file $TARGET \ 21 | --output_dir decontamination-output \ 22 | --columns instruction response \ 23 | --data_files $TARGET 24 | fi 25 | 26 | echo "Minihash dedup.." 27 | python -m star_align.minhash_dedup \ 28 | --data_files $TARGET \ 29 | --column instruction \ 30 | --output $TARGET 31 | 32 | python -m star_align.minhash_dedup \ 33 | --data_files $TARGET \ 34 | --column response \ 35 | --output $TARGET 36 | 37 | python -m star_align.minhash_dedup \ 38 | --data_files $TARGET \ 39 | --column code_representation \ 40 | --ignore_empty True \ 41 | --output $TARGET 42 | -------------------------------------------------------------------------------- /self_ossinstruct_sc2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "MODE: $MODE" 4 | echo "SEED_DATA_FILE: $SEED_DATA_FILE" 5 | echo "INDEX: $INDEX" 6 | echo "MAX_NEW_DATA: $MAX_NEW_DATA" 7 | echo "DIR: $1" 8 | 9 | # if mode is "I->R", num of samples is 10, otherwise 1 10 | if [[ "$MODE" == "I->R" ]]; then 11 | N_SAMPLES=1 12 | NUM_FEWSHOTS=1 13 | NUM_BATCHED_REQUESTS=4096 14 | ASYNC_MICRO_BATCH_SIZE=16 15 | else 16 | N_SAMPLES=1 17 | NUM_FEWSHOTS=8 18 | NUM_BATCHED_REQUESTS=4096 19 | ASYNC_MICRO_BATCH_SIZE=8 20 | fi 21 | 22 | echo "N_SAMPLES: $N_SAMPLES" 23 | echo "NUM_FEWSHOTS: $NUM_FEWSHOTS" 24 | echo "NUM_BATCHED_REQUESTS: $NUM_BATCHED_REQUESTS" 25 | echo "ASYNC_MICRO_BATCH_SIZE: $ASYNC_MICRO_BATCH_SIZE" 26 | 27 | COMMAND="python -m star_align.self_ossinstruct \ 28 | --async_micro_batch_size $ASYNC_MICRO_BATCH_SIZE \ 29 | --use_vllm_server True \ 30 | --instruct_mode '$MODE' \ 31 | --seed_data_files $SEED_DATA_FILE \ 32 | --max_new_data $MAX_NEW_DATA \ 33 | --tag sc2-${NUM_FEWSHOTS}shot \ 34 | --temperature 0.7 \ 35 | --seed_code_start_index $INDEX \ 36 | --model bigcode/starcoder2-15b \ 37 | --num_fewshots $NUM_FEWSHOTS \ 38 | --num_batched_requests $NUM_BATCHED_REQUESTS \ 39 | --num_sample_per_request $N_SAMPLES \ 40 | --save_dir $1" 41 | 42 | if [[ -n "$2" ]]; then 43 | COMMAND="$COMMAND --continue_from $2" 44 | fi 45 | 46 | echo "Running command: $COMMAND" 47 | eval $COMMAND 48 | -------------------------------------------------------------------------------- /seed_gathering/tree_sitter_parser.py: -------------------------------------------------------------------------------- 1 | from tree_sitter import Language, Parser 2 | 3 | Language.build_library( 4 | 'build/lang.so', 5 | [ 6 | './tree-sitter-python' 7 | ] 8 | ) 9 | LANGUAGE = Language('build/lang.so', 'python') 10 | 11 | 12 | QUERY = LANGUAGE.query(""" 13 | (function_definition name: (identifier) @fn-name) 14 | """) 15 | 16 | 17 | global_parser = Parser() 18 | global_parser.set_language(LANGUAGE) 19 | 20 | 21 | def get_fn_name(code, parser=global_parser): 22 | src = bytes(code, "utf8") 23 | tree = parser.parse(src) 24 | node = tree.root_node 25 | for cap, typ in QUERY.captures(node): 26 | if typ == "fn-name": 27 | return node_to_string(src, cap) 28 | return None 29 | 30 | 31 | def node_to_string(src: bytes, node): 32 | return src[node.start_byte:node.end_byte].decode("utf8") 33 | 34 | 35 | def make_parser(): 36 | _parser = Parser() 37 | _parser.set_language(LANGUAGE) 38 | return _parser 39 | 40 | 41 | RETURN_QUERY = LANGUAGE.query(""" 42 | (return_statement) @return 43 | """) 44 | 45 | 46 | def does_have_return(src, parser=global_parser): 47 | tree = parser.parse(bytes(src, "utf8")) 48 | root = tree.root_node 49 | captures = RETURN_QUERY.captures(root) 50 | for node, _ in captures: 51 | # if it doesn't have an argument, it's not a return with a value 52 | if len(node.children) <= 1: # includes "return" itself 53 | continue 54 | else: 55 | return True 56 | 57 | return False 58 | 59 | 60 | if __name__ == "__main__": 61 | code = """ 62 | import ble 63 | from a import b 64 | """ 65 | print(global_parser.parse(bytes(code, "utf8")).root_node.sexp()) 66 | -------------------------------------------------------------------------------- /src/star_align/decontamination/utils.py: -------------------------------------------------------------------------------- 1 | """Migrated from: https://github.com/bigcode-project/bigcode-dataset. License: Apache 2.0""" 2 | 3 | import time 4 | from multiprocessing import Pool 5 | 6 | from tqdm import tqdm 7 | 8 | 9 | def save_shard(shard_tuple): 10 | """Save shard""" 11 | filename, shard = shard_tuple 12 | # use to_json instead to save as json file 13 | shard.to_parquet(filename) 14 | 15 | 16 | def shard_dataset(ds, shard_size, output_dir, num_proc): 17 | if ds._indices is not None: 18 | dataset_nbytes = ds.data.nbytes * len(ds._indices) / len(ds.data) 19 | else: 20 | dataset_nbytes = ds.data.nbytes 21 | num_shards = int(dataset_nbytes / shard_size) + 1 22 | print(f"Number of shards: {num_shards}") 23 | 24 | print("sharding the dataset") 25 | t_start = time.time() 26 | shards = ( 27 | ds.shard(num_shards=num_shards, index=i, contiguous=True) 28 | for i in range(num_shards) 29 | ) 30 | # use f"{OUT_PATH}/data/train-{index:05d}-of-{num_shards:05d}.json" instead for json files 31 | filenames = ( 32 | f"{output_dir}/train-{index:05d}-of-{num_shards:05d}.parquet" 33 | for index in range(num_shards) 34 | ) 35 | 36 | with Pool(num_proc) as p: 37 | list( 38 | tqdm( 39 | p.imap_unordered(save_shard, zip(filenames, shards), chunksize=4), 40 | total=num_shards, 41 | ) 42 | ) 43 | print(f"Time to save dataset: {time.time()-t_start:.2f}") 44 | 45 | 46 | def add_dict(dict1: dict, dict2: dict) -> None: 47 | """ 48 | Add the values of dict2 to dict1. All values must be int, float or dictionaries that also verify this condition. 49 | Will modify dict1 and return None 50 | """ 51 | for key, value in dict2.items(): 52 | if isinstance(value, (int, float)): 53 | if key not in dict1: 54 | dict1[key] = 0 55 | dict1[key] += value 56 | elif isinstance(value, dict): 57 | if key not in dict1: 58 | dict1[key] = {} 59 | assert isinstance(dict1[key], dict) 60 | add_dict(dict1[key], value) 61 | else: 62 | raise ValueError(f"Invalid type for key/value {key}: {value}") 63 | -------------------------------------------------------------------------------- /self_ossinstruct_sc2_parallel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "MODE: $MODE" 4 | echo "SEED_DATA_FILE: $SEED_DATA_FILE" 5 | echo "INDEX: $INDEX" 6 | echo "MAX_NEW_DATA: $MAX_NEW_DATA" 7 | echo "DIR: $1" 8 | 9 | NUM_GPUS=$(nvidia-smi --query-gpu=count --format=csv,noheader,nounits | head -n 1) 10 | 11 | DATA_CHUNK_SIZE=$(($MAX_NEW_DATA / $NUM_GPUS)) 12 | REMAINDER=$(($MAX_NEW_DATA % $NUM_GPUS)) 13 | 14 | if [[ "$MODE" == "I->R" ]]; then 15 | N_SAMPLES=1 16 | NUM_FEWSHOTS=1 17 | NUM_BATCHED_REQUESTS=4096 18 | ASYNC_MICRO_BATCH_SIZE=16 19 | else 20 | N_SAMPLES=1 21 | NUM_FEWSHOTS=8 22 | NUM_BATCHED_REQUESTS=4096 23 | ASYNC_MICRO_BATCH_SIZE=8 24 | fi 25 | 26 | echo "N_SAMPLES: $N_SAMPLES" 27 | echo "NUM_FEWSHOTS: $NUM_FEWSHOTS" 28 | echo "NUM_BATCHED_REQUESTS: $NUM_BATCHED_REQUESTS" 29 | echo "ASYNC_MICRO_BATCH_SIZE: $ASYNC_MICRO_BATCH_SIZE" 30 | 31 | PIDS=() 32 | function killall_pids { 33 | for pid in ${PIDS[@]}; do 34 | kill $pid 35 | done 36 | } 37 | trap killall_pids SIGINT SIGTERM 38 | 39 | for (( GPU_ID=0; GPU_ID<$NUM_GPUS; GPU_ID++ )) 40 | do 41 | START_INDEX=$(($INDEX + $GPU_ID * $DATA_CHUNK_SIZE)) 42 | if [[ $GPU_ID -lt $REMAINDER ]]; then 43 | CHUNK_SIZE=$(($DATA_CHUNK_SIZE + 1)) 44 | else 45 | CHUNK_SIZE=$DATA_CHUNK_SIZE 46 | fi 47 | END_INDEX=$(($START_INDEX + $CHUNK_SIZE - 1)) 48 | 49 | echo "Starting process for GPU $GPU_ID with data from $START_INDEX to $END_INDEX..." 50 | 51 | OUTDIR="$1/$GPU_ID" 52 | mkdir -p $OUTDIR 53 | 54 | CUDA_VISIBLE_DEVICES=$GPU_ID python -m star_align.self_ossinstruct \ 55 | --async_micro_batch_size $ASYNC_MICRO_BATCH_SIZE \ 56 | --use_vllm_server False \ 57 | --instruct_mode "$MODE" \ 58 | --seed_data_files $SEED_DATA_FILE \ 59 | --max_new_data $CHUNK_SIZE \ 60 | --tag sc2-${NUM_FEWSHOTS}shot \ 61 | --temperature 0.7 \ 62 | --seed_code_start_index $START_INDEX \ 63 | --model bigcode/starcoder2-15b \ 64 | --num_fewshots $NUM_FEWSHOTS \ 65 | --num_batched_requests $NUM_BATCHED_REQUESTS \ 66 | --num_sample_per_request $N_SAMPLES \ 67 | --save_dir $OUTDIR & 68 | PIDS+=($!) 69 | done 70 | 71 | wait 72 | 73 | if [[ $? -ne 0 ]]; then 74 | echo "Error in one of the processes. Exiting... Check logs for more details." 75 | exit 1 76 | fi 77 | 78 | # dir for final res 79 | FINAL="$1/final" 80 | FINAL_FILE="$FINAL/aggregated-${MODE}.jsonl" 81 | 82 | echo "All processes finished. Aggregating results... to $FINAL_FILE" 83 | 84 | # aggregate! 85 | mkdir -p $FINAL 86 | touch $FINAL_FILE 87 | 88 | for (( GPU_ID=0; GPU_ID<$NUM_GPUS; GPU_ID++ )) 89 | do 90 | # get first file for dir 91 | FILE=$(ls $1/$GPU_ID | head -n 1) 92 | cat $1/$GPU_ID/$FILE >> $FINAL_FILE 93 | done 94 | 95 | echo "Done!" 96 | -------------------------------------------------------------------------------- /seed_gathering/README.md: -------------------------------------------------------------------------------- 1 | # Code for gathering seed functions 2 | 3 | The pipeline for gathering seed functions is composed of three scripts, to be run in the following order: 4 | 5 | 1. `./generate_from_the_stack.py`: Gathers unfiltered seed functions with docstrings from The Stack v1. 6 | 2. `./high_quality_subset.py`: Transforms the seed functions using `autoimport`, and filters them by checking for a return statement and type-checking them with Pyright. 7 | 3. `./filter_dataset.py`: Further filters the seed functions by decontaminating the dataset, using StarCoder2-15B as a judge to remove bad examples, and using a set of static heuristics. 8 | 9 | ## 1. Generate from the Stack 10 | 11 | In this step, we simply extract all functions from The Stack v1 (dedup) that have docstrings using tree-sitter. This is done by running the following command: 12 | 13 | ```bash 14 | python3 generate_from_the_stack.py --push "/" 15 | ``` 16 | 17 | This step may take a couple hours depending on your hardware. The resulting dataset will be pushed to the specified Hugging Face repository. 18 | 19 | After running the command, we also run near-deduplication with MinHash, LSH, and Jaccard Similarity of 0.5. 20 | We utilize the following repository for this step: https://github.com/ChenghaoMou/text-dedup 21 | 22 | The dataset resulting from this step can be found here: https://huggingface.co/datasets/bigcode/stack-dedup-python-fns 23 | 24 | ## 2. High-quality subset 25 | 26 | Here, we take the previously generated dataset and filter it and transform it using a set of heuristics. 27 | We run the following steps: 28 | 29 | 1. We filter all functions which do not have a "return statement". This is done such that our execution 30 | filtering step in the instruction-generation pipeline does not have to deal with functions that do not return anything, which 31 | are hard to test. Another benefit is that this strengthens the type-checking step, as we can now validate the return type for all functions. 32 | 2. We infer imports for the functions using `autoimport`. Such that our standalone functions now correctly import any required modules. 33 | 3. We type-check each function using Pyright. This is done to filter any functions that may reference undefined identifiers or have static type errors (as detected by 34 | Pyright). 35 | 36 | The dataset resulting from this step can be found here: https://huggingface.co/datasets/bigcode/python-stack-v1-functions-filtered 37 | 38 | ## 3. Filter dataset 39 | 40 | Now, we further filter the dataset generated in in the previous with different methods: 41 | 42 | 1. We remove functions that contain a set of words that are likely to be bad examples (e.g. "TODO"). 43 | 2. We also remove functions that import problematic packages, which can lead to issues in execution filtering (e.g. `os` or `sys`). 44 | 3. We remove functions which contains either solutions or prompts of benchmarks on which we evaluated the models. 45 | 4. We filter out functions that do not have any arguments, as these are likely to be bad examples in this constrained setting. 46 | 5. Finally, we utilize the base model StarCoder2-15B as a classifier to remove any examples that has bad documentation or low-quality code. 47 | 48 | The dataset resulting from this step can be found here: https://huggingface.co/datasets/bigcode/python-stack-v1-functions-filtered-sc2 49 | -------------------------------------------------------------------------------- /evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | 3 | > [!IMPORTANT] 4 | > **General requirements** 5 | > 6 | > Before you start, make sure you have cloned the repository and you are in the **root directory of the project**. Make sure you installed the required packages with `pip install -e .`. Different package versions may impact the reproducibility of the results. 7 | 8 | ## Running EvalPlus with vLLM 9 | 10 | We implemented batched inference in [evaluation/text2code_vllm.py] using [vLLM](https://docs.vllm.ai/en/latest/). This speed up the evaluation significantly: **a greedy decoding run can be finished within 20 seconds**. Here is the command: 11 | 12 | ```bash 13 | MODEL=/path/to/your/model 14 | DATASET=humaneval # or mbpp 15 | SAVE_PATH=evalplus-$(basename $MODEL)-$DATASET.jsonl 16 | CUDA_VISIBLE_DEVICES=0 python -m evaluation.text2code_vllm \ 17 | --model_key $MODEL \ 18 | --dataset $DATASET \ 19 | --save_path $SAVE_PATH 20 | 21 | python -m evalplus.evaluate --dataset $DATASET --samples $SAVE_PATH 22 | ``` 23 | 24 | ## Reproduce StarCoder2-Instruct 25 | 26 | > [!NOTE] 27 | > 28 | > We obtained the results with the subsequent hardware and environment: 29 | > 30 | > - One NVIDIA A100 80G GPU 31 | > - Python 3.10.0 32 | > 33 | > In case you face issues, we provide the raw outputs we generated in the [evalplus_results](evalplus_results) directory. 34 | 35 | ### Reproduce HumanEval(+) and MBPP(+) 36 | 37 | We pack multiple problems into one batch to speed up the inference. A different batch size may lead to slightly worse/better results due to the floating point round off resulted from the underlying [cuBLAS](https://docs.nvidia.com/cuda/cublas/index.html) optimization. 38 | 39 | Make sure you set `CUDA_VISIBLE_DEVICES` to the GPU you want to use and `cd`ed to the root directory of the repo. We assume you use device 0 in the following commands. 40 | 41 | #### HumanEval(+) 42 | 43 | ```bash 44 | MODEL_KEY=bigcode/starcoder2-15b-instruct-v0.1 45 | MODEL=bigcode/starcoder2-15b-instruct-v0.1 46 | DATASET=humaneval 47 | SAVE_PATH=evalplus-$(basename $MODEL)-$DATASET.jsonl 48 | CUDA_VISIBLE_DEVICES=0 49 | 50 | CUDA_VISIBLE_DEVICES=0 python -m evaluation.text2code \ 51 | --model_key $MODEL_KEY \ 52 | --model_name_or_path $MODEL \ 53 | --save_path $SAVE_PATH \ 54 | --dataset $DATASET \ 55 | --temperature 0.0 \ 56 | --top_p 1.0 \ 57 | --max_new_tokens 512 \ 58 | --n_problems_per_batch 16 \ 59 | --n_samples_per_problem 1 \ 60 | --n_batches 1 61 | 62 | python -m evalplus.evaluate --dataset $DATASET --samples $SAVE_PATH 63 | # humaneval (base tests) 64 | # pass@1: 0.726 65 | # humaneval+ (base + extra tests) 66 | # pass@1: 0.634 67 | ``` 68 | 69 | #### MBPP(+) 70 | 71 | ```bash 72 | MODEL_KEY=bigcode/starcoder2-15b-instruct-v0.1 73 | MODEL=bigcode/starcoder2-15b-instruct-v0.1 74 | DATASET=mbpp 75 | SAVE_PATH=evalplus-$(basename $MODEL)-$DATASET.jsonl 76 | 77 | CUDA_VISIBLE_DEVICES=0 python -m evaluation.text2code \ 78 | --model_key $MODEL_KEY \ 79 | --model_name_or_path $MODEL \ 80 | --save_path $SAVE_PATH \ 81 | --dataset $DATASET \ 82 | --temperature 0.0 \ 83 | --top_p 1.0 \ 84 | --max_new_tokens 512 \ 85 | --n_problems_per_batch 16 \ 86 | --n_samples_per_problem 1 \ 87 | --n_batches 1 88 | 89 | python -m evalplus.evaluate --dataset $DATASET --samples $SAVE_PATH 90 | # mbpp (base tests) 91 | # pass@1: 0.642 92 | # mbpp+ (base + extra tests) 93 | # pass@1: 0.526 94 | ``` 95 | -------------------------------------------------------------------------------- /src/star_align/clean_data.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import random 3 | from dataclasses import dataclass, field 4 | from pathlib import Path 5 | from typing import cast 6 | 7 | from tqdm.auto import tqdm 8 | from transformers import HfArgumentParser 9 | 10 | from star_align.utils import find_code_blocks, read_jsonl, write_jsonl 11 | 12 | 13 | @dataclass(frozen=True) 14 | class Args: 15 | data_files: list[str] 16 | output_file: str 17 | diversify_func_names: bool = field(default=False) 18 | 19 | 20 | def extract_and_concat_function_names(python_content): 21 | """ 22 | Extracts all function names from a given Python content string and concatenates them into a single string. 23 | 24 | Parameters: 25 | - python_content: A string containing the Python code to analyze. 26 | 27 | Returns: 28 | - A string containing all function names defined in the content, concatenated. 29 | """ 30 | tree = ast.parse(python_content) 31 | function_names = [] 32 | 33 | # Define a node visitor that adds the name of each function definition it visits 34 | class FunctionDefVisitor(ast.NodeVisitor): 35 | def visit_FunctionDef(self, node): 36 | function_names.append(node.name) 37 | # Process the subtree for this node 38 | self.generic_visit(node) 39 | 40 | def visit_AsyncFunctionDef(self, node): 41 | function_names.append(node.name) 42 | self.generic_visit(node) 43 | 44 | # Create a node visitor and walk through the AST 45 | visitor = FunctionDefVisitor() 46 | visitor.visit(tree) 47 | 48 | # Concatenate all function names into a single string 49 | return " ".join(function_names) 50 | 51 | 52 | def main(): 53 | args = cast(Args, HfArgumentParser(Args).parse_args_into_dataclasses()[0]) 54 | raw_data: list[dict] = [] 55 | for data_file in args.data_files: 56 | data = read_jsonl(Path(data_file)) 57 | # language = data_file.split("-")[1] 58 | # assert language in ALL_LANGS, f"Unknown language {language}" 59 | # raw_data.extend(dict(lang=language, **d) for d in data) 60 | raw_data.extend(data) 61 | # common keys for all d in data 62 | common_keys = set.intersection(*(set(d.keys()) for d in raw_data)) 63 | raw_data = [{k: d[k] for k in common_keys} for d in raw_data] 64 | print(f"Common keys: {common_keys}") 65 | # counter = defaultdict[str, int](int) 66 | 67 | def mk_key(instruction: str) -> str: 68 | return "".join(instruction.split()) 69 | 70 | random.seed(0) 71 | random.shuffle(raw_data) 72 | 73 | seen_keys = set[str]() 74 | new_data = list[dict]() 75 | for d in tqdm(raw_data): 76 | key_i, key_r = mk_key(d["instruction"]), mk_key(d["response"]) 77 | if key_i in seen_keys or key_r in seen_keys: 78 | continue 79 | if args.diversify_func_names: 80 | code_block = find_code_blocks(d["response"])[0] 81 | try: 82 | fn_names = extract_and_concat_function_names(code_block) 83 | except SyntaxError: 84 | continue 85 | if fn_names in seen_keys: 86 | continue 87 | seen_keys.add(fn_names) 88 | new_data.append(d) 89 | seen_keys.add(key_i) 90 | seen_keys.add(key_r) 91 | 92 | print(f"Chose {len(new_data)} out of {len(raw_data)}") 93 | write_jsonl(Path(args.output_file), new_data) 94 | 95 | 96 | if __name__ == "__main__": 97 | main() 98 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | .vscode 3 | /*.jsonl 4 | /*.json 5 | /*.bak 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Demo file 171 | flagged/ 172 | *.sh 173 | !sanitize.sh 174 | logs/ 175 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SelfCodeAlign: Self-Alignment for Code Generation 2 | 3 |

4 | Paper 5 | 6 | 7 | 8 | 9 |

10 | 11 |

12 | 🧐 About 13 | | ⭐️ StarCoder2-Instruct 14 | | 📝 Citation 15 | 16 | 17 | 18 | 19 | 20 |

21 | 22 | ## About 23 | 24 | **SelfCodeAlign** is the first fully open and transparent pipeline that enhances a code language model without relying on human annotations or distilled data from large, proprietary models. This approach led to the creation of [StarCoder2-Instruct](README-SC2INST.md), a fully transparent, permissively licensed, self-aligned code model that achieves state-of-the-art performance in coding tasks. 25 | 26 | **Authors:** 27 | [Yuxiang Wei](https://yuxiang.cs.illinois.edu), 28 | [Federico Cassano](https://federico.codes/), 29 | [Jiawei Liu](https://jw-liu.xyz), 30 | [Yifeng Ding](https://yifeng-ding.com), 31 | [Naman Jain](https://naman-ntc.github.io), 32 | [Zachary Mueller](https://muellerzr.github.io), 33 | [Harm de Vries](https://www.harmdevries.com), 34 | [Leandro von Werra](https://twitter.com/lvwerra), 35 | [Arjun Guha](https://www.khoury.northeastern.edu/home/arjunguha/main/homehttps://www.khoury.northeastern.edu/home/arjunguha/main/home//), 36 | [Lingming Zhang](https://lingming.cs.illinois.edu). 37 | 38 | ![self-alignment pipeline](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/SelfCodeAlign.png) 39 | 40 | ## StarCoder2-Instruct 41 | 42 | ![Banner](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/banner.png) 43 | 44 | StarCoder2-Instruct is created with an [earlier version](https://github.com/bigcode-project/selfcodealign/tree/starcoder2-instruct) of SelfCodeAlign. It is the very first entirely self-aligned code Large Language Model (LLM) trained with a fully permissive and transparent pipeline. Our open-source pipeline uses StarCoder2-15B to generate thousands of instruction-response pairs, which are then used to fine-tune StarCoder-15B itself without any human annotations or distilled data from huge and proprietary LLMs. 45 | 46 | - **Model:** [bigcode/starcoder2-15b-instruct-v0.1](https://huggingface.co/bigcode/starcoder2-instruct-15b-v0.1) 47 | - **Code:** [bigcode-project/starcoder2-self-align](https://github.com/bigcode-project/starcoder2-self-align) 48 | - **Dataset:** [bigcode/self-oss-instruct-sc2-exec-filter-50k](https://huggingface.co/datasets/bigcode/self-oss-instruct-sc2-exec-filter-50k/) 49 | 50 | For more details, check [README-SC2INST.md](README-SC2INST.md). 51 | 52 | ## Citation 53 | 54 | ```bibtex 55 | @article{wei2024selfcodealign, 56 | title={SelfCodeAlign: Self-Alignment for Code Generation}, 57 | author={Yuxiang Wei and Federico Cassano and Jiawei Liu and Yifeng Ding and Naman Jain and Zachary Mueller and Harm de Vries and Leandro von Werra and Arjun Guha and Lingming Zhang}, 58 | year={2024}, 59 | journal={arXiv preprint arXiv:2410.24198} 60 | } 61 | ``` 62 | -------------------------------------------------------------------------------- /seed_gathering/benchmark_data.py: -------------------------------------------------------------------------------- 1 | """data to filter out of the dataset""" 2 | import json 3 | import itertools 4 | from pathlib import Path 5 | 6 | from datasets import load_dataset 7 | 8 | 9 | TEST_IDS = list(range(11, 511)) 10 | 11 | # HumanEval solutions that are considered simple/generic enough to be kept in the training dataset 12 | HUMAN_EVAL_STRINGS_OK = [ 13 | 'return x + y', 'return len(string)', 'return n**2', 'return ''.join(strings)'] 14 | 15 | DS_1000_PATH = Path("/data/ds-1000/ds1000_data/") 16 | 17 | 18 | def extract_ds_1000_prompt(prompt: str): 19 | if "SOLUTION START" in prompt: 20 | assert prompt.count("SOLUTION START") == 1 21 | return prompt.split("SOLUTION START")[0] 22 | elif "BEGIN SOLUTION" in prompt: 23 | assert prompt.count("BEGIN SOLUTION") == 1 24 | return prompt.split("BEGIN SOLUTION")[0] 25 | else: 26 | raise ValueError() 27 | 28 | 29 | def load_ds_1000(): 30 | data = [] 31 | for prompt_file in DS_1000_PATH.glob("*/Insertion/q*/prompt.txt"): 32 | with open(prompt_file) as f: 33 | data.append(extract_ds_1000_prompt(f.read())) 34 | return data 35 | 36 | 37 | def load_mbpp(): 38 | dataset = load_dataset("mbpp", "sanitized", split="train") 39 | return dataset 40 | 41 | 42 | def mbpp_docstrings(): 43 | data = load_mbpp() 44 | return [sample["prompt"] for sample in data] 45 | 46 | 47 | def mbpp_solutions(): 48 | data = load_mbpp() 49 | return [sample["code"] for sample in data] 50 | 51 | 52 | def extract_docstring(prompt: str) -> str: 53 | if '"""' in prompt: 54 | if prompt.count('"""') == 2: 55 | return prompt.split('"""')[1].strip() 56 | elif prompt.count('"""') == 4: 57 | return prompt.split('"""')[3].strip() 58 | else: 59 | raise ValueError() 60 | elif '\'\'\'' in prompt: 61 | assert prompt.count('\'\'\'') == 2 62 | return prompt.split('\'\'\'')[1].strip() 63 | else: 64 | raise ValueError() 65 | 66 | 67 | def human_eval_docstrings(): 68 | ds = load_dataset("openai_humaneval", split="test") 69 | docstrings = [extract_docstring(v['prompt']) for v in ds] 70 | return docstrings 71 | 72 | 73 | def apps_solutions(): 74 | """ 75 | Solutions column contains a list of strings 76 | """ 77 | ds = load_dataset("codeparrot/apps", split="test") 78 | solutions = [sample["solutions"] 79 | for sample in ds if len(sample["solutions"]) > 0] 80 | res = itertools.chain.from_iterable( 81 | json.loads(sample) for sample in solutions) 82 | return list(res) 83 | 84 | 85 | def multipl_e_docstrings(): 86 | languages = [ 87 | "cpp", "cs", "d", "go", "java", "jl", "js", "lua", "php", "pl", "py", "r", 88 | "rb", "rkt", "rs", "scala", "sh", "swift", "ts" 89 | ] 90 | # languages = ["py", "java", "js"] 91 | src_datas = ["humaneval", "mbpp"] 92 | variations = ["", "-remove"] 93 | data = [] 94 | for lang in languages: 95 | for src_data in src_datas: 96 | for variation in variations: 97 | if src_data == "mbpp" and variation == "-remove": 98 | continue 99 | ds = load_dataset( 100 | "nuprl/MultiPL-E", f"{src_data}-{lang}{variation}", split="test") 101 | data += [sample["prompt"].strip() for sample in ds] 102 | return data 103 | 104 | 105 | def load_dataset_column(dataset: str, column: str, split: str, name=None): 106 | ds = load_dataset(dataset, split=split, name=name) 107 | res = [sample[column].strip() for sample in ds] 108 | # Only return non-empty strings 109 | return [sample for sample in res if len(sample) > 0] 110 | 111 | 112 | def filter_out(): 113 | FILTER_OUT = { 114 | "mbpp_docstrings": mbpp_docstrings(), 115 | "mbpp_solutions": mbpp_solutions(), 116 | "human_eval_docstrings": human_eval_docstrings(), 117 | "human_eval_solutions": [ 118 | s for s in load_dataset_column("openai_humaneval", "canonical_solution", "test") 119 | if s not in HUMAN_EVAL_STRINGS_OK 120 | ], 121 | } 122 | 123 | for benchmark, values in FILTER_OUT.items(): 124 | print(f"num strings from {benchmark}: {len(values)}") 125 | 126 | return FILTER_OUT 127 | -------------------------------------------------------------------------------- /seed_gathering/generate_from_the_stack.py: -------------------------------------------------------------------------------- 1 | from tree_sitter_parser import LANGUAGE, make_parser, node_to_string 2 | import datasets 3 | import os 4 | import signal 5 | from multiprocessing import Pool 6 | 7 | 8 | TOPLEVEL_DOCSTRING_QUERY = LANGUAGE.query(""" 9 | ( 10 | (function_definition 11 | name: (identifier) 12 | body: (block . 13 | (expression_statement 14 | (string 15 | (string_start) @docstring.start 16 | (string_content) 17 | (string_end) @docstring.end)))) @function.def 18 | (#eq? @docstring.start "\\\"\\\"\\\"") 19 | (#eq? @docstring.end "\\\"\\\"\\\"") 20 | ) 21 | """) 22 | 23 | 24 | def get_fns_with_docstrings(src, tree): 25 | captures = TOPLEVEL_DOCSTRING_QUERY.captures(tree.root_node) 26 | res = [] 27 | for capture in captures: 28 | node, ty = capture 29 | if ty != "function.def": 30 | continue 31 | # if the starting col is not 0, then it's not a top-level fn 32 | _, col = node.start_point 33 | if col != 0: 34 | continue 35 | res.append(node_to_string(src, node)) 36 | return res 37 | 38 | 39 | def parse_ex(parser, ex): 40 | ex = ex["content"] 41 | try: 42 | buf = bytes(ex, "utf8") 43 | tree = parser.parse(buf) 44 | return get_fns_with_docstrings(buf, tree) 45 | except: 46 | return [] 47 | 48 | 49 | # if one parser segfaults, we can just make a new one and other parsers will still be fine 50 | # WE LOVE TREE SITTER! 51 | PARSERS = None 52 | 53 | 54 | def process_chunk(idx_and_chunk): 55 | assert PARSERS is not None 56 | idx, chunk = idx_and_chunk 57 | parser = PARSERS[idx] 58 | chunk_new_funs = set() 59 | for ex in chunk: 60 | chunk_new_funs.update(parse_ex(parser, ex)) 61 | return chunk_new_funs 62 | 63 | 64 | def main(args): 65 | global PARSERS 66 | ds = datasets.load_dataset( 67 | args.dataset, 68 | data_dir=args.data_dir, 69 | split="train", 70 | ) 71 | funs = set() 72 | PARSERS = [make_parser() for _ in range(args.num_workers)] 73 | total_len = len(ds) 74 | CHUNK_SIZE = 1000 * args.num_workers 75 | 76 | print(f"Total length: {total_len}") 77 | print(f"Chunk size: {CHUNK_SIZE}") 78 | 79 | chunk = [] 80 | p = Pool(args.num_workers) 81 | for i, ex in enumerate(ds): 82 | if i % (total_len // 100) == 0: 83 | print(f"{i}/{total_len}") 84 | try: 85 | chunk.append(ex) 86 | if len(chunk) == CHUNK_SIZE or i == total_len - 1: 87 | print(f"Processing chunk {i // CHUNK_SIZE}") 88 | # divide the chunk into NUM_WORKERS chunks 89 | subchunk_size = len(chunk) // args.num_workers 90 | subchunks = [chunk[i:i + subchunk_size] 91 | for i in range(0, len(chunk), subchunk_size)] 92 | new_funs_iter = p.imap( 93 | process_chunk, [(i, subchunk) for i, subchunk in enumerate(subchunks)]) 94 | print("Getting new functions") 95 | len_before = len(funs) 96 | while True: 97 | try: 98 | def timeout_handler(_, __): 99 | raise KeyboardInterrupt # it's fineeeeeee 100 | signal.signal(signal.SIGALRM, timeout_handler) 101 | signal.alarm(60) 102 | funs.update(next(new_funs_iter)) 103 | signal.alarm(0) 104 | except KeyboardInterrupt: 105 | signal.alarm(0) 106 | print("Keyboard interrupt. Terminating pool") 107 | p.terminate() 108 | p = Pool(args.num_workers) 109 | break 110 | except StopIteration: 111 | break 112 | except Exception as e: 113 | print(e) 114 | 115 | signal.alarm(0) 116 | 117 | PARSERS = [make_parser() for _ in range(args.num_workers)] 118 | 119 | print( 120 | f"Done processing chunk {i // CHUNK_SIZE}. Got {len(funs) - len_before} new functions") 121 | 122 | chunk = [] 123 | except Exception as e: 124 | print(e) 125 | chunk = [] 126 | 127 | if i == total_len - 1: 128 | break 129 | 130 | p.close() 131 | 132 | new_ds_dict = { 133 | "content": list(funs), 134 | "id": list(range(len(funs))) 135 | } 136 | 137 | new_ds = datasets.Dataset.from_dict(new_ds_dict) 138 | new_ds.push_to_hub(args.push, private=True) 139 | 140 | 141 | if __name__ == "__main__": 142 | import argparse 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument("--num_workers", type=int, default=os.cpu_count()) 145 | parser.add_argument("--dataset", type=str, 146 | default="bigcode/the-stack-dedup") 147 | parser.add_argument("--data_dir", type=str, default="data/python") 148 | parser.add_argument("--push", type=str, required=True) 149 | args = parser.parse_args() 150 | main(args) 151 | -------------------------------------------------------------------------------- /evaluation/text2code.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | if __name__ == "__main__": 4 | # Deprecate warning 5 | warnings.warn( 6 | "This module is deprecated. Use `evaluation.text2code_vllm` instead.", 7 | DeprecationWarning, 8 | ) 9 | # Press y to continue 10 | if input("Do you want to continue? [y/N]: ").lower() != "y": 11 | exit() 12 | 13 | import itertools 14 | from dataclasses import dataclass 15 | from pathlib import Path 16 | from typing import Literal, TypedDict, cast 17 | from evalplus.data import get_human_eval_plus, get_mbpp_plus, write_jsonl 18 | from tqdm.auto import tqdm 19 | from transformers import HfArgumentParser 20 | 21 | from star_align.llm_wrapper import GenerationConfig, get_model_context 22 | from star_align.prompt_template import SC2_INSTRUCT_PROMPT as PROMPT_TEMPLATE 23 | from star_align.utils import chunked 24 | 25 | 26 | class Text2CodeProblem(TypedDict): 27 | id: str 28 | instruction: str 29 | response_prefix: str 30 | 31 | 32 | def get_mbpp_raw_problems() -> list[dict]: 33 | problems = get_mbpp_plus() 34 | return list(problems.values()) 35 | 36 | 37 | def get_humaneval_raw_problems() -> list[dict]: 38 | problems = get_human_eval_plus() 39 | return list(problems.values()) 40 | 41 | 42 | def map_mbpp_problem(p: dict) -> Text2CodeProblem: 43 | id = p["task_id"] 44 | prompt = p["prompt"] 45 | start_index = prompt.index('"""') 46 | end_index = prompt.rindex('"""') 47 | prompt = prompt[start_index + 3 : end_index] 48 | assert_index = prompt.index("assert") 49 | instruction = prompt[:assert_index].strip() 50 | if not instruction.endswith("."): 51 | instruction += "." 52 | assertion = prompt[assert_index:].strip() 53 | instruction = f"""{instruction} 54 | 55 | Your code should pass the following assertion: 56 | ```python 57 | {assertion} 58 | ```""" 59 | prefix = "" if PROMPT_TEMPLATE.endswith("\n") else "\n" 60 | response_prefix = f"""{prefix}```python""" 61 | return Text2CodeProblem( 62 | id=str(id), instruction=instruction, response_prefix=response_prefix 63 | ) 64 | 65 | 66 | def map_humaneval_problem(p: dict) -> Text2CodeProblem: 67 | id = p["task_id"] 68 | prompt = p["prompt"] 69 | prompt = prompt.strip() 70 | prompt_header = "Write a Python function to solve the given task:" 71 | instruction = f"""{prompt_header} 72 | ```python 73 | {prompt} 74 | ```""" 75 | prefix = "" if PROMPT_TEMPLATE.endswith("\n") else "\n" 76 | prefix_template = "```python\n{prompt}" 77 | response_prefix = prefix + ( 78 | prefix_template.replace("{prompt}", prompt) 79 | if "{prompt}" in prefix_template 80 | else prefix_template 81 | ) 82 | return Text2CodeProblem( 83 | id=id, instruction=instruction, response_prefix=response_prefix 84 | ) 85 | 86 | 87 | @dataclass(frozen=True) 88 | class Args: 89 | model_key: str 90 | dataset: Literal["humaneval", "mbpp"] 91 | save_path: str 92 | 93 | n_batches: int 94 | n_problems_per_batch: int 95 | n_samples_per_problem: int 96 | 97 | model_name_or_path: str | None = None 98 | 99 | 100 | def main(): 101 | parser = HfArgumentParser((Args, GenerationConfig)) 102 | args, generation_config = cast( 103 | tuple[Args, GenerationConfig], 104 | parser.parse_args_into_dataclasses(), 105 | ) 106 | raw_problem_fn, map_problem_fn = ( 107 | (get_humaneval_raw_problems, map_humaneval_problem) 108 | if args.dataset == "humaneval" 109 | else (get_mbpp_raw_problems, map_mbpp_problem) 110 | ) 111 | raw_problems = raw_problem_fn() 112 | problems = list(map(map_problem_fn, raw_problems)) 113 | 114 | state = get_model_context(args.model_key, args.model_name_or_path) 115 | 116 | problems_chunked = list(chunked(list(problems), args.n_problems_per_batch)) 117 | iter = itertools.product(problems_chunked, range(args.n_batches)) 118 | n_total = len(problems_chunked) * args.n_batches 119 | 120 | Path(args.save_path).write_text("") 121 | for problems, batch_idx in tqdm(iter, total=n_total): 122 | task_ids = [problem["id"] for problem in problems] 123 | prompts = [ 124 | # TODO: make it generic for all models 125 | PROMPT_TEMPLATE.format( 126 | instruction=problem["instruction"], response=problem["response_prefix"] 127 | ) 128 | for problem in problems 129 | ] 130 | print("PROMPT") 131 | print(prompts[-1]) 132 | all_prompts = prompts * args.n_samples_per_problem 133 | all_task_ids = task_ids * args.n_samples_per_problem 134 | response = state.complete(generation_config, all_prompts, stop_tokens=["\n```"]) 135 | completions = response.decoded_outputs 136 | assert len(problems) <= args.n_problems_per_batch 137 | assert len(completions) == len(problems) * args.n_samples_per_problem 138 | print("COMPLETION") 139 | print(completions[-1]) 140 | samples = [ 141 | dict( 142 | task_id=task_id, 143 | completion=completion[ 144 | : ( 145 | index 146 | if (index := completion.find("```")) != -1 147 | else len(completion) 148 | ) 149 | ], 150 | ) 151 | for task_id, completion in zip(all_task_ids, completions) 152 | ] 153 | write_jsonl(args.save_path, samples, append=True) 154 | 155 | 156 | if __name__ == "__main__": 157 | main() 158 | -------------------------------------------------------------------------------- /seed_gathering/high_quality_subset.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import subprocess 3 | import tempfile 4 | import signal 5 | import hashlib 6 | import os 7 | import argparse 8 | from typing import List, Dict 9 | from tqdm import tqdm 10 | 11 | from tree_sitter_parser import LANGUAGE, global_parser 12 | 13 | RETURN_QUERY = LANGUAGE.query(""" 14 | (return_statement) @return 15 | """) 16 | 17 | 18 | def does_have_return(src): 19 | tree = global_parser.parse(bytes(src, "utf8")) 20 | root = tree.root_node 21 | captures = RETURN_QUERY.captures(root) 22 | for node, _ in captures: 23 | # if it doesn't have an argument, it's not a return with a value 24 | if len(node.children) <= 1: # includes "return" itself 25 | continue 26 | else: 27 | return True 28 | 29 | return False 30 | 31 | 32 | if __name__ == "__main__": 33 | print(does_have_return("def foo():\n return")) 34 | 35 | 36 | # runs pyright in the given directory, returns stdout 37 | # then, it logs the number of errors for each file 38 | def run_pyright(d): 39 | try: 40 | outs = subprocess.run( 41 | ["pyright", "*"], 42 | cwd=d, 43 | capture_output=True, 44 | timeout=120, 45 | text=True, 46 | ).stdout 47 | except Exception as e: 48 | print(e) 49 | return None 50 | 51 | cur_file = "" 52 | filemap = {} 53 | lines = outs.split("\n") 54 | for i, line in enumerate(lines): 55 | if i == len(lines) - 2: 56 | break 57 | 58 | if line.startswith(" "): 59 | if "- error:" in line: 60 | filemap[cur_file] += 1 61 | else: 62 | file = line.split("/")[-1] 63 | filemap[file] = 0 64 | cur_file = file 65 | 66 | return filemap 67 | 68 | 69 | def typecheck_batch(files: List[str]) -> Dict[str, str]: 70 | # Create a temporary directory using the tempfile module 71 | filemap: Dict[str, str] = {} 72 | with tempfile.TemporaryDirectory() as tempdir: 73 | for contents in files: 74 | hash_object = hashlib.sha1(bytes(contents, "utf8")) 75 | hex_dig = hash_object.hexdigest() 76 | filemap[hex_dig] = contents 77 | name = os.path.join(tempdir, hex_dig + ".py") 78 | with open(name, "w") as f: 79 | f.write(contents) 80 | 81 | # Run pyright in the temporary directory 82 | typecheck_map = run_pyright(tempdir) 83 | if typecheck_map is None: 84 | return {} 85 | 86 | for contents, errors in typecheck_map.items(): 87 | no_py = contents.replace(".py", "") 88 | if errors == 0: 89 | continue 90 | 91 | if no_py in filemap: 92 | del filemap[no_py] 93 | 94 | print(f"Pass rate: {len(filemap)}/{len(files)}") 95 | 96 | return filemap 97 | 98 | 99 | def infer_imports(code: str) -> str: 100 | import autoimport 101 | 102 | try: 103 | def handler(signum, frame): 104 | raise Exception("Timeout") 105 | signal.signal(signal.SIGALRM, handler) 106 | signal.alarm(10) 107 | inferred = autoimport.fix_code(code) 108 | signal.alarm(0) 109 | return inferred 110 | except Exception as e: 111 | signal.alarm(0) 112 | print(f"Error while inferring imports: {e}") 113 | return code 114 | 115 | 116 | def main(args): 117 | ds = datasets.load_dataset(args.dataset, 118 | data_dir="data", split="train") 119 | 120 | print("Filtering to only functions with return statements") 121 | ds = ds.filter(lambda ex: does_have_return( 122 | ex["content"]), num_proc=os.cpu_count()) 123 | 124 | if args.infer_imports: 125 | print("Inferring imports for functions") 126 | ds = ds.map(lambda ex: {"content": infer_imports( 127 | ex["content"])}, num_proc=os.cpu_count()) 128 | 129 | batch = [] 130 | max_i = len(ds) - 1 131 | 132 | new_ds = { 133 | "content": [], 134 | "sha1": [], 135 | "id": [], 136 | } 137 | 138 | e_id = 0 139 | 140 | for i, ex in enumerate(tqdm(ds, total=len(ds))): 141 | try: 142 | code = ex["content"] 143 | 144 | batch.append(code) 145 | 146 | if len(batch) == args.batch_size or i == max_i: 147 | filemap = typecheck_batch(batch) 148 | for sha1, contents in filemap.items(): 149 | new_ds["content"].append(contents) 150 | new_ds["sha1"].append(sha1) 151 | new_ds["id"].append(e_id) 152 | e_id += 1 153 | 154 | batch = [] 155 | except Exception as e: 156 | print(f"There was an error: {e}") 157 | continue 158 | 159 | new_ds_hf = datasets.Dataset.from_dict(new_ds) 160 | new_ds_hf.push_to_hub(args.push, private=True) 161 | 162 | 163 | if __name__ == "__main__": 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument("--dataset", type=str, 166 | help="Points to dataset of python functions with docstrings. Columns: 'content'", 167 | required=True) 168 | parser.add_argument( 169 | "--push", type=str, required=True, help="Push to this dataset to which repo") 170 | parser.add_argument( 171 | "--infer-imports", action="store_true", help="Infer imports for functions") 172 | parser.add_argument( 173 | "--batch-size", type=int, default=250, help="Batch size for typechecking") 174 | args = parser.parse_args() 175 | main(args) 176 | -------------------------------------------------------------------------------- /src/star_align/decontamination/benchmark_data.py: -------------------------------------------------------------------------------- 1 | """Migrated from: https://github.com/bigcode-project/bigcode-dataset. License: Apache 2.0""" 2 | 3 | """data to filter out of the dataset""" 4 | import itertools 5 | import json 6 | import os 7 | from pathlib import Path 8 | 9 | from datasets import load_dataset 10 | 11 | # HumanEval solutions that are considered simple/generic enough to be kept in the training dataset 12 | HUMAN_EVAL_STRINGS_OK = [ 13 | "return x + y", 14 | "return len(string)", 15 | "return n**2", 16 | "return " ".join(strings)", 17 | ] 18 | 19 | 20 | def extract_ds_1000_prompt(prompt: str): 21 | if "SOLUTION START" in prompt: 22 | assert prompt.count("SOLUTION START") == 1 23 | return prompt.split("SOLUTION START")[0] 24 | elif "BEGIN SOLUTION" in prompt: 25 | assert prompt.count("BEGIN SOLUTION") == 1 26 | return prompt.split("BEGIN SOLUTION")[0] 27 | else: 28 | raise ValueError() 29 | 30 | 31 | def load_ds_1000(): 32 | DS1000_PATH_NAME = os.getenv("DS1000_PATH", None) 33 | assert ( 34 | DS1000_PATH_NAME is not None 35 | ), "Please set the environment variable DS1000_PATH to the path of `ds1000_data`" 36 | DS1000_PATH = Path(DS1000_PATH_NAME) # type: ignore 37 | data: dict = {} 38 | for prompt_file in DS1000_PATH.glob("*/Insertion/q*/prompt.txt"): 39 | with open(prompt_file) as f: 40 | data[extract_ds_1000_prompt(f.read())] = prompt_file.as_posix() 41 | return data 42 | 43 | 44 | def load_mbpp(): 45 | MBPP_PATH_NAME = os.getenv("MBPP_PATH", None) 46 | assert ( 47 | MBPP_PATH_NAME is not None 48 | ), "Please set the environment variable MBPP_PATH to the path of `mbpp.jsonl`" 49 | MBPP_PATH = Path(MBPP_PATH_NAME) 50 | TEST_IDS = list(range(11, 511)) 51 | data = [] 52 | with open(MBPP_PATH) as f: 53 | for line in f: 54 | data.append(json.loads(line)) 55 | 56 | data = [sample for sample in data if sample["task_id"] in TEST_IDS] 57 | 58 | assert len(data) == 500 59 | 60 | # Checksum / version issues here 61 | # dataset = load_dataset("mbpp", split="test") 62 | return data 63 | 64 | 65 | def mbpp_docstrings(): 66 | data = load_mbpp() 67 | return {sample["text"]: str(sample["task_id"]) for sample in data} 68 | 69 | 70 | def mbpp_solutions(): 71 | data = load_mbpp() 72 | return {sample["code"]: str(sample["task_id"]) for sample in data} 73 | 74 | 75 | def extract_docstring(prompt: str) -> str: 76 | if '"""' in prompt: 77 | if prompt.count('"""') == 2: 78 | return prompt.split('"""')[1].strip() 79 | elif prompt.count('"""') == 4: 80 | return prompt.split('"""')[3].strip() 81 | else: 82 | raise ValueError() 83 | elif "'''" in prompt: 84 | assert prompt.count("'''") == 2 85 | return prompt.split("'''")[1].strip() 86 | else: 87 | raise ValueError() 88 | 89 | 90 | def human_eval_docstrings(): 91 | ds = load_dataset("openai_humaneval", split="test") 92 | docstrings = {extract_docstring(v["prompt"]): str(v["task_id"]) for v in ds} 93 | return docstrings 94 | 95 | 96 | def apps_solutions(): 97 | """ 98 | Solutions column contains a list of strings 99 | """ 100 | ds = load_dataset("codeparrot/apps", split="test") 101 | solutions = [sample["solutions"] for sample in ds if len(sample["solutions"]) > 0] 102 | res = itertools.chain.from_iterable(json.loads(sample) for sample in solutions) 103 | return list(res) 104 | 105 | 106 | def multipl_e_docstrings(): 107 | languages = [ 108 | "cpp", 109 | "cs", 110 | "d", 111 | "go", 112 | "java", 113 | "jl", 114 | "js", 115 | "lua", 116 | "php", 117 | "pl", 118 | "py", 119 | "r", 120 | "rb", 121 | "rkt", 122 | "rs", 123 | "scala", 124 | "sh", 125 | "swift", 126 | "ts", 127 | ] 128 | # languages = ["py", "java", "js"] 129 | src_datas = ["humaneval", "mbpp"] 130 | variations = ["", "-remove"] 131 | data = [] 132 | for lang in languages: 133 | for src_data in src_datas: 134 | for variation in variations: 135 | if src_data == "mbpp" and variation == "-remove": 136 | continue 137 | ds = load_dataset( 138 | "nuprl/MultiPL-E", f"{src_data}-{lang}{variation}", split="test" 139 | ) 140 | data += [sample["prompt"].strip() for sample in ds] 141 | return data 142 | 143 | 144 | def load_dataset_column(dataset: str, column: str, split: str, name=None): 145 | ds = load_dataset(dataset, split=split, name=name) 146 | # res = [sample[column].strip() for sample in ds] 147 | # Only return non-empty strings 148 | return { 149 | sample_col_stripped: str( 150 | sample["task_id"] if "task_id" in sample else f"{dataset}/{idx}" 151 | ) 152 | for idx, sample in enumerate(ds) 153 | if len(sample_col_stripped := sample[column].strip()) > 0 154 | } 155 | 156 | 157 | LAZY_FILTER_OUT = { 158 | "human_eval_docstrings": lambda: human_eval_docstrings(), 159 | "human_eval_solutions": lambda: { 160 | s: v 161 | for s, v in load_dataset_column( 162 | "openai_humaneval", "canonical_solution", "test" 163 | ).items() 164 | if s not in HUMAN_EVAL_STRINGS_OK 165 | }, 166 | # "apps_docstrings": lambda: load_dataset_column( 167 | # "codeparrot/apps", "question", "test" 168 | # ), 169 | # 115212 examples to filter-out in apps-solutions, which would take way too much time without any hashing trick 170 | # "apps_solutions": apps_solutions(), 171 | # MultiPL-E samples are from HumanEval and MBPP: we are already looking for them 172 | # "multipl-e_docstrings": multipl_e_docstrings(), 173 | # There is no solution provided with multipl-e 174 | "gsm8k_questions": lambda: load_dataset_column("gsm8k", "question", "test", "main"), 175 | "ds_1000_prompts": lambda: load_ds_1000(), 176 | "mbpp_docstrings": lambda: mbpp_docstrings(), 177 | "mbpp_solutions": lambda: mbpp_solutions(), 178 | } 179 | 180 | IGNORED = os.getenv("IGNORED", "").split(":") 181 | print("Ignoring:", IGNORED) 182 | for ignored in IGNORED: 183 | if ignored != "" and ignored in LAZY_FILTER_OUT: 184 | del LAZY_FILTER_OUT[ignored] 185 | FILTER_OUT = {k: v() for k, v in LAZY_FILTER_OUT.items()} 186 | 187 | 188 | for benchmark, values in FILTER_OUT.items(): 189 | print(f"num strings from {benchmark}: {len(values)}") 190 | -------------------------------------------------------------------------------- /evaluation/text2code_vllm.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from pathlib import Path 4 | from typing import Literal, TypedDict, cast 5 | from evalplus.data import get_human_eval_plus, get_mbpp_plus, write_jsonl 6 | 7 | from evoeval.data import get_evo_eval 8 | from transformers import HfArgumentParser 9 | 10 | from star_align.utils import infer_prompt_template, is_base_model 11 | 12 | from vllm import LLM, SamplingParams 13 | 14 | 15 | class Text2CodeProblem(TypedDict): 16 | id: str 17 | prompt: str 18 | instruction: str 19 | response_prefix: str 20 | 21 | 22 | # MBPP_INSTRUCTION = """{nl_description} Your code should satisfy the following assertion: 23 | # ```python 24 | # {assertions} 25 | # ``` 26 | # Enclose your solution in ```python and ```""" 27 | 28 | 29 | def get_mbpp_raw_problems() -> list[dict]: 30 | problems = get_mbpp_plus() 31 | return list(problems.values()) 32 | 33 | 34 | def get_humaneval_raw_problems() -> list[dict]: 35 | problems = get_human_eval_plus() 36 | return list(problems.values()) 37 | 38 | 39 | def get_evoeval_raw_problems(dataset: str): 40 | def get_raw_problems() -> list[dict]: 41 | problems = get_evo_eval(dataset) 42 | return list(problems.values()) 43 | 44 | return get_raw_problems 45 | 46 | 47 | def map_mbpp_problem(p: dict) -> Text2CodeProblem: 48 | id = p["task_id"] 49 | prompt = p["prompt"] 50 | start_index = prompt.index('"""') 51 | end_index = prompt.rindex('"""') 52 | prompt = prompt[start_index + 3 : end_index] 53 | assert_index = prompt.index("assert") 54 | instruction = prompt[:assert_index].strip() 55 | if not instruction.endswith("."): 56 | instruction += "." 57 | assertion = prompt[assert_index:].strip() 58 | instruction = f"""{instruction} 59 | 60 | ```python 61 | {assertion} 62 | ```""" 63 | prefix = "" 64 | response_prefix = f"""{prefix}```python""" 65 | return Text2CodeProblem( 66 | id=str(id), 67 | prompt=prompt, 68 | instruction=instruction, 69 | response_prefix=response_prefix, 70 | ) 71 | 72 | 73 | def map_humaneval_problem(p: dict) -> Text2CodeProblem: 74 | id = p["task_id"] 75 | prompt = p["prompt"] 76 | prompt = prompt.strip() 77 | # try: 78 | # docstring_index = prompt.index('"""') 79 | # except ValueError: 80 | # docstring_index = prompt.index("'''") 81 | # signature = prompt[:docstring_index].strip() 82 | # Instruction 83 | # instruction = f"""Complete the implementation of the following function: 84 | prompt_header = os.getenv( 85 | "PROMPT_HEADER", "Write a Python function to solve the following task:" 86 | ) 87 | instruction = f"""{prompt_header} 88 | ```python 89 | {prompt} 90 | ```""" 91 | prefix = "" 92 | prefix_template = os.getenv("PREFIX_TEMPLATE", "```python") 93 | response_prefix = prefix + ( 94 | prefix_template.replace("{prompt}", prompt) 95 | if "{prompt}" in prefix_template 96 | else prefix_template 97 | ) 98 | # response_prefix = f"""{prefix}```python 99 | # {prompt}""" 100 | return Text2CodeProblem( 101 | id=id, 102 | prompt=prompt, 103 | instruction=instruction, 104 | response_prefix=response_prefix, 105 | ) 106 | 107 | 108 | @dataclass(frozen=True) 109 | class Args: 110 | model_key: str 111 | dataset: Literal[ 112 | "humaneval", 113 | "mbpp", 114 | "EvoEval_difficult", 115 | "EvoEval_creative", 116 | "EvoEval_subtle", 117 | "EvoEval_combine", 118 | "EvoEval_tool_use", 119 | "EvoEval_verbose", 120 | "EvoEval_concise", 121 | ] 122 | save_path: str 123 | n_samples_per_problem: int = field(default=1) 124 | max_new_tokens: int = field(default=1024) 125 | top_p: float = field(default=1.0) 126 | temperature: float = field(default=0.0) 127 | model_name_or_path: str | None = None 128 | 129 | 130 | def main(): 131 | args = cast(Args, HfArgumentParser(Args).parse_args_into_dataclasses()[0]) 132 | raw_problem_fn, map_problem_fn = ( 133 | (get_evoeval_raw_problems(args.dataset), map_humaneval_problem) 134 | if args.dataset.startswith("EvoEval_") 135 | else ( 136 | (get_humaneval_raw_problems, map_humaneval_problem) 137 | if args.dataset == "humaneval" 138 | else (get_mbpp_raw_problems, map_mbpp_problem) 139 | ) 140 | ) 141 | raw_problems = raw_problem_fn() 142 | problems = list(map(map_problem_fn, raw_problems)) 143 | 144 | engine = LLM( 145 | tokenizer=args.model_key, model=args.model_name_or_path or args.model_key 146 | ) 147 | 148 | base_model_prompt = is_base_model(args.model_key) 149 | 150 | stop: str | list[str] = ( 151 | "\n```\n" 152 | if not base_model_prompt 153 | else ["\ndef ", "\nclass ", "\nimport ", "\nfrom ", "\nassert ", "\n# "] 154 | ) 155 | sampling_params = SamplingParams( 156 | n=args.n_samples_per_problem, 157 | temperature=args.temperature, 158 | max_tokens=args.max_new_tokens, 159 | top_k=-1, 160 | top_p=args.top_p, 161 | stop=stop, 162 | ) 163 | 164 | if base_model_prompt: 165 | print("Base model") 166 | else: 167 | prompt_template = infer_prompt_template( 168 | os.getenv("TOKENIZER") or args.model_name_or_path or args.model_key 169 | ) 170 | # prompt_template = PROMPT_TEMPLATE 171 | print("Using:", prompt_template) 172 | 173 | prompts: list[str] = [] 174 | for problem in problems: 175 | if not base_model_prompt: 176 | prompt = prompt_template.format( 177 | instruction=problem["instruction"], response=problem["response_prefix"] 178 | ) 179 | else: 180 | prompt = problem["prompt"] 181 | prompts.append(prompt) 182 | 183 | results = engine.generate(prompts, sampling_params) 184 | Path(args.save_path).write_text("") 185 | 186 | step = 20 187 | print_or_not = [idx == 0 or idx % step == 0 for idx in range(len(problems))] 188 | 189 | def sanitize(output: str) -> str: 190 | if not base_model_prompt: 191 | return output.split("```python")[-1].split("```")[0] 192 | for s in stop: 193 | output = output.rsplit(s, 1)[0] 194 | return output 195 | 196 | for problem, prompt, result, print_debug in zip( 197 | problems, prompts, results, print_or_not 198 | ): 199 | if print_debug: 200 | print("[Example Prompt]") 201 | print(prompt) 202 | print("[Example Completion]") 203 | print(result.outputs[0].text) 204 | samples = [ 205 | dict( 206 | task_id=problem["id"], 207 | completion=sanitize(output.text), 208 | ) 209 | for output in result.outputs 210 | ] 211 | write_jsonl(args.save_path, samples, append=True) 212 | 213 | 214 | if __name__ == "__main__": 215 | main() 216 | -------------------------------------------------------------------------------- /src/star_align/utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import hashlib 3 | import json 4 | import os 5 | import time 6 | from pathlib import Path 7 | from typing import Any, Iterable, Literal, Mapping, Sequence, TypeVar 8 | 9 | import openai 10 | import tenacity 11 | import tiktoken 12 | 13 | N_CORES = 1 if (count := os.cpu_count()) is None or count == 0 else count // 2 14 | 15 | 16 | def read_jsonl(path: str | Path) -> list[Any]: 17 | """Read lines of JSON from a file (including '\n').""" 18 | with Path(path).open("r") as f: 19 | return [json.loads(line) for line in f] 20 | 21 | 22 | def write_jsonl(path: str | Path, data: Sequence[Mapping], mode: str = "w"): 23 | # cannot use `dict` here as it is invariant 24 | with Path(path).open(mode) as f: 25 | for item in data: 26 | f.write(json.dumps(item) + "\n") 27 | 28 | 29 | _T = TypeVar("_T") 30 | 31 | 32 | def chunked(seq: Sequence[_T], n: int) -> Iterable[Sequence[_T]]: 33 | """Yield successive n-sized chunks from seq.""" 34 | return (seq[i : i + n] for i in range(0, len(seq), n)) 35 | 36 | 37 | def retry(errors: Any, max_attempts: int = 5): 38 | return tenacity.retry( 39 | retry=tenacity.retry_if_exception_type(errors), 40 | wait=tenacity.wait_exponential(multiplier=1, min=5, max=20), 41 | stop=tenacity.stop_after_attempt(max_attempts), 42 | before_sleep=print, 43 | ) 44 | 45 | 46 | ERRORS = ( 47 | openai.RateLimitError, 48 | openai.APIError, 49 | openai.APIConnectionError, 50 | openai.InternalServerError, 51 | ) 52 | 53 | 54 | class OpenAIClient: 55 | def __init__(self): 56 | self.client = openai.OpenAI() 57 | self.async_client = openai.AsyncClient() 58 | 59 | @retry(ERRORS) 60 | def chat_completions_with_backoff(self, *args, **kwargs): 61 | return self.client.chat.completions.create(*args, **kwargs) 62 | 63 | @retry(ERRORS) 64 | def completions_with_backoff(self, *args, **kwargs): 65 | return self.client.completions.create(*args, **kwargs) 66 | 67 | @retry(ERRORS) 68 | async def chat_completions_with_backoff_async(self, *args, **kwargs): 69 | return await self.async_client.chat.completions.create(*args, **kwargs) 70 | 71 | @retry(ERRORS) 72 | async def completions_with_backoff_async(self, *args, **kwargs): 73 | return await self.async_client.completions.create(*args, **kwargs) 74 | 75 | async def delayed_request( 76 | self, 77 | request: dict[str, Any], 78 | mode: Literal["chat", "completion"], 79 | delay: float | None, 80 | ): 81 | """Prevent quantized rate limit: 82 | https://help.openai.com/en/articles/6891753-rate-limit-advice""" 83 | if delay is not None: 84 | # synchronized sleep 85 | time.sleep(delay) 86 | if mode == "chat": 87 | func = self.chat_completions_with_backoff_async 88 | else: 89 | func = self.completions_with_backoff_async 90 | return await func(**request) 91 | 92 | async def dispatch_chat_completions( 93 | self, 94 | requests: list[dict[str, Any]], 95 | delay: float | None = None, 96 | ): 97 | """Dispatch chat completions requests asynchronously. 98 | Args: 99 | requests: a list of API argument names to values. 100 | delay: interval between requests. 101 | """ 102 | 103 | tasks = [self.delayed_request(request, "chat", delay) for request in requests] 104 | return await asyncio.gather(*tasks, return_exceptions=True) 105 | 106 | async def dispatch_completions( 107 | self, 108 | requests: list[dict[str, Any]], 109 | delay: float | None = None, 110 | ): 111 | """Dispatch completions requests asynchronously. 112 | Args: 113 | requests: a list of API argument names to values. 114 | delay: interval between requests. 115 | """ 116 | 117 | tasks = [ 118 | self.delayed_request(request, "completion", delay) for request in requests 119 | ] 120 | return await asyncio.gather(*tasks, return_exceptions=True) 121 | 122 | 123 | # https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb 124 | def num_tokens_from_string(string: str, model: str) -> int: 125 | """Returns the number of tokens in a text string.""" 126 | encoding = tiktoken.encoding_for_model(model) 127 | # encoding = tiktoken.get_encoding(encoding_name) 128 | num_tokens = len(encoding.encode(string, disallowed_special=())) 129 | return num_tokens 130 | 131 | 132 | def timestamp() -> str: 133 | return time.strftime("%Y%m%d_%H%M%S") 134 | 135 | 136 | def compute_fingerprint(*args: Any, hash_length: int | None = None) -> str: 137 | combined = "".join(map(str, args)) 138 | content = hashlib.sha256(combined.encode()).hexdigest() 139 | if hash_length is not None: 140 | content = content[:hash_length] 141 | return content 142 | 143 | 144 | def find_code_blocks(response: str, tag: str | None = None) -> list[str]: 145 | """Find all enclosed code blocks in the response, optionally filtering by language tag.""" 146 | all_indices = find_codeblock_indices(response, tag) 147 | return [response[start:end].strip() for start, end in all_indices] 148 | 149 | 150 | def find_codeblock_indices( 151 | response: str, tag: str | None = None 152 | ) -> list[tuple[int, int]]: 153 | """Find all enclosed code blocks in the response, optionally filtering by language tag.""" 154 | all_indices: list[tuple[int, int]] = [] 155 | search_start = ( 156 | 0 # Variable to keep track of where to start searching for the next code block 157 | ) 158 | 159 | while "```" in response[search_start:]: 160 | # Find the start of the code block (excluding the backticks) 161 | code_start_index = response.find("```", search_start) + 3 162 | 163 | # Find the end of the language tag line (or the start of the code if no tag line) 164 | code_start_endline = response.find("\n", code_start_index) 165 | if code_start_endline == -1: # Handle case where there's no newline after ``` 166 | code_start_endline = code_start_index 167 | 168 | # Extract the language tag (if any) 169 | extracted_tag = response[code_start_index:code_start_endline].strip() 170 | 171 | # Adjust the start index if a language tag is found 172 | if extracted_tag: 173 | actual_code_start = code_start_endline + 1 174 | else: 175 | actual_code_start = code_start_index 176 | 177 | # Find the end of the code block 178 | code_end_index = response.find("```", actual_code_start) 179 | if code_end_index == -1: 180 | break # Exit if there's no closing ``` 181 | 182 | # Extract the code 183 | # code = response[actual_code_start:code_end_index].strip() 184 | 185 | # Check if the extracted code block matches the requested language tag (if any) 186 | if tag is None or extracted_tag.lower() == tag.lower(): 187 | all_indices.append((actual_code_start, code_end_index)) 188 | 189 | # Update the search_start to look for the next code block 190 | search_start = code_end_index + 3 191 | 192 | return all_indices 193 | 194 | 195 | DEFAULT_TEMPLATE = """\ 196 | ### Instruction 197 | {instruction} 198 | 199 | ### Response 200 | {response}""" 201 | 202 | 203 | def is_base_model(tokenizer_name: str) -> bool: 204 | from transformers import AutoTokenizer 205 | 206 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 207 | return tokenizer.chat_template is None and "octocoder" not in tokenizer_name 208 | 209 | 210 | OCTOCODER_CHAT_TEMPLATE = """\ 211 | {%- for message in messages %} 212 | {%- if message['role'] == 'system' %} 213 | {{ raise_exception('System messages are not allowed in this template.') }} 214 | {%- else %} 215 | {%- if message['role'] == 'user' %} 216 | {{'Question: ' + message['content'] + '\n\n'}} 217 | {%- else %} 218 | {{'Answer: ' + message['content'] + '\n\n'}} 219 | {%- endif %} 220 | {%- endif %} 221 | {%- endfor %} 222 | {{'Question: '}}""" 223 | 224 | 225 | def infer_prompt_template(tokenizer_name: str) -> str: 226 | from transformers import AutoTokenizer 227 | 228 | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 229 | if "octocoder" in tokenizer_name: 230 | tokenizer.chat_template = OCTOCODER_CHAT_TEMPLATE 231 | if tokenizer.chat_template is not None: 232 | template = tokenizer.apply_chat_template( 233 | [ 234 | {"role": "user", "content": "{instruction}"}, 235 | {"role": "assistant", "content": "{response}"}, 236 | ], 237 | tokenize=False, 238 | ) 239 | else: 240 | template = DEFAULT_TEMPLATE 241 | end_index = template.rindex("{response}") + len("{response}") 242 | template = template[:end_index] 243 | return template 244 | -------------------------------------------------------------------------------- /src/star_align/train.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import cast 3 | 4 | import torch 5 | from datasets import load_dataset 6 | from transformers import HfArgumentParser, Trainer, TrainingArguments 7 | 8 | from star_align.llm_wrapper import ( 9 | DecodingConfig, 10 | EncodingConfig, 11 | TokenizationContext, 12 | get_model_context, 13 | pad_sequences, 14 | ) 15 | from star_align.prompt_template import CHAT_TEMPLATE, SC2_INSTRUCT_PROMPT as PROMPT_TEMPLATE 16 | from star_align.utils import N_CORES 17 | 18 | 19 | @dataclass(frozen=True) 20 | class ModelArguments: 21 | model_key: str 22 | model_name_or_path: str | None = None 23 | attention_dropout: float | None = field(default=None) 24 | residual_dropout: float | None = field(default=None) 25 | embedding_dropout: float | None = field(default=None) 26 | 27 | 28 | # Ignored index in CrossEntropyLoss 29 | IGNORED_INDEX = -100 30 | 31 | def map_dataset( 32 | examples: dict[str, list[str]], 33 | args: "Args", 34 | context: TokenizationContext, 35 | ) -> dict: 36 | if args.prompt_completion_mode: 37 | prompts = examples["prompt"] 38 | completions = examples["completion"] 39 | else: 40 | instructions = examples["instruction"] 41 | responses = examples["response"] 42 | 43 | prompts = [ 44 | PROMPT_TEMPLATE.format(instruction=instruction, response="") 45 | for instruction in instructions 46 | ] 47 | completions = responses 48 | 49 | assert len(prompts) == len(completions) 50 | prompt_config = EncodingConfig(add_bos=True, add_eos=False) 51 | completion_config = EncodingConfig(add_bos=False, add_eos=True) 52 | prompt_id_batches = context.encode(prompt_config, prompts) 53 | completion_id_batches = context.encode(completion_config, completions) 54 | # prompt_id_batches = context.tokenization_context.encode(prompt_config, prompts) 55 | # completion_id_batches = context.tokenization_context.encode( 56 | # completion_config, completions 57 | # ) 58 | assert len(prompt_id_batches) == len(completion_id_batches) 59 | untruncated_input_ids = [ 60 | (instruction_ids + response_ids) 61 | for instruction_ids, response_ids in zip( 62 | prompt_id_batches, completion_id_batches 63 | ) 64 | ] 65 | exceeding_length = [ 66 | len(input_id) > args.max_training_seq_length 67 | for input_id in untruncated_input_ids 68 | ] 69 | input_ids = [ 70 | input_id[: args.max_training_seq_length] for input_id in untruncated_input_ids 71 | ] 72 | # NOTE: no need to set EOF to IGNORED_INDEX as it is *implicitly* ignored inside 73 | # the model.forward that shifts the logits left by 1 74 | labels = [ 75 | (list(map(lambda _: IGNORED_INDEX, instruction_ids)) + response_ids)[ 76 | : args.max_training_seq_length 77 | ] 78 | for instruction_ids, response_ids in zip( 79 | prompt_id_batches, completion_id_batches 80 | ) 81 | ] 82 | # `len` of each returned value must be the same, which is required by `tokenizer.map` 83 | # After `map`, they are treated as individual pieces of data, not as a batch. 84 | assert len(input_ids) == len(labels) 85 | for input_id_batch, label_batch in zip(input_ids, labels): 86 | assert len(input_id_batch) == len(label_batch) 87 | print(context.decode(DecodingConfig.default(), input_ids[0:])[0]) 88 | return { 89 | "input_ids": input_ids, 90 | "labels": labels, 91 | "exceeding_length": exceeding_length, 92 | } 93 | 94 | 95 | def get_data_collator(model_key: str, args: "Args", pad_token_id: int): 96 | """Pad input_ids to the right, create labels by setting the padding tokens to -100, and 97 | create attention_mask to ignore the padding tokens""" 98 | 99 | def collate(examples: list[dict[str, list[int]]]) -> dict[str, torch.Tensor]: 100 | input_ids_unpadded = [example["input_ids"] for example in examples] 101 | labels_unpadded = [example["labels"] for example in examples] 102 | padding_length = ( 103 | args.max_training_seq_length if args.pad_to_max_length else None 104 | ) 105 | input_ids = pad_sequences( 106 | input_ids_unpadded, pad_token_id, "right", padding_length=padding_length 107 | ) 108 | labels = pad_sequences( 109 | labels_unpadded, IGNORED_INDEX, "right", padding_length=padding_length 110 | ) 111 | 112 | assert input_ids.shape == labels.shape 113 | assert len(input_ids) == len(examples) 114 | # Enforced in `map_raw_dataset` 115 | assert input_ids.shape[-1] <= args.max_training_seq_length 116 | if args.pad_to_max_length: 117 | assert input_ids.shape[-1] == args.max_training_seq_length 118 | 119 | if "starcoder2" in model_key: 120 | attention_mask = torch.ones(input_ids.shape, dtype=torch.bool) 121 | else: 122 | attention_mask = input_ids.ne(pad_token_id) 123 | # when bos == eos, the first token will be masked by mistake 124 | attention_mask[:, 0] = True 125 | return { 126 | "input_ids": input_ids, 127 | "labels": labels, 128 | "attention_mask": attention_mask, 129 | } 130 | 131 | return collate 132 | 133 | 134 | @dataclass(frozen=True) 135 | class Args: 136 | datafile_paths: list[str] = field(default_factory=list) 137 | max_training_seq_length: int = field(default=1216) 138 | pad_to_max_length: bool = field(default=False) 139 | eval_dataset_size: float = field( 140 | default=0.05, metadata={"help": "0--1 means ratio, >1 means number of examples"} 141 | ) 142 | use_flash_attention: bool = field(default=False) 143 | prompt_completion_mode: bool = field(default=False) 144 | 145 | 146 | def train(): 147 | parser = HfArgumentParser((ModelArguments, TrainingArguments, Args)) 148 | model_args, training_args, args = cast( 149 | tuple[ModelArguments, TrainingArguments, Args], 150 | parser.parse_args_into_dataclasses(), 151 | ) 152 | dataset = load_dataset("json", data_files=args.datafile_paths, split="train") 153 | 154 | model_key = model_args.model_key 155 | if (model_name_or_path := model_args.model_name_or_path) is None: 156 | model_name_or_path = model_key 157 | 158 | tokenization_context = TokenizationContext.from_model_key( 159 | model_key, model_name_or_path 160 | ) 161 | # if dataset_config.dpo_jsonl_path is None or dataset_config.dpo_sft: 162 | train_dataset = dataset.map( 163 | function=map_dataset, 164 | fn_kwargs=dict(args=args, context=tokenization_context), 165 | batched=True, 166 | num_proc=N_CORES, 167 | remove_columns=dataset.column_names, 168 | load_from_cache_file=False, # not args.overwrite_cache 169 | desc="Running tokenizer on train dataset", 170 | ) 171 | msg = f"#Examples truncated: {sum(train_dataset['exceeding_length'])} / {len(train_dataset)}" 172 | print(msg) 173 | # else: 174 | # train_dataset = dataset 175 | 176 | # Shuffling 177 | if training_args.eval_steps is None and training_args.evaluation_strategy == "no": 178 | train_dataset = train_dataset.shuffle(seed=training_args.seed) 179 | eval_dataset = None 180 | else: 181 | print("Splitting dataset") 182 | split_dataset = train_dataset.train_test_split( 183 | test_size=args.eval_dataset_size, 184 | shuffle=True, 185 | seed=training_args.seed, 186 | ) 187 | train_dataset = split_dataset["train"] 188 | eval_dataset = split_dataset["test"] 189 | 190 | state = get_model_context( 191 | model_key, 192 | model_name_or_path, 193 | tokenization_context, 194 | inference_mode=False, 195 | use_flash_attention=args.use_flash_attention, 196 | attention_dropout=model_args.attention_dropout, 197 | residual_dropout=model_args.residual_dropout, 198 | embedding_dropout=model_args.embedding_dropout, 199 | ) 200 | if "codeqwen" in model_key.lower(): 201 | print(f"Hack for {model_key}") 202 | state.model.generation_config.do_sample = True 203 | 204 | print("Parallel mode:", training_args.parallel_mode) 205 | data_collator = get_data_collator( 206 | model_args.model_key, args, state.tokenization_context.pad_token_id 207 | ) 208 | 209 | # neftune_noise_alpha 210 | trainer = Trainer( 211 | model=state.model, 212 | args=training_args, 213 | train_dataset=train_dataset, 214 | eval_dataset=eval_dataset, 215 | data_collator=data_collator, 216 | # eval_dataset=small_eval_dataset, 217 | # compute_metrics=compute_metrics, 218 | ) 219 | 220 | # NOTE: the checkpoint will override the initialized model 221 | trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 222 | trainer.save_state() 223 | trainer.save_model(training_args.output_dir) 224 | state.tokenization_context.tokenizer.chat_template = CHAT_TEMPLATE 225 | state.tokenization_context.tokenizer.save_pretrained(training_args.output_dir) 226 | 227 | 228 | if __name__ == "__main__": 229 | train() 230 | -------------------------------------------------------------------------------- /evaluation/ds_1000.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from pathlib import Path 4 | from typing import Callable, Literal, cast 5 | from transformers import AutoTokenizer 6 | from ds1000 import DS1000Dataset, DS1000Problem 7 | from tqdm.auto import tqdm 8 | from transformers import HfArgumentParser 9 | 10 | from star_align.llm_wrapper import ( 11 | GenerationConfig, 12 | ModelContext, 13 | create_infilling_prompt, 14 | get_model_context, 15 | ) 16 | from star_align.utils import infer_prompt_template 17 | 18 | from vllm import LLM, SamplingParams 19 | 20 | PROMPT = cast(str, None) 21 | 22 | 23 | @dataclass 24 | class Args: 25 | dataset_path: str 26 | model_key: str 27 | model_name_or_path: str 28 | mode: Literal["Insertion", "Completion"] 29 | output_dir: str 30 | 31 | temperature: float = field(default=0.2) 32 | top_p: float = field(default=0.95) 33 | max_length: int = field(default=1024) 34 | n_samples_per_batch: int = field(default=5) 35 | n_batches: int = field(default=8) 36 | 37 | def to_generation_config(self) -> GenerationConfig: 38 | return GenerationConfig( 39 | # Use max_length to control 40 | max_new_tokens=9999999999999, 41 | top_p=self.top_p, 42 | temperature=self.temperature, 43 | max_length=self.max_length, 44 | ) 45 | 46 | 47 | def postprocess(text: str) -> str: 48 | return text.split("```")[0] 49 | 50 | 51 | def create_prompt(args: Args, tokenizer: AutoTokenizer, problem: DS1000Problem) -> str: 52 | prompt = problem["prompt"] 53 | if args.mode == "Insertion": 54 | prompt = preprocess_insertion_prompt(prompt) 55 | assert prompt.count("[insert]") == 1 56 | prefix, suffix = prompt.split("[insert]") 57 | prompt = create_infilling_prompt( 58 | model_key=args.model_key, 59 | prefix=prefix, 60 | suffix=suffix, 61 | tokenizer=tokenizer, 62 | ) 63 | else: 64 | assert args.mode == "Completion" 65 | instruction, response_prefix = preprocess_completion_prompt(problem["prompt"]) 66 | prompt = PROMPT.format( 67 | instruction=instruction, 68 | response=response_prefix, 69 | ) 70 | return prompt 71 | 72 | 73 | def generate( 74 | args: Args, 75 | # model_context: ModelContext, 76 | engine: LLM, 77 | problem: DS1000Problem, 78 | ): 79 | lib: str = problem["lib"] 80 | model_key = args.model_key.replace("/", "-") 81 | problem_id: str = f"q{problem.problem_id}" 82 | path = Path(args.output_dir) / model_key / lib / args.mode / problem_id 83 | finishing_signal = path / "FINISHED" 84 | if finishing_signal.exists(): 85 | print("Skipping:", path) 86 | return 87 | if not path.exists(): 88 | print("Making directory:", path) 89 | path.mkdir(parents=True, exist_ok=True) 90 | # config = args.to_generation_config() 91 | prompt = create_prompt(args, engine.get_tokenizer(), problem) 92 | print("========PROMPT=======") 93 | print(prompt) 94 | print("========PROMPT=======") 95 | 96 | sampling_params = SamplingParams( 97 | n=args.n_batches * args.n_samples_per_batch, 98 | temperature=args.temperature, 99 | max_tokens=args.max_length, 100 | top_k=-1, 101 | top_p=args.top_p, 102 | stop=["```"], 103 | ) 104 | 105 | # for batch_idx in range(args.n_batches): 106 | # print(f"Generating batch {batch_idx} of {args.n_batches}") 107 | # response = model_context.complete( 108 | # config=config, 109 | # prompts=[prompt] * args.n_samples_per_batch, 110 | # stop_tokens=["```"] if os.getenv("STOP") is not None else None, 111 | # ) 112 | print(f"Generating {args.n_batches * args.n_samples_per_batch} samples") 113 | results = engine.generate(prompt, sampling_params) 114 | assert len(results) == 1 115 | print("=======RESPOSE[-1]=======") 116 | # postprocess_fn: Callable[[str], str] = ( 117 | # (lambda x: x) if args.mode == "Insertion" else postprocess 118 | # ) 119 | postprocess_fn = postprocess 120 | print(postprocess_fn(results[0].outputs[-1].text)) 121 | # print("=======RESPOSE[-1]=======") 122 | # print("=======RESPOSE[RAW]=======") 123 | # print(response.decoded_outputs[-1]) 124 | # print("=======RESPOSE[RAW]=======") 125 | # exit() 126 | assert len(results[0].outputs) == args.n_batches * args.n_samples_per_batch 127 | for idx, output in enumerate(results[0].outputs): 128 | sample = output.text 129 | sample = postprocess_fn(sample) 130 | # global_index = batch_idx * args.n_samples_per_batch + idx 131 | global_index = idx 132 | output_file = path / f"{global_index}.py" 133 | output_file.write_text(sample) 134 | finishing_signal.touch() 135 | 136 | 137 | def preprocess_completion_prompt(prompt: str) -> tuple[str, str]: 138 | """Preprocess the DS-1000 prompt (Completion mode) into instruction and response prefix""" 139 | # hit = False 140 | if not "SOLUTION START" in prompt: 141 | answer_index = prompt.rindex("A:") 142 | answer = prompt[answer_index + 2 :].strip() 143 | instruction: str = prompt[:answer_index].strip() 144 | if instruction.startswith("Problem:"): 145 | instruction = instruction[len("Problem:") :].strip() 146 | if "### BEGIN SOLUTION" in prompt: 147 | assert prompt.count("") == 1 148 | assert prompt.count("") == 0 149 | lines = answer.splitlines(keepends=True) 150 | return_line, result_line, begin_line = lines[-3:] 151 | assert return_line.strip().startswith("# return") 152 | assert result_line.strip().startswith("# ") 153 | assert begin_line.strip() == "### BEGIN SOLUTION" 154 | response = "".join(lines[:-3]).strip() 155 | hint = begin_line.replace("###", "#").replace("BEGIN SOLUTION", "Solution") 156 | response += f"\n{hint}\n" 157 | else: 158 | assert "BEGIN SOLUTION" in prompt 159 | assert prompt.count("") == 2 160 | assert prompt.count("") == 1 161 | first_block_start = prompt.index("") 162 | first_block_end = prompt.index("") 163 | second_block_start = prompt.index("", first_block_start + 1) 164 | assert first_block_end < second_block_start 165 | lines = answer.splitlines(keepends=True) 166 | block_end, instruction_line, begin_line, block_start = lines[-4:] 167 | assert begin_line.strip() == "BEGIN SOLUTION" 168 | assert block_start.strip() == "" 169 | if not block_end.strip() == "": 170 | if lines[-6].strip() == "": 171 | response_prefix = lines[:-6] 172 | starting_lines = lines[-5:-2] 173 | else: 174 | assert instruction_line.strip() == "" 175 | response_prefix = lines[:-3] 176 | starting_lines = lines[-2:-2] 177 | else: 178 | response_prefix = lines[:-4] 179 | starting_lines = lines[-3:-2] 180 | starting_lines = [f"# {line.lstrip()}" for line in starting_lines] 181 | response = "".join([*response_prefix, *starting_lines]).strip() 182 | response += "\n# Solution\n" 183 | else: 184 | # hit = True 185 | assert prompt.count("") == 0 186 | assert prompt.count("") == 0 187 | assert prompt.strip().endswith("# SOLUTION START") 188 | code_prefix = prompt[: prompt.rindex("# SOLUTION START")].strip() 189 | instruction = f"""Write a solution to the following problem: 190 | ```python 191 | {code_prefix} 192 | ```""" 193 | response = f"```python\n{code_prefix}\n# Solution\n" 194 | instruction = instruction.replace("", "```python").replace("", "```") 195 | response = response.replace("", "```python").replace("", "```") 196 | # if hit: 197 | # print("[Instruction]") 198 | # print(instruction) 199 | # print("[Response]") 200 | # print(response) 201 | # breakpoint() 202 | return instruction, response 203 | 204 | 205 | def preprocess_insertion_prompt(prompt: str) -> str: 206 | pattern = """ 207 | BEGIN SOLUTION 208 | 209 | [insert] 210 | 211 | END SOLUTION""" 212 | pattern_index = prompt.index(pattern) 213 | # pattern_block = prompt[pattern_index:] 214 | prefix = prompt[:pattern_index] 215 | # hit = False 216 | if pattern + "\n" in prompt: 217 | index = prompt.index("", pattern_index + len(pattern)) 218 | suffix = prompt[index + len("") :] 219 | else: 220 | # hit = True 221 | assert pattern in prompt 222 | suffix = "" 223 | final_prompt = prefix.strip() + "\n[insert]\n" + suffix.strip() 224 | final_prompt = final_prompt.replace("", "```python").replace("", "```") 225 | # if hit: 226 | # print(final_prompt) 227 | # breakpoint() 228 | return final_prompt 229 | 230 | 231 | def main(): 232 | args = cast(Args, HfArgumentParser(Args).parse_args_into_dataclasses()[0]) 233 | dataset = DS1000Dataset(args.dataset_path, mode=args.mode) 234 | 235 | global PROMPT 236 | if (inferred := os.getenv("INFER")) is not None: 237 | if inferred == "1": 238 | PROMPT = infer_prompt_template(args.model_name_or_path) 239 | else: 240 | PROMPT = infer_prompt_template(inferred) 241 | 242 | print("Using prompt:") 243 | print(PROMPT) 244 | 245 | all_problems = [ 246 | problem 247 | for problems in dataset.data.values() 248 | for problem in problems 249 | if args.mode == "Completion" or problem["lib"] != "Matplotlib" 250 | ] 251 | engine = LLM( 252 | tokenizer=args.model_key, model=args.model_name_or_path or args.model_key 253 | ) 254 | # model_context = get_model_context( 255 | # model_key=args.model_key, 256 | # model_name_or_path=args.model_name_or_path, 257 | # ) 258 | for problem in tqdm(all_problems): 259 | # generate(args, model_context, problem) 260 | generate(args, engine, problem) 261 | 262 | 263 | if __name__ == "__main__": 264 | main() 265 | -------------------------------------------------------------------------------- /src/star_align/execution_filter.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import os 4 | import shutil 5 | import sys 6 | from concurrent.futures import ProcessPoolExecutor, as_completed 7 | from multiprocessing import Process, cpu_count 8 | from evalplus.eval.utils import ( 9 | create_tempdir, 10 | reliability_guard, 11 | swallow_io, 12 | time_limit, 13 | ) 14 | from tqdm.auto import tqdm 15 | 16 | from datasets import load_dataset 17 | from star_align.utils import chunked, find_code_blocks 18 | from transformers import HfArgumentParser 19 | from dataclasses import dataclass, field 20 | from typing import cast 21 | 22 | 23 | _magic_splitter_ = "### -- what do you think? -- ###" 24 | 25 | 26 | def make_python_membound_code_prefix(limit_mb): 27 | maximum_memory_bytes = limit_mb * 1024 * 1024 28 | return f"""\ 29 | import resource 30 | import platform 31 | 32 | resource.setrlimit( 33 | resource.RLIMIT_AS, ({maximum_memory_bytes}, {maximum_memory_bytes}) 34 | ) 35 | resource.setrlimit( 36 | resource.RLIMIT_DATA, ({maximum_memory_bytes}, {maximum_memory_bytes}) 37 | ) 38 | if not platform.uname().system == "Darwin": 39 | resource.setrlimit( 40 | resource.RLIMIT_STACK, ({maximum_memory_bytes}, {maximum_memory_bytes}) 41 | ) 42 | {_magic_splitter_} 43 | """ 44 | 45 | 46 | @dataclass(frozen=True) 47 | class Args: 48 | response_paths: list[str] 49 | result_path: str 50 | save_request_errors: bool = False 51 | shuffle: bool = field(default=True) 52 | cache_paths: list[str] = field(default_factory=list) 53 | load_pass_only_cache: bool = field(default=False) 54 | max_batched_tasks: int = 10000 55 | max_workers: int = cpu_count() 56 | container_server: str | None = None 57 | 58 | 59 | def suppress_output(func): 60 | def wrapper(*args, **kwargs): 61 | original_stdout = sys.stdout 62 | original_stderr = sys.stderr 63 | sys.stdout = open(os.devnull, "w") 64 | sys.stderr = sys.stdout 65 | try: 66 | result = func(*args, **kwargs) 67 | finally: 68 | sys.stdout.close() 69 | sys.stdout = original_stdout 70 | sys.stderr = original_stderr 71 | return result 72 | 73 | return wrapper 74 | 75 | 76 | # Note: only run this within a safe subprocess 77 | def _run(code) -> None: 78 | with create_tempdir(): 79 | # These system calls are needed when cleaning up tempdir. 80 | rmtree = shutil.rmtree 81 | rmdir = os.rmdir 82 | chdir = os.chdir 83 | getcwd = os.getcwd 84 | 85 | maximum_memory_bytes = 1 * 1024 * 1024 * 1024 86 | reliability_guard(maximum_memory_bytes=maximum_memory_bytes) 87 | 88 | # Disable functionalities that can make destructive changes to the test. 89 | # allow only 1GB memory usage 90 | 91 | # run the function 92 | with swallow_io(): 93 | with time_limit(4): # max 4 seconds 94 | # run the function 95 | exec(code) 96 | 97 | # Needed for cleaning up. 98 | shutil.rmtree = rmtree 99 | os.rmdir = rmdir 100 | os.chdir = chdir 101 | os.getcwd = getcwd 102 | 103 | 104 | def containerized_run(item, limit_mb=4 * 1024): 105 | from star_align.code_exec_server.code_exec_reqs import exec_test 106 | 107 | idx, result, code, srv = item 108 | membound_code = make_python_membound_code_prefix(limit_mb) + code 109 | passed, output = exec_test( 110 | srv, membound_code, "", timeout=10, timeout_on_client=True 111 | ) 112 | return (idx, result, code, passed, output) 113 | 114 | 115 | def fork_run(item): 116 | idx, response, code, _ = item 117 | sys.stdout = open(os.devnull, "w") 118 | sys.stderr = sys.stdout 119 | p = Process(target=_run, args=(code,)) 120 | p.start() 121 | p.join(timeout=10) 122 | passed = p.exitcode == 0 123 | return (idx, response, code, passed, "NOT SUPPORTED") 124 | 125 | 126 | def is_compilable(code): 127 | try: 128 | ast.parse(code) 129 | return True 130 | except (SyntaxError, ValueError): 131 | return False 132 | 133 | 134 | def extract_code(response: str) -> str: 135 | def sanitize_codeblock(code: str) -> str: 136 | if "input" not in code: 137 | return code.strip() 138 | # Only remove the `if __name__..` when `input` is present because 139 | # it will block the code execution. 140 | key = 'if __name__ == "__main__":' 141 | key_alt = "if __name__ == '__main__':" 142 | index = code.find(key) 143 | if index == -1: 144 | index = code.find(key_alt) 145 | if index == -1: 146 | return code.strip() 147 | assert index != -1 148 | code = code[:index].strip() 149 | return code 150 | 151 | code_blocks = list(map(sanitize_codeblock, find_code_blocks(response))) 152 | return "\n\n".join(code_blocks) 153 | 154 | 155 | def form_new_data( 156 | item: dict, 157 | response: str, 158 | extracted_code: str, 159 | pass_execution: bool, 160 | output: str, 161 | ) -> dict: 162 | newdata = {k: v for k, v in item.items() if k not in ["response", "parsing_result"]} 163 | newdata["response"] = response 164 | newdata["extracted_code"] = extracted_code 165 | newdata["pass"] = pass_execution 166 | newdata["output"] = output 167 | return newdata 168 | 169 | 170 | def main(): 171 | args = cast(Args, HfArgumentParser(Args).parse_args_into_dataclasses()[0]) 172 | if args.container_server is None: 173 | option = input( 174 | "WARNING: container_server is not set. You will run the code locally, which can lead to unexpected side effects. Continue? (y/n): " 175 | ).strip() 176 | if option.lower() != "y": 177 | return 178 | 179 | if os.path.exists(args.result_path): 180 | option = input( 181 | f"WARNING: {args.result_path} already exists. Overwrite? (y/n): " 182 | ).strip() 183 | if option.lower() != "y": 184 | return 185 | 186 | cleanup_command = os.getenv("CLEANUP_COMMAND", None) 187 | if cleanup_command is not None: 188 | print(f"NOTE: the cleanup command is set to:") 189 | print(cleanup_command) 190 | 191 | raw_data = load_dataset("json", data_files=args.response_paths, split="train") 192 | if args.shuffle: 193 | raw_data = raw_data.shuffle() 194 | if len(args.cache_paths) > 0: 195 | cached_data = load_dataset("json", data_files=args.cache_paths, split="train") 196 | if args.load_pass_only_cache: 197 | cached_dict: dict[str, dict] = { 198 | item["extracted_code"]: item for item in cached_data if item["pass"] 199 | } 200 | else: 201 | cached_dict = {item["extracted_code"]: item for item in cached_data} 202 | else: 203 | cached_dict = {} 204 | 205 | all_tasks: list[tuple[int, str, str, str | None]] = [] 206 | eval_results: list[dict] = [] 207 | for idx, item in enumerate(tqdm(raw_data, desc="Preprocessing: extracting code")): 208 | # passing_results = [] 209 | if "parsing_result" not in item: 210 | item["parsing_result"] = [dict(response=item["response"])] 211 | for result in item["parsing_result"]: 212 | response = result["response"] 213 | code = extract_code(response) 214 | if (hit_item := cached_dict.get(code, None)) is not None: 215 | assert code == hit_item["extracted_code"] 216 | new_data = form_new_data( 217 | item=item, 218 | response=response, 219 | extracted_code=code, 220 | pass_execution=hit_item["pass"], 221 | output=hit_item["output"], 222 | ) 223 | eval_results.append(new_data) 224 | else: 225 | all_tasks.append((idx, response, code, args.container_server)) 226 | 227 | def pass_rate_str(passed: int, total: int, tag: str = "") -> str: 228 | percentage = f"{passed/total * 100:.2f}%" if total > 0 else "N/A" 229 | ratio = f"{passed}/{total}" 230 | tag = f"{tag} " if len(tag) > 0 else "" 231 | return f"{tag}Passed: {ratio} ({percentage})" 232 | 233 | n_cached_passed = sum(item["pass"] for item in eval_results) 234 | n_cached_total = len(eval_results) 235 | 236 | print(f"Cached: {len(eval_results)}, Active: {len(all_tasks)}") 237 | print(pass_rate_str(n_cached_passed, n_cached_total, "Cached")) 238 | 239 | run_func = containerized_run if args.container_server else fork_run 240 | tasks_chunks = list(chunked(all_tasks, args.max_batched_tasks)) 241 | n_processed = 0 242 | n_passed = 0 243 | with open(args.result_path, "w") as f: 244 | for cached_result in eval_results: 245 | f.write(json.dumps(cached_result) + "\n") 246 | with ProcessPoolExecutor(max_workers=args.max_workers) as executor: 247 | pbar = tqdm(tasks_chunks) 248 | for chunked_tasks in pbar: 249 | futures = [executor.submit(run_func, task) for task in chunked_tasks] 250 | # NOTE: futures do not return in the same order as before 251 | pbar_inner = tqdm( 252 | as_completed(futures), 253 | total=len(futures), 254 | leave=False, 255 | ) 256 | n_passed_inner = 0 257 | for n_processed_inner, future in enumerate(pbar_inner, start=1): 258 | n_processed += 1 259 | try: 260 | future_result = future.result() 261 | except Exception as e: 262 | continue 263 | idx, response, code, passed, output = future_result 264 | if "Failed to execute program" in output: 265 | if not args.save_request_errors: 266 | continue 267 | newdata = form_new_data( 268 | item=raw_data[idx], 269 | response=response, 270 | extracted_code=code, 271 | pass_execution=passed, 272 | output=output, 273 | ) 274 | f.write(json.dumps(newdata) + "\n") 275 | n_passed += passed 276 | n_passed_inner += passed 277 | pbar_inner.set_description( 278 | pass_rate_str(n_passed_inner, n_processed_inner) 279 | ) 280 | pbar.set_description(pass_rate_str(n_passed, n_processed)) 281 | if cleanup_command is not None: 282 | print(f"Cleaning up: {cleanup_command}") 283 | os.system(cleanup_command) 284 | print("Cleanup done.") 285 | 286 | n_total_passed = n_cached_passed + n_passed 287 | n_total = len(all_tasks) + n_cached_total 288 | print(pass_rate_str(n_total_passed, n_total, "Total")) 289 | 290 | 291 | if __name__ == "__main__": 292 | main() 293 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README-SC2INST.md: -------------------------------------------------------------------------------- 1 | # StarCoder2-Instruct: Fully Transparent and Permissive Self-Alignment for Code Generation 2 | 3 |

4 | ⭐️ About 5 | | 🚀 Quick start 6 | | 📚 Data generation 7 | | 🧑‍💻 Training 8 | | 📊 Evaluation 9 | | ⚠️ Limitations 10 |

11 | 12 | ![Banner](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/banner.png) 13 | 14 | 17 | 18 | ## About 19 | 20 | We introduce StarCoder2-15B-Instruct-v0.1, the very first entirely self-aligned code Large Language Model (LLM) trained with a fully permissive and transparent pipeline. Our open-source pipeline uses StarCoder2-15B to generate thousands of instruction-response pairs, which are then used to fine-tune StarCoder-15B itself without any human annotations or distilled data from huge and proprietary LLMs. 21 | 22 | - **Model:** [bigcode/starcoder2-15b-instruct-v0.1](https://huggingface.co/bigcode/starcoder2-instruct-15b-v0.1) 23 | - **Code:** [bigcode-project/starcoder2-self-align](https://github.com/bigcode-project/starcoder2-self-align) 24 | - **Dataset:** [bigcode/self-oss-instruct-sc2-exec-filter-50k](https://huggingface.co/datasets/bigcode/self-oss-instruct-sc2-exec-filter-50k/) 25 | - **Authors:** 26 | [Yuxiang Wei](https://yuxiang.cs.illinois.edu), 27 | [Federico Cassano](https://federico.codes/), 28 | [Jiawei Liu](https://jw-liu.xyz), 29 | [Yifeng Ding](https://yifeng-ding.com), 30 | [Naman Jain](https://naman-ntc.github.io), 31 | [Harm de Vries](https://www.harmdevries.com), 32 | [Leandro von Werra](https://twitter.com/lvwerra), 33 | [Arjun Guha](https://www.khoury.northeastern.edu/home/arjunguha/main/home/), 34 | [Lingming Zhang](https://lingming.cs.illinois.edu). 35 | 36 | ![self-alignment pipeline](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/method.png) 37 | 38 | ## Quick start 39 | 40 | Here is an example to get started with StarCoder2-15B-Instruct-v0.1 using the [transformers](https://huggingface.co/docs/transformers/index) library: 41 | 42 | ```python 43 | import transformers 44 | import torch 45 | 46 | pipeline = transformers.pipeline( 47 | model="bigcode/starcoder2-15b-instruct-v0.1", 48 | task="text-generation", 49 | torch_dtype=torch.bfloat16, 50 | device_map="auto", 51 | ) 52 | 53 | def respond(instruction: str, response_prefix: str) -> str: 54 | messages = [{"role": "user", "content": instruction}] 55 | prompt = pipeline.tokenizer.apply_chat_template(messages, tokenize=False) 56 | prompt += response_prefix 57 | 58 | teminators = [ 59 | pipeline.tokenizer.eos_token_id, 60 | pipeline.tokenizer.convert_tokens_to_ids("###"), 61 | ] 62 | 63 | result = pipeline( 64 | prompt, 65 | max_length=256, 66 | num_return_sequences=1, 67 | do_sample=False, 68 | eos_token_id=teminators, 69 | pad_token_id=pipeline.tokenizer.eos_token_id, 70 | truncation=True, 71 | ) 72 | response = response_prefix + result[0]["generated_text"][len(prompt) :].split("###")[0].rstrip() 73 | return response 74 | 75 | 76 | instruction = "Write a quicksort function in Python with type hints and a 'less_than' parameter for custom sorting criteria." 77 | response_prefix = "" 78 | 79 | print(respond(instruction, response_prefix)) 80 | ``` 81 | 82 | ## Data generation pipeline 83 | 84 | > Run `pip install -e .` first to install the package locally. Check [seed_gathering](seed_gathering/) for details on how we collected the seeds. 85 | 86 | By default, we use in-memory vLLM engine for data generation, but we also provide an option to use vLLM's [OpenAI compatible server](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html) for data generation. 87 | 88 | Set `CUDA_VISIBLE_DEVICES=...` to specify the GPU devices to use for the vLLM engine. 89 | 90 | To maximize data generation efficiency, we recommend invoking the script multiple times with different `seed_code_start_index` and `max_new_data` values, each with an vLLM engine running on a separate GPU set. For example, for a 100k seed dataset on a 2-GPU machine, you can have 2 processes each generating 50k samples by setting `CUDA_VISIBLE_DEVICES=0 --seed_code_start_index 0 --max_new_data 50000` and `CUDA_VISIBLE_DEVICES=1 --seed_code_start_index 50000 --max_new_data 50000`. 91 | 92 |
93 | 94 | Click to see how to run with vLLM's OpenAI compatible API 95 | 96 | To do so, make sure the vLLM server is running, and the associated `openai` environment variables are set. 97 | 98 | For example, you can start an vLLM server with `docker`: 99 | 100 | ```shell 101 | docker run --gpus '"device=0"' \ 102 | -v $HF_HOME:/root/.cache/huggingface \ 103 | -p 10000:8000 \ 104 | --ipc=host \ 105 | vllm/vllm-openai:v0.3.3 \ 106 | --model bigcode/starcoder2-15b \ 107 | --tensor-parallel-size 1 --dtype bfloat16 108 | ``` 109 | 110 | And then set the environment variables as follows: 111 | 112 | ```shell 113 | export OPENAI_API_KEY="EMPTY" 114 | export OPENAI_BASE_URL="http://localhost:10000/v1/" 115 | ``` 116 | 117 | You will also need to set `--use_vllm_server True` in the following commands. 118 | 119 |
120 | 121 |
122 | 123 | Snippet to concepts generation 124 | 125 | ```shell 126 | MODEL=bigcode/starcoder2-15b 127 | MAX_NEW_DATA=1000000 128 | python src/star_align/self_ossinstruct.py \ 129 | --use_vllm_server False \ 130 | --instruct_mode "S->C" \ 131 | --seed_data_files /path/to/seeds.jsonl \ 132 | --max_new_data $MAX_NEW_DATA \ 133 | --tag concept_gen \ 134 | --temperature 0.7 \ 135 | --seed_code_start_index 0 \ 136 | --model $MODEL \ 137 | --num_fewshots 8 \ 138 | --num_batched_requests 2000 \ 139 | --num_sample_per_request 1 140 | ``` 141 | 142 |
143 | 144 |
145 | 146 | Concepts to instruction generation 147 | 148 | ```shell 149 | MODEL=bigcode/starcoder2-15b 150 | MAX_NEW_DATA=1000000 151 | python src/star_align/self_ossinstruct.py \ 152 | --instruct_mode "C->I" \ 153 | --seed_data_files /path/to/concepts.jsonl \ 154 | --max_new_data $MAX_NEW_DATA \ 155 | --tag instruction_gen \ 156 | --temperature 0.7 \ 157 | --seed_code_start_index 0 \ 158 | --model $MODEL \ 159 | --num_fewshots 8 \ 160 | --num_sample_per_request 1 \ 161 | --num_batched_request 2000 162 | ``` 163 | 164 |
165 | 166 |
167 | 168 | Instruction to response (with self-validation code) generation 169 | 170 | ```shell 171 | MODEL=bigcode/starcoder2-15b 172 | MAX_NEW_DATA=1000000 173 | python src/star_align/self_ossinstruct.py \ 174 | --instruct_mode "I->R" \ 175 | --seed_data_files path/to/instructions.jsonl \ 176 | --max_new_data $MAX_NEW_DATA \ 177 | --tag response_gen \ 178 | --seed_code_start_index 0 \ 179 | --model $MODEL \ 180 | --num_fewshots 1 \ 181 | --num_batched_request 500 \ 182 | --num_sample_per_request 10 \ 183 | --temperature 0.7 184 | ``` 185 | 186 |
187 | 188 |
189 | 190 | Execution filter 191 | 192 | > **Warning:** Though we implemented reliability guards, it is highly recommended to run execution in a sandbox environment we provided. 193 | 200 | 201 | To use the Docker container for executing code, you will first need to `git submodule update --init --recursive` to clone the server, then run: 202 | 203 | ```shell 204 | pushd ./src/star_align/code_exec_server 205 | ./pull_and_run.sh 206 | popd 207 | python src/star_align/execution_filter.py \ 208 | --response_paths /path/to/response.jsonl \ 209 | --result_path /path/to/filtered.jsonl \ 210 | --max_batched_tasks 10000 \ 211 | --container_server http://127.0.0.1:8000 212 | ``` 213 | 214 | Execution filter will produce a flattened list of JSONL entries with a `pass` field indicating whether the execution passed or not. **It also incrementally dumps the results and can load a cached partial data file.** You can recover an execution with: 215 | 216 | ```shell 217 | python src/star_align/execution_filter.py \ 218 | --response_paths /path/to/response.jsonl* \ 219 | --cache_paths /path/to/filtered.jsonl* \ 220 | --result_path /path/to/filtered-1.jsonl \ 221 | --max_batched_tasks 10000 \ 222 | --container_server http://127.0.0.1:8000 223 | ``` 224 | 225 | Note that sometimes execution can lead to significant slowdowns due to excessive resource consumption. To alleviate this, you can limit the docker's cpu usage (e.g., `docker run --cpuset-cpus="0-31"`). You can also do: 226 | 227 | ```shell 228 | # For example, you can set the command to be `sudo pkill -f '/tmp/codeexec'` 229 | export CLEANUP_COMMAND="the command to execute after each batch" 230 | python src/star_align/execution_filter.py... 231 | ``` 232 | 233 | Also, the container connection may be lost during execution. In this case, you can just leverage the caching mechanism described above to re-run the script. 234 | 235 |
236 | 237 |
238 | 239 | Data sanitization and selection 240 | 241 | ```shell 242 | # Uncomment to do decontamination 243 | # export MBPP_PATH="/path/to/mbpp.jsonl" 244 | # export DS1000_PATH="/path/to/ds1000_data" 245 | # export DECONTAMINATION=1 246 | ./sanitize.sh /path/to/exec-filtered.jsonl /path/to/sanitized.jsonl 247 | ``` 248 | 249 |
250 | 251 | ## Training Details 252 | 253 | > Run `pip install -e .` first to install the package locally. And install [Flash Attention](https://github.com/Dao-AILab/flash-attention) to speed up the training. 254 | 255 | ### Hyperparameters 256 | 257 | - **Optimizer:** Adafactor 258 | - **Learning rate:** 1e-5 259 | - **Epoch:** 4 260 | - **Batch size:** 64 261 | - **Warmup ratio:** 0.05 262 | - **Scheduler:** Linear 263 | - **Sequence length:** 1280 264 | - **Dropout**: Not applied 265 | 266 | ### Hardware 267 | 268 | 1 x NVIDIA A100 80GB. Yes, you just need one A100 to finetune StarCoder2-15B! 269 | 270 | ### Script 271 | 272 | The following script finetunes StarCoder2-15B-Instruct-v0.1 from the base StarCoder2-15B model. `/path/to/dataset.jsonl` is the JSONL format of the [50k dataset](https://huggingface.co/datasets/bigcode/self-oss-instruct-sc2-exec-filter-50k) we generated. You can dump the dataset to JSONL to fit the training script. 273 | 274 |
275 | 276 | Click to see the training script 277 | 278 | NOTE: StarCoder2-15B sets dropout values to 0.1 by default. We did not apply dropout in finetuning and thus set the them to 0.0. 279 | 280 | ```shell 281 | MODEL_KEY=bigcode/starcoder2-15b 282 | LR=1e-5 283 | EPOCH=4 284 | SEQ_LEN=1280 285 | WARMUP_RATIO=0.05 286 | OUTPUT_DIR=/path/to/output_model 287 | DATASET_FILE=/path/to/50k-dataset.jsonl 288 | accelerate launch -m star_align.train \ 289 | --model_key $MODEL_KEY \ 290 | --model_name_or_path $MODEL_KEY \ 291 | --use_flash_attention True \ 292 | --datafile_paths $DATASET_FILE \ 293 | --output_dir $OUTPUT_DIR \ 294 | --bf16 True \ 295 | --num_train_epochs $EPOCH \ 296 | --max_training_seq_length $SEQ_LEN \ 297 | --pad_to_max_length False \ 298 | --per_device_train_batch_size 1 \ 299 | --gradient_accumulation_steps 64 \ 300 | --group_by_length False \ 301 | --ddp_find_unused_parameters False \ 302 | --logging_steps 1 \ 303 | --log_level info \ 304 | --optim adafactor \ 305 | --max_grad_norm -1 \ 306 | --warmup_ratio $WARMUP_RATIO \ 307 | --learning_rate $LR \ 308 | --lr_scheduler_type linear \ 309 | --attention_dropout 0.0 \ 310 | --residual_dropout 0.0 \ 311 | --embedding_dropout 0.0 312 | ``` 313 | 314 |
315 | 316 | ## Evaluation on EvalPlus, LiveCodeBench, and DS-1000 317 | 318 | > Check [evaluation](evaluation/) for more details. 319 | 320 | ![EvalPlus](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/evalplus.png) 321 | 322 | ![LiveCodeBench and DS-1000](https://huggingface.co/datasets/bigcode/starcoder2-instruct-assets/resolve/main/lcb-ds1000.png) 323 | 324 | ## Bias, Risks, and Limitations 325 | 326 | StarCoder2-15B-Instruct-v0.1 is primarily finetuned for Python code generation tasks that can be verified through execution, which may lead to certain biases and limitations. For example, the model might not adhere strictly to instructions that dictate the output format. In these situations, it's beneficial to provide a **response prefix** or a **one-shot example** to steer the model’s output. Additionally, the model may have limitations with other programming languages and out-of-domain coding tasks. 327 | 328 | The model also inherits the bias, risks, and limitations from its base StarCoder2-15B model. For more information, please refer to the [StarCoder2-15B model card](https://huggingface.co/bigcode/starcoder2-15b). 329 | -------------------------------------------------------------------------------- /seed_gathering/filter_dataset.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import os 3 | from tree_sitter_parser import global_parser, LANGUAGE, does_have_return, make_parser 4 | import benchmark_data 5 | from tqdm import tqdm 6 | import torch 7 | import argparse 8 | from vllm import LLM, SamplingParams 9 | import random 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--dataset', type=str, required=True) 13 | parser.add_argument('--model', type=str, 14 | default="bigcode/starcoder2-15b") 15 | parser.add_argument('--batch-size', type=int, default=512) 16 | parser.add_argument('--sample-size', type=int, default=None) 17 | parser.add_argument('--num-gpus', type=int, default=1) 18 | parser.add_argument('--content_col', type=str, default="content") 19 | parser.add_argument('--push', type=str, required=True) 20 | args = parser.parse_args() 21 | random.seed(42) 22 | 23 | FN_BLOCK_QUERY = LANGUAGE.query(""" 24 | (function_definition 25 | body: (block) @fn-block) 26 | """) 27 | 28 | 29 | def template_few_shot(code, answer, rationale): 30 | doc, code = py_extract_docstring(code) 31 | assert answer == "No" or answer == "Yes" 32 | prompt = f"""username_0: I have a function in Python and I'd like someone to check my description of this function. 33 | I'm doing this so that I can write a good docstring for this function. 34 | 35 | Here is the code for the function: 36 | ```py 37 | {code} 38 | ``` 39 | 40 | Here is my description of this program: 41 | ``` 42 | {doc} 43 | ``` 44 | 45 | Do not attempt to execute the function or to judge its correctness. 46 | Answer with "Yes" or "No" depending on if my description has enough information alone to re-implement the function. 47 | Also, answer with "No" if the description does not match the function.username_1: Sure, no problem. I will be able to help. 48 | My answer is: {answer} 49 | 50 | {rationale} 51 | 52 | Upvotes: 200""" 53 | return prompt 54 | 55 | 56 | FEW_SHOTS = [ 57 | ( 58 | '''def simple_scan_network(): 59 | """ 60 | Do a simple network scan, which only works if your network configuration 61 | is 192.168.1.x 62 | """ 63 | base_ip = "192.168.1." 64 | addresses = ['127.0.0.1'] 65 | 66 | for index in range(1, 255): 67 | addresses.extend([base_ip + str(index)]) 68 | 69 | return addresses''', 70 | "No", 71 | "The simple_scan_network function you have provided seems to generate addresses that then would be used for a network scan, but does not actually perform it, unlike the function claims.", 72 | ), 73 | ( 74 | '''import pandas 75 | 76 | 77 | def coerce_integer(df): 78 | """ 79 | Loop through the columns of a df, if it is numeric, 80 | convert it to integer and fill nans with zeros. 81 | This is somewhat heavy-handed in an attempt to force 82 | Esri to recognize sparse columns as integers. 83 | """ 84 | # Numeric columns to not coerce to integer 85 | EXCEPT = ["latitude", "longitude", "zipCode"] 86 | 87 | def numeric_column_to_int(series): 88 | return ( 89 | series.fillna(0).astype(int) 90 | if pandas.api.types.is_numeric_dtype(series) and series.name not in EXCEPT 91 | else series 92 | ) 93 | 94 | return df.transform(numeric_column_to_int, axis=0)''', 95 | "Yes", 96 | "The docstring does seem to match the implementation! The function loops through the columns of a df and coerces it as explained.", 97 | ), 98 | ('''def __trans_df_into_dict(data): 99 | """Converte DataFrame to dictionary. 100 | 101 | Args: 102 | data (pandas.DataFrame): DataFrame. 103 | 104 | Returns: 105 | dict: Name dictionary. 106 | """ 107 | data["en_name"] = data["en_name"].str.upper() 108 | data["en_name_f"] = data["en_name"].str.split(" ", expand=True)[0] 109 | data["en_name_l"] = data["en_name"].str.split(" ", expand=True)[1] 110 | data["jp_name_f"] = data["jp_name"].str.split("・", expand=True)[0] 111 | data["jp_name_l"] = data["jp_name"].str.split("・", expand=True)[1] 112 | fullname_dict = dict(zip(data["en_name"], data["jp_name"])) 113 | fname_dict = dict(zip(data["en_name_f"], data["jp_name_f"])) 114 | lname_dict = dict(zip(data["en_name_l"], data["jp_name_l"])) 115 | return fullname_dict, fname_dict, lname_dict''', 116 | "No", 117 | "The function__trans_df_into_dict does indeed convert a dataframe into a dictionary, however, it converts various columns that were not described in the docstring.\nFor instance, nowhere in the docstring it mentions handling japanese characters or the name of the column.", 118 | ), 119 | ( 120 | '''def inchesToMeters(inches): 121 | """Convert inches to meters.""" 122 | return inches * 0.0254''', 123 | "Yes", 124 | "inchesToMeters is a very simple function, the doccstring explains concisely its purpose, which is of converting inches to meters.", 125 | ), 126 | ('''def square_crop(im, target_size=None): 127 | """ Crop image to `target_size`. If that's None the image is squared 128 | to the smallest size 129 | """ 130 | 131 | w = im.size[0] 132 | h = im.size[1] 133 | 134 | target_size = target_size if target_size else min(w, h) 135 | 136 | dx = (w - target_size) / 2 137 | dy = (h - target_size) / 2 138 | 139 | return im.crop((dx, dy, dx + target_size, dy + target_size))''', 140 | "Yes", 141 | "Following the standard description for docstrings for functions and methods, the square_crop function description tells exactly what the function does." 142 | ), 143 | ('''def _setup_motifs_files(args): 144 | """convenience fn, make sure setup is same across 145 | multiplicity/orientation/spacing workflows 146 | """ 147 | motifs_files = {} 148 | motifs_files["early"] = "{}/{}/ggr.scanmotifs.h5".format( 149 | args.inputs["inference"][args.cluster]["scanmotifs_dir"], 150 | args.inputs["inference"][args.cluster]["scanmotifs_early_dir"]) 151 | motifs_files["mid"] = "{}/{}/ggr.scanmotifs.h5".format( 152 | args.inputs["inference"][args.cluster]["scanmotifs_dir"], 153 | args.inputs["inference"][args.cluster]["scanmotifs_mid_dir"]) 154 | motifs_files["late"] = "{}/{}/ggr.scanmotifs.h5".format( 155 | args.inputs["inference"][args.cluster]["scanmotifs_dir"], 156 | args.inputs["inference"][args.cluster]["scanmotifs_late_dir"]) 157 | 158 | return motifs_files''', 159 | "No", 160 | "The docstring for _setup_motifs_files just says this is a convenience function. There is definitely not enough information to re-implement this function from the docstring alone.", 161 | ), 162 | ('''def trip(u, v): 163 | """ 164 | Returns the scalar triple product of vectors u and v and z axis. 165 | The convention is z dot (u cross v). Dotting with the z axis simplifies 166 | it to the z component of the u cross v 167 | The product is: 168 | positive if v is to the left of u, that is, 169 | the shortest right hand rotation from u to v is ccw 170 | negative if v is to the right of u, that is, 171 | the shortest right hand rotation from u to v is cw 172 | zero if v is colinear with u 173 | Essentially trip is the z component of the cross product of u x v 174 | """ 175 | return (u[0] * v[1] - u[1] * v[0])''', 176 | "Yes", 177 | "The docstring for the trip function is very detailed and describes the function's purpose and the mathematical formula used to calculate the scalar triple product.", 178 | ) 179 | ] 180 | 181 | 182 | def prompt_fmt(code): 183 | doc, code = py_extract_docstring(code) 184 | random.shuffle(FEW_SHOTS) 185 | buf = "" 186 | for few in FEW_SHOTS: 187 | buf += template_few_shot(*few) 188 | buf += f"""username_0: I have a function in Python and I'd like someone to check my description of this function. 189 | I'm doing this so that I can write a good docstring for this function. 190 | 191 | Here is the code for the function: 192 | ```py 193 | {code} 194 | ``` 195 | 196 | Here is my description of this program: 197 | ``` 198 | {doc} 199 | ``` 200 | 201 | Do not attempt to execute the function or to judge its correctness. 202 | Answer with "Yes" or "No" depending on if my description has enough information alone to re-implement the function. 203 | Also, answer with "No" if the description does not match the function. 204 | Upvotes: 100username_1: Sure, no problem. I will be able to help. 205 | My answer is:""" 206 | return buf 207 | 208 | 209 | def auto_dtype(): 210 | if torch.cuda.is_bf16_supported(): 211 | return "bfloat16" 212 | return "auto" 213 | 214 | 215 | def chunkify(lst, n): 216 | chunks = [] 217 | for i in range(0, len(lst), n): 218 | chunk = [] 219 | for j in range(n): 220 | if i + j < len(lst): 221 | chunk.append(lst[i + j]) 222 | chunks.append(chunk) 223 | return chunks 224 | 225 | 226 | dataset = datasets.load_dataset(args.dataset, split="train") 227 | print(f"Loaded {len(dataset)} examples. Running pre-filtering...") 228 | 229 | BAD_WORDS = ["todo", "fixme", "bug"] 230 | BAD_IMPORTS = ["argparse", "os", "subprocess", "sys", "setuptools", 231 | "distutils", "matplotlib", "seaborn"] 232 | BAD_IMPORTS = [f"import {b}" for b in BAD_IMPORTS] + \ 233 | [f"from {b}" for b in BAD_IMPORTS] 234 | BAD_SUBSTRINGS = BAD_WORDS + BAD_IMPORTS 235 | 236 | bench_filter = benchmark_data.filter_out() 237 | all_bench = bench_filter["human_eval_docstrings"] + \ 238 | bench_filter["human_eval_solutions"] + \ 239 | bench_filter["mbpp_docstrings"] + \ 240 | bench_filter["mbpp_solutions"] 241 | 242 | 243 | def pre_filtering(ex): 244 | code = ex[args.content_col] 245 | code_bytes = code.encode('utf-8') 246 | 247 | # filter out bad substrings 248 | lower = code.lower() 249 | for word in BAD_SUBSTRINGS: 250 | if word in lower: 251 | return False 252 | 253 | for b in all_bench: 254 | if b in code: # contaminated sample! 255 | return False 256 | 257 | # too many lines of code -- say 150 258 | lines = code.split("\n") 259 | if len(lines) > 150: 260 | return False 261 | 262 | # filter functions which don't have an argument 263 | # 1. find first def statement in lines 264 | # 2. check if contains (): 265 | for line in lines: 266 | if line.startswith("def ") and "():" in line: 267 | return False 268 | 269 | # filter out functions with no return statement 270 | parser = make_parser() 271 | if not does_have_return(code, parser=parser): 272 | return False 273 | 274 | try: 275 | tree = global_parser.parse(code_bytes) 276 | block, _ = FN_BLOCK_QUERY.captures(tree.root_node)[0] 277 | 278 | # get the docstring, filter if not a docstring 279 | exp = block.children[0] 280 | if not exp.type == 'expression_statement' and not exp.children[0].type == 'string': 281 | return False 282 | 283 | docstring = exp.children[0] 284 | docstring_text = docstring.text.decode('utf-8') 285 | if not docstring_text.startswith('"""') and not docstring_text.endswith('"""'): 286 | return False 287 | except Exception as e: 288 | print(f"Error in filtering: {e}") 289 | return False 290 | 291 | return True # all good! 292 | 293 | 294 | threads = os.cpu_count() - 1 # type: ignore 295 | dataset = dataset.filter(pre_filtering, num_proc=threads) 296 | 297 | model = LLM(args.model, dtype=auto_dtype(), 298 | gpu_memory_utilization=0.95, tensor_parallel_size=args.num_gpus) 299 | tokenizer = model.get_tokenizer() 300 | 301 | if args.sample_size is not None: 302 | dataset = dataset.shuffle() 303 | dataset = dataset.select(range(args.sample_size)) 304 | 305 | 306 | print(f"Now running stage 3 filtering on {len(dataset)} examples...") 307 | 308 | 309 | def unindent(s): 310 | lines = s.splitlines() 311 | non_blank_lines = [line for line in lines if line.strip()] 312 | min_indent = min(len(line) - len(line.lstrip()) 313 | for line in non_blank_lines) if non_blank_lines else 0 314 | unindented_lines = [line[min_indent:] if len( 315 | line) >= min_indent else line for line in lines] 316 | return '\n'.join(unindented_lines) 317 | 318 | 319 | def py_extract_docstring(code): 320 | first_doc = code.find('"""') 321 | assert first_doc != -1 322 | first_doc = first_doc + 3 323 | second_doc = code[first_doc+1:].find('"""') 324 | assert second_doc != -1 325 | second_doc = second_doc + first_doc + 1 326 | doc = code[first_doc:second_doc] 327 | doc = unindent(doc).strip() 328 | code = code[:first_doc-3] + code[second_doc+3:] 329 | return doc, code 330 | 331 | 332 | # this is such a hack, but it works 333 | dummy = 'def dummy(): \n """\n """\n pass' 334 | dummy_prompt = prompt_fmt(dummy) 335 | few_shot_toks = len(tokenizer.encode( 336 | dummy_prompt)) - len(tokenizer.encode(dummy)) 337 | print(f"Few-shot prompt has {few_shot_toks} tokens") 338 | prompts = [] 339 | for ex in tqdm(dataset, total=len(dataset), desc="Generating prompts"): 340 | code = ex[args.content_col] 341 | toks = len(tokenizer.encode(code)) + few_shot_toks 342 | if toks > 16380: 343 | print(f"Skipping example with {toks} tokens") 344 | # to skip, just add dummy prompt 345 | prompts.append(dummy_prompt) 346 | continue 347 | p = prompt_fmt(code) 348 | prompts.append(p) 349 | 350 | responses = [] 351 | for chunk in tqdm(chunkify(prompts, args.batch_size), desc="Generating responses"): 352 | outs = model.generate(chunk, SamplingParams( 353 | temperature=0.0, stop="\n", max_tokens=5)) 354 | contents = [o.outputs[0].text for o in outs] 355 | for c in contents: 356 | yes_count = c.lower().count("yes") 357 | no_count = c.lower().count("no") 358 | if yes_count > no_count: 359 | responses.append(True) 360 | elif yes_count < no_count: 361 | responses.append(False) 362 | else: 363 | # default to No 364 | responses.append(False) 365 | 366 | 367 | new_ds = dataset.filter( # horrible hack! 368 | lambda ex, i: responses[i] and "def dummy()" not in ex[args.content_col], with_indices=True) 369 | print(f"Filtered {len(dataset) - len(new_ds)} examples") 370 | new_ds.push_to_hub(args.push, private=True) 371 | -------------------------------------------------------------------------------- /src/star_align/decontamination/find_substrings.py: -------------------------------------------------------------------------------- 1 | # type: ignore 2 | """Migrated from: https://github.com/bigcode-project/bigcode-dataset. License: Apache 2.0""" 3 | 4 | import argparse 5 | import json 6 | import os 7 | import shutil 8 | from copy import deepcopy 9 | from glob import glob 10 | from pathlib import Path 11 | 12 | from datasets import load_dataset 13 | 14 | from star_align.utils import write_jsonl 15 | 16 | from .benchmark_data import FILTER_OUT 17 | from .utils import add_dict, shard_dataset 18 | 19 | SHARD_SIZE = 1000 << 20 # 1GB 20 | LANGUAGE_COL = "lang" 21 | # LANGUAGES = ["Python", "Java", "JavaScript"] 22 | 23 | 24 | def dump_benchmarks(file_path: str): 25 | """ 26 | Dump the dictionary of benchmark samples that are filtered out 27 | """ 28 | with open(file_path, "w") as f: 29 | json.dump(FILTER_OUT, f, indent=2) 30 | 31 | 32 | def filter_reason_to_benchmark_name(filter_reason: str): 33 | assert filter_reason.endswith("_match") 34 | return filter_reason[:-6] 35 | 36 | 37 | def benchmark_name_to_filter_reason(benchmark_name: str): 38 | return f"{benchmark_name}_match" 39 | 40 | 41 | def update_benchmark_dict( 42 | filter_out: dict, benchmark_cache: str, excluded_data_cache: str 43 | ): 44 | """ 45 | Iterates on current benchmark-samples. If a sample is found in the cached benchmark-samples, it is removed (it does not need to be searched), 46 | and the corresponding data-samples from the cache are added to `exclude_data` 47 | 48 | Returns: 49 | - `updated`: an updated benchmark dict where samples from the cache are removed (they do not need to be searched anymore) 50 | - `exclude_data`: a list of files to remove from the dataset 51 | """ 52 | updated = deepcopy(filter_out) 53 | exclude_data = [] 54 | with open(benchmark_cache) as f: 55 | benchmark_cache = json.load(f) 56 | with open(excluded_data_cache) as f: 57 | excluded_data_cache = json.load(f) 58 | 59 | for bench, samples in filter_out.items(): 60 | for bench_sample in samples: 61 | # Benchmark-sample was found in cache 62 | if bench in benchmark_cache and bench_sample in benchmark_cache[bench]: 63 | # No need to search for this sample in the dataset 64 | updated[bench].remove(bench_sample) 65 | # Corresponding data-samples will be excluded from the dataset. 66 | exclude_data += [ 67 | data_sample 68 | for data_sample in excluded_data_cache 69 | if data_sample["filter_reason"] 70 | == benchmark_name_to_filter_reason(bench) 71 | and data_sample["matched_substring"] == bench_sample 72 | ] 73 | 74 | print("After loading cache, will search for:") 75 | for benchmark, values in updated.items(): 76 | print(f" num strings from {benchmark}: {len(values)}") 77 | # Remove empty benchmarks 78 | updated = {key: value for key, value in updated.items() if len(value) > 0} 79 | return updated, exclude_data 80 | 81 | 82 | def find_substrings(data, columns, filter_out, return_matched=False): 83 | """ 84 | filter_out: Dict[str, List[str]] mapping from benchmark name to list of strings that need to be 85 | filtered-out. 86 | Return True, None if the file should be included in the dataset. 87 | Otherwise return False and some metadata about the file excluded 88 | """ 89 | content = "\n\n".join([data[col].lower() for col in columns]) 90 | # For each substring, try to find it in the file (case insensitive) 91 | for benchmark, substrings in filter_out.items(): 92 | for substring in substrings: 93 | if substring.lower() in content: 94 | if return_matched: 95 | return False, benchmark_name_to_filter_reason(benchmark), substring 96 | else: 97 | return False, benchmark_name_to_filter_reason(benchmark) 98 | 99 | # Return True, None if none of the substrings was found 100 | if return_matched: 101 | return True, None, None 102 | else: 103 | return True, None 104 | 105 | 106 | def aggregate_meta(tmp_meta_dir: str): 107 | res = {} 108 | for file in glob(f"{tmp_meta_dir}/*-meta.json"): 109 | with open(file, "r") as f: 110 | meta = json.load(f) 111 | add_dict(res, meta) 112 | return res 113 | 114 | 115 | def concatenate_meta(tmp_meta_dir: str): 116 | res = [] 117 | for file in glob(f"{tmp_meta_dir}/*-excluded-data.json"): 118 | with open(file, "r") as f: 119 | meta = json.load(f) 120 | res += meta 121 | return res 122 | 123 | 124 | class Meta: 125 | def __init__(self) -> None: 126 | self.meta_dict = dict() 127 | 128 | def update(self, lang: str, filter_reason: str): 129 | if lang not in self.meta_dict: 130 | self.meta_dict[lang] = {} 131 | if filter_reason not in self.meta_dict[lang]: 132 | self.meta_dict[lang][filter_reason] = 0 133 | self.meta_dict[lang][filter_reason] += 1 134 | 135 | 136 | class SubstringFilterer(object): 137 | def __init__( 138 | self, 139 | output_dir: str, 140 | output_file: str, 141 | cached_decontamination_dir: str, 142 | split_languages: bool, 143 | cache_retrieval_key: str, 144 | columns: list[str], 145 | tmp_meta_dir=None, 146 | data_dir=None, 147 | ) -> None: 148 | self.output_dir = output_dir 149 | self.output_file = output_file 150 | self.split_languages = split_languages 151 | self.cache_retrieval_key = cache_retrieval_key 152 | self.columns = columns 153 | self.tmp_meta_dir = ( 154 | tmp_meta_dir if tmp_meta_dir is not None else f"{output_dir}/tmp/meta" 155 | ) 156 | self.data_dir = data_dir if data_dir is not None else f"{output_dir}/data" 157 | os.makedirs(self.tmp_meta_dir, exist_ok=True) 158 | os.makedirs(self.data_dir, exist_ok=True) 159 | # Save benchmark data 160 | self.excluded_data_cache = os.path.join(self.output_dir, "excluded-data.json") 161 | self.benchmarks_cache = os.path.join(output_dir, "benchmarks.json") 162 | dump_benchmarks(self.benchmarks_cache) 163 | 164 | if cached_decontamination_dir is not None: 165 | # Load cache 166 | self.filter_out, self.exclude_data = update_benchmark_dict( 167 | FILTER_OUT, 168 | os.path.join(cached_decontamination_dir, "benchmarks.json"), 169 | os.path.join(cached_decontamination_dir, "excluded-data.json"), 170 | ) 171 | # All hashes should be unique 172 | hash_list = [ 173 | data_sample["data"][self.cache_retrieval_key] 174 | for data_sample in self.exclude_data 175 | ] 176 | assert len(hash_list) == len(set(hash_list)) 177 | # dict: retrieval-key (hash/content) -> data-sample 178 | self.exclude_data_index = { 179 | data_sample["data"][self.cache_retrieval_key]: data_sample 180 | for data_sample in self.exclude_data 181 | } 182 | self.use_cached_decontamination = True 183 | else: 184 | self.filter_out = FILTER_OUT 185 | self.exclude_data = None 186 | self.exclude_data_index = {} 187 | self.use_cached_decontamination = False 188 | 189 | def _filter_file(self, sample): 190 | should_include, filter_reason, matched_substring = True, None, None 191 | if self.use_cached_decontamination: 192 | # According to cache, this data sample should be excluded 193 | if sample[self.cache_retrieval_key] in self.exclude_data_index: 194 | should_include = False 195 | filter_reason = self.exclude_data_index[ 196 | sample[self.cache_retrieval_key] 197 | ]["filter_reason"] 198 | matched_substring = self.exclude_data_index[ 199 | sample[self.cache_retrieval_key] 200 | ]["matched_substring"] 201 | # If sample has passed the cache, check the other substrings 202 | if should_include: 203 | should_include, filter_reason, matched_substring = find_substrings( 204 | sample, self.columns, self.filter_out, return_matched=True 205 | ) 206 | return should_include, filter_reason, matched_substring 207 | 208 | def _filter(self, batch: dict, idx): 209 | meta = Meta() 210 | excluded_data = [] 211 | features = batch.keys() 212 | res = {k: [] for k in features} 213 | for sample in zip(*[batch[k] for k in features]): 214 | sample = {k: v for k, v in zip(features, sample)} 215 | should_include, filter_reason, matched_substring = self._filter_file(sample) 216 | if not should_include: 217 | meta.update(sample.get(LANGUAGE_COL, "unknown"), filter_reason) 218 | excluded_data.append( 219 | { 220 | "data": sample, 221 | "filter_reason": filter_reason, 222 | "matched_substring": matched_substring, 223 | } 224 | ) 225 | else: 226 | # Add to output 227 | for k in features: 228 | res[k].append(sample[k]) 229 | 230 | # Record Meta 231 | with open( 232 | os.path.join(self.tmp_meta_dir, f"{idx[0]}-{idx[-1]}-meta.json"), "w" 233 | ) as f: 234 | json.dump(meta.meta_dict, f) 235 | with open( 236 | os.path.join(self.tmp_meta_dir, f"{idx[0]}-{idx[-1]}-excluded-data.json"), 237 | "w", 238 | ) as f: 239 | json.dump(excluded_data, f, indent=2) 240 | return res 241 | 242 | def filter_dataset(self, ds, num_proc, batch_size): 243 | filtered = ds.map( 244 | self._filter, 245 | batched=True, 246 | batch_size=batch_size, 247 | with_indices=True, 248 | num_proc=num_proc, 249 | load_from_cache_file=False, 250 | ) 251 | print("Number of samples in the new dataset: ", len(filtered)) 252 | return filtered 253 | 254 | def finalize(self): 255 | # Dump meta 256 | meta = aggregate_meta(self.tmp_meta_dir) 257 | print(meta) 258 | with open(os.path.join(self.output_dir, "meta.json"), "w") as f: 259 | json.dump(meta, f, indent=2) 260 | # Dump excluded-data.json 261 | meta = concatenate_meta(self.tmp_meta_dir) 262 | print("Number of excluded examples: ", len(meta)) 263 | with open(self.excluded_data_cache, "w") as f: 264 | json.dump(meta, f, indent=2) 265 | # delete temporary meta data 266 | shutil.rmtree(self.tmp_meta_dir) 267 | 268 | # def save(self, filtered, num_proc): 269 | # # Save shards 270 | # if self.split_languages: 271 | # for lang in LANGUAGES: 272 | # print(f"Sharding subset: {lang}") 273 | # target_dir = os.path.join(self.data_dir, lang.lower()) 274 | # os.makedirs(target_dir, exist_ok=True) 275 | # subset = filtered.filter(lambda example: example[LANGUAGE_COL] == lang, num_proc=num_proc) 276 | # shard_dataset(subset, SHARD_SIZE, target_dir, num_proc=16) 277 | # else: 278 | # shard_dataset(filtered, SHARD_SIZE, self.data_dir, num_proc=16) 279 | 280 | def run(self, dataset, num_proc, batch_size): 281 | filtered = self.filter_dataset(dataset, num_proc, batch_size) 282 | write_jsonl(Path(self.output_file), filtered) 283 | # Finalize meta-data 284 | self.finalize() 285 | # Save filtered dataset. 286 | # NOTE: we save to jsonl so this is not needed 287 | # self.save(filtered, num_proc) 288 | return filtered 289 | 290 | 291 | def arguments(): 292 | parser = argparse.ArgumentParser() 293 | parser.add_argument( 294 | "--dataset_name", 295 | default="json", 296 | type=str, 297 | help="Name or path of the HF dataset to decontaminate", 298 | ) 299 | parser.add_argument("--data_files", nargs="+", default=None, help="Data files") 300 | parser.add_argument( 301 | "--columns", 302 | nargs="+", 303 | required=True, 304 | help="Columns to form the text to search for", 305 | ) 306 | parser.add_argument( 307 | "--output_file", required=True, type=str, help="Path to save output jsonl data" 308 | ) 309 | parser.add_argument( 310 | "--output_dir", 311 | required=True, 312 | type=str, 313 | help="Path to save output data and metadata", 314 | ) 315 | parser.add_argument("--num_proc", type=int, default=200, help="Number of processes") 316 | parser.add_argument( 317 | "--batch_size", 318 | type=int, 319 | default=10000, 320 | help="Size of batches passed to Dataset.map", 321 | ) 322 | parser.add_argument( 323 | "--cached_decontamination_dir", 324 | type=str, 325 | default=None, 326 | help="Directory containing a `benchmarks.json` and `excluded_data.json` files from a previous decontamination run." 327 | "Will use this data to avoid searching again for strings that were previously decontaminated." 328 | "It's up to the user to ensure that the dataset being decontaminated is a subset of the one from the cached decontamination run" 329 | "(Otherwise not all the benchmark samples will be checked against new data samples)", 330 | ) 331 | parser.add_argument( 332 | "--cache_retrieval_key", 333 | type=str, 334 | default="hexsha", 335 | help="Key used to retrieve examples from the cache. Ideally `hexsha`. Otherwise, another unique feature in case the hash is not present, like `content`)", 336 | ) 337 | parser.add_argument( 338 | "--split_languages", 339 | action="store_true", 340 | help="If True, will create one subfolder per language for the output dataset.", 341 | ) 342 | return parser.parse_args() 343 | 344 | 345 | def main(): 346 | args = arguments() 347 | 348 | filterer = SubstringFilterer( 349 | output_dir=args.output_dir, 350 | output_file=args.output_file, 351 | columns=args.columns, 352 | cached_decontamination_dir=args.cached_decontamination_dir, 353 | split_languages=args.split_languages, 354 | cache_retrieval_key=args.cache_retrieval_key, 355 | ) 356 | 357 | ds = load_dataset( 358 | args.dataset_name, 359 | split="train", 360 | data_files=args.data_files, 361 | # chunksize=40 << 20 362 | ) 363 | 364 | filterer.run(ds, args.num_proc, args.batch_size) 365 | 366 | 367 | if __name__ == "__main__": 368 | main() 369 | -------------------------------------------------------------------------------- /src/star_align/minhash_dedup.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import random 4 | import multiprocessing as mp 5 | import os 6 | import re 7 | from collections import defaultdict 8 | from typing import Any, Callable, List 9 | 10 | import click 11 | import datasets 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | import pickle # nosec 16 | from collections import Counter 17 | from pathlib import Path 18 | from itertools import tee 19 | 20 | from scipy.integrate import quad as integrate 21 | 22 | import hashlib 23 | import struct 24 | from hashlib import md5 25 | from hashlib import sha256 26 | 27 | import xxhash 28 | from xxhash import xxh3_64 29 | from xxhash import xxh3_64_digest 30 | from xxhash import xxh3_128 31 | from xxhash import xxh3_128_digest 32 | 33 | 34 | parser = argparse.ArgumentParser() 35 | # IO Args 36 | parser.add_argument("--data_files", type=str, required=True) 37 | parser.add_argument("--output", type=str, required=True) 38 | parser.add_argument("--num_proc", type=int, default=os.cpu_count()) 39 | # Meta Args 40 | parser.add_argument("--column", type=str, required=True) 41 | parser.add_argument("--batch_size", type=int, default=10_000) 42 | # MinHash Args 43 | parser.add_argument("--ngram", type=int, default=5) 44 | parser.add_argument("--min_length", type=int, default=5) 45 | parser.add_argument("--ignore_empty", type=bool, default=False) 46 | parser.add_argument("--seed", type=int, default=42) 47 | parser.add_argument("--num_perm", type=int, default=250) 48 | parser.add_argument("--threshold", type=float, default=0.7) 49 | parser.add_argument("--b", type=int, default=None) 50 | parser.add_argument("--r", type=int, default=None) 51 | parser.add_argument("--hash_func", type=str, default="sha1") 52 | parser.add_argument("--hash_bits", type=int, default=64) 53 | args = parser.parse_args() 54 | 55 | 56 | def ngrams(sequence: List[str], n: int, min_length: int = 5): 57 | """ 58 | Return the ngrams generated from a sequence of items, as an iterator. 59 | 60 | This is a modified version of nltk.util.ngrams. 61 | """ 62 | if len(sequence) < min_length: 63 | return [] 64 | if len(sequence) < n: 65 | return [tuple(sequence)] 66 | iterables = tee(iter(sequence), n) 67 | for i, sub_iterable in enumerate(iterables): 68 | for _ in range(i): 69 | next(sub_iterable, None) 70 | return zip(*iterables) 71 | 72 | 73 | class UnionFind: 74 | """ 75 | A data structure for maintaining disjoint sets. This helps build connected components for given duplicate pairs. 76 | """ 77 | 78 | def __init__(self): 79 | self.parent = {} 80 | # Counter is a subclass of dict with slightly different python and c implementations 81 | # you can think of it as an optimized defaultdict(int) 82 | self.rank = Counter() 83 | 84 | def find(self, x): 85 | try: 86 | # path compression 87 | if self.parent[x] != x: 88 | self.parent[x] = self.find(self.parent[x]) 89 | except KeyError: 90 | # KeyError happens if x not in parent 91 | self.parent[x] = x 92 | finally: 93 | return self.parent[x] 94 | 95 | def union(self, x, y): 96 | px = self.find(x) 97 | py = self.find(y) 98 | 99 | # If both elements are already in the same set, do nothing 100 | # The line in original UnionFind `self.parent[px] = self.parent[py] = min(px, py)` is redundant when px == py 101 | if px == py: 102 | return 103 | 104 | if self.rank[px] == self.rank[py]: 105 | # If ranks are equal, choose one as the new root and increment its rank 106 | # with few duplicates this is likely to be the most common case 107 | self.parent[py] = px 108 | self.rank[px] += 1 109 | # otherwise, assume that leftside is more likely to be higher rank 110 | # Attach the smaller rank tree under the root of the larger rank tree 111 | elif self.rank[px] > self.rank[py]: 112 | self.parent[py] = px 113 | else: 114 | self.parent[px] = py 115 | 116 | def reset(self): 117 | self.parent = {} 118 | self.rank = Counter() 119 | 120 | def dump(self, path: str | Path, id2id=None): 121 | if id2id is not None: 122 | new_uf = UnionFind() 123 | for i in self.parent: 124 | new_uf.union(id2id[i], id2id[self.find(i)]) 125 | else: 126 | new_uf = self 127 | 128 | with open(path, "wb") as f: 129 | pickle.dump(new_uf, f, protocol=pickle.HIGHEST_PROTOCOL) 130 | 131 | 132 | RNG = np.random.RandomState(args.seed) 133 | NON_ALPHA = re.compile(r"\W", re.UNICODE) 134 | datasets.logging.set_verbosity_error() 135 | 136 | SIGNATURE_COLUMN = "__signatures__" 137 | INDEX_COLUMN = "__index__" 138 | CLUSTER_COLUMN = "__cluster__" 139 | 140 | # for is originally used to reduce memory usage in MacOS but also ensures that the Union Find data structure 141 | # is not copied to child processes as long as it is not modified. 142 | mp.set_start_method("fork", force=True) 143 | uf = UnionFind() 144 | 145 | 146 | def sha1_hash(data: bytes, d: int = 32) -> int: 147 | """ 148 | Generate a d-bit hash value from the given data. 149 | """ 150 | if d == 32: 151 | return struct.unpack( 152 | " int: 165 | """ 166 | Generate a 16-bit xxhash based hash value from the given data. 167 | As of python xxhash 3.3.0 (and since 0.3.0) outputs in big-endian. 168 | This is useful as a special purpose xxhash when you only want 16 bits. 169 | bit masked xxh3_64 hashes are faster than xxh32 in modern systems. 170 | """ 171 | return xxhash.xxh3_64_intdigest(data, seed) & 0xFFFF 172 | 173 | 174 | def xxh3_32hash(data: bytes, seed: int = 0) -> int: 175 | """ 176 | Generate a 32-bit xxhash based hash value from the given data. 177 | As of python xxhash 3.3.0 (and since 0.3.0) outputs in big-endian. 178 | This is useful as a special purpose xxhash when you only want 32bits. 179 | bit masked xxh3_64 hashes are faster than xxh32 in modern systems. 180 | """ 181 | return xxhash.xxh3_64_intdigest(data, seed) & 0xFFFFFFFF 182 | 183 | 184 | def optimal_param( 185 | threshold: float, 186 | num_perm: int, 187 | false_positive_weight: float = 0.5, 188 | false_negative_weight: float = 0.5, 189 | ): 190 | """ 191 | Compute the optimal `MinHashLSH` parameter that minimizes the weighted sum 192 | of probabilities of false positive and false negative, taken from datasketch. 193 | """ 194 | 195 | def false_positive_area(threshold: float, b: int, r: int): 196 | """Source: `datasketch.lsh`""" 197 | 198 | def proba(s): 199 | return 1 - (1 - s ** float(r)) ** float(b) 200 | 201 | a, _ = integrate(proba, 0.0, threshold) 202 | return a 203 | 204 | def false_negative_area(threshold: float, b: int, r: int): 205 | """Source: `datasketch.lsh`""" 206 | 207 | def proba(s): 208 | return 1 - (1 - (1 - s ** float(r)) ** float(b)) 209 | 210 | a, _ = integrate(proba, threshold, 1.0) 211 | return a 212 | 213 | min_error = float("inf") 214 | opt = (0, 0) 215 | for b in range(1, num_perm + 1): 216 | max_r = int(num_perm / b) 217 | for r in range(1, max_r + 1): 218 | fp = false_positive_area(threshold, b, r) 219 | fn = false_negative_area(threshold, b, r) 220 | error = fp * false_positive_weight + fn * false_negative_weight 221 | if error < min_error: 222 | min_error = error 223 | opt = (b, r) 224 | return opt 225 | 226 | 227 | def embed_func( 228 | content: str, 229 | idx: int, 230 | *, 231 | num_perm: int, 232 | ngram_size: int, 233 | min_length: int, 234 | hashranges: list[tuple[int, int]], 235 | permutations: np.ndarray, 236 | hash_func: Callable, 237 | dtype: type, 238 | max_hash: np.uint, 239 | modulo_prime: np.uint, 240 | ) -> dict[str, Any]: 241 | """ 242 | Calculate hash values for the content. 243 | """ 244 | # a, b are each np.ndarray arrays containing {num_perm} pairs of random numbers used for building new hashes 245 | # the formula is a * x(base hash of each shingle) + b 246 | a, b = permutations 247 | # split content on whitespace (NON_ALPHA regex), tokenize with ngrams(), and join these n-grams into a single space separated string. 248 | # we then convert to lower case and then bytestrings which is then hashed. Only unique hashed n-grams are left. 249 | tokens: set[bytes] = { 250 | bytes(" ".join(t).lower(), "utf-8") 251 | for t in ngrams(NON_ALPHA.split(content.lower()), ngram_size, min_length) 252 | } 253 | 254 | hashvalues: np.ndarray = np.array( 255 | [hash_func(token) for token in tokens], dtype=dtype 256 | ).reshape(len(tokens), 1) 257 | # Permute the hash values to produce new universal hashes 258 | # Element-wise multiplication with 'hashvalues' and a (non 0 random value) and then adding b 259 | # Then, take modulo 'MODULO_PRIME' and bitwise_and with 'MAX_HASH' to keep only the necessary bits. 260 | hashvalues = (hashvalues * a + b) % modulo_prime & max_hash 261 | # this part is where the name "min" of minhash comes from 262 | # this stacks all the hashes and then takes the minimum from each column 263 | masks: np.ndarray = np.full(shape=num_perm, dtype=dtype, fill_value=max_hash) 264 | hashvalues = np.vstack([hashvalues, masks]).min(axis=0) 265 | # Originally, byteswap was done for speed. Testing show it has a negligible impact 266 | # keeping for backward compatibility, even though theoretically and empirically 267 | # it doesnt matter if it is there or not. github.com/ekzhu/datasketch/issues/114 268 | Hs: list[bytes] = [ 269 | bytes(hashvalues[start:end].byteswap().data) for start, end in hashranges 270 | ] 271 | return {SIGNATURE_COLUMN: Hs, INDEX_COLUMN: idx} 272 | 273 | 274 | def main(): 275 | global uf 276 | uf.reset() 277 | HASH_BITS: int = args.hash_bits 278 | HASH_CONFIG: dict[int, tuple[type, Any, Any]] = { 279 | 64: (np.uint64, np.uint32((1 << 32) - 1), np.uint64((1 << 61) - 1)), 280 | # 32, 16 bit config does not use a mersenne prime. 281 | # The original reason for using mersenne prime was speed. 282 | # Testing reveals, there is no benefit to using a 2^61 mersenne prime for division 283 | 32: (np.uint32, np.uint32((1 << 32) - 1), np.uint32((1 << 32) - 5)), 284 | 16: (np.uint16, np.uint16((1 << 16) - 1), np.uint16((1 << 16) - 15)), 285 | } 286 | DTYPE, MAX_HASH, MODULO_PRIME = HASH_CONFIG.get(HASH_BITS, HASH_CONFIG[64]) 287 | 288 | match args.hash_func: 289 | case "sha1": 290 | 291 | def hash_func(byte_data): 292 | return sha1_hash(byte_data, d=min(HASH_BITS, 32)) 293 | 294 | case "xxh3": 295 | if HASH_BITS == 16: 296 | hash_func = xxh3_16hash 297 | else: 298 | hash_func = xxh3_32hash 299 | 300 | if args.b is not None and args.r is not None: 301 | B, R = args.b, args.r 302 | else: 303 | # Compute the optimal `MinHashLSH` parameter that minimizes the weighted sum 304 | # of probabilities of false positive and false negative, taken from datasketch. 305 | B, R = optimal_param( 306 | args.threshold, 307 | args.num_perm, 308 | false_positive_weight=0.5, 309 | false_negative_weight=0.5, 310 | ) 311 | 312 | HASH_RANGES = [(i * R, (i + 1) * R) for i in range(B)] 313 | HASH_TABLES = [defaultdict(set) for _ in range(B)] 314 | 315 | PERMUTATIONS = ( 316 | RNG.randint( 317 | 1, MODULO_PRIME, size=(args.num_perm,), dtype=DTYPE 318 | ), # a is a multiplier so should not be 0 319 | RNG.randint(0, MODULO_PRIME, size=(args.num_perm,), dtype=DTYPE), # b 320 | ) 321 | 322 | # Loading 323 | data_files_list = [x.strip() for x in args.data_files.split(",")] 324 | ds = datasets.load_dataset("json", data_files=data_files_list, split="train") 325 | ds = ds.map( 326 | lambda x, i: {INDEX_COLUMN: i}, with_indices=True, num_proc=args.num_proc 327 | ) 328 | 329 | if args.ignore_empty: 330 | ds_rest = ds.filter(lambda x: len(x[args.column].strip()) == 0) 331 | ds = ds.filter(lambda x: len(x[args.column].strip()) > 0) 332 | 333 | ds = ds.filter( 334 | lambda x: len(NON_ALPHA.split(x[args.column].lower())) >= args.min_length, 335 | num_proc=args.num_proc, 336 | ) 337 | 338 | LEN_DATASET = len(ds) 339 | if args.ignore_empty: 340 | LEN_DATASET += len(ds_rest) 341 | 342 | # MinHashing 343 | embedded = ds.map( 344 | function=embed_func, 345 | fn_kwargs={ 346 | "num_perm": args.num_perm, 347 | "hashranges": HASH_RANGES, 348 | "ngram_size": args.ngram, 349 | "min_length": args.min_length, 350 | "permutations": PERMUTATIONS, 351 | "hash_func": hash_func, 352 | "dtype": DTYPE, 353 | "max_hash": MAX_HASH, 354 | "modulo_prime": MODULO_PRIME, 355 | }, 356 | input_columns=[args.column, INDEX_COLUMN], 357 | remove_columns=[col for col in ds.column_names if col != INDEX_COLUMN], 358 | num_proc=args.num_proc, 359 | with_indices=False, 360 | desc="Fingerprinting...", 361 | ) 362 | LEN_EMBEDDED = len(embedded) 363 | NUM_SHARDS = np.ceil(LEN_EMBEDDED / args.batch_size).astype(int) 364 | 365 | # Clustering 366 | edges = [] 367 | for i in tqdm( 368 | range(0, NUM_SHARDS), 369 | dynamic_ncols=True, 370 | desc="Iterating MinHashes...", # noqa: E501 371 | ): 372 | embedded_shard = embedded.shard( 373 | num_shards=NUM_SHARDS, 374 | index=i, 375 | contiguous=True, 376 | writer_batch_size=args.batch_size, 377 | ) 378 | for key, Hs in zip( 379 | embedded_shard[INDEX_COLUMN], embedded_shard[SIGNATURE_COLUMN] 380 | ): 381 | for i, H in enumerate(Hs): 382 | HASH_TABLES[i][H].add(key) 383 | 384 | print(f"Number of clusters: {len(HASH_TABLES)}") 385 | for table in tqdm(HASH_TABLES, dynamic_ncols=True, desc="Clustering..."): 386 | # cluster: Set[int] 387 | for cluster in table.values(): 388 | if len(cluster) <= 1: 389 | continue 390 | idx = min(cluster) 391 | for x in cluster: 392 | edges.append((x, idx)) 393 | uf.union(x, idx) 394 | print(f"Number of edges: {len(set(edges))}") 395 | 396 | # Filtering 397 | ds = ds.map( 398 | function=lambda record: {CLUSTER_COLUMN: uf.find(record[INDEX_COLUMN])}, 399 | with_indices=False, 400 | num_proc=args.num_proc, 401 | new_fingerprint=str(random.getrandbits(128)), 402 | desc="Finding clusters...", 403 | ) 404 | # This is where the deduplication happens 405 | # Since there is no easy groupby in datasets 406 | # I will use this simple filter for now 407 | final_data = ds.filter( 408 | function=lambda record: record[CLUSTER_COLUMN] == record[INDEX_COLUMN], 409 | with_indices=False, 410 | num_proc=args.num_proc, 411 | desc="Filtering clusters...", 412 | ) 413 | if args.ignore_empty and len(ds_rest) > 0: 414 | final_data = datasets.concatenate_datasets([ds_rest, final_data]) 415 | 416 | # Saving 417 | final_data = final_data.remove_columns([CLUSTER_COLUMN, INDEX_COLUMN]) 418 | final_data.to_json(args.output) 419 | print("Before:", LEN_DATASET) 420 | print("After:", len(final_data)) 421 | 422 | # Cleaning 423 | ds.cleanup_cache_files() 424 | final_data.cleanup_cache_files() 425 | 426 | 427 | if __name__ == "__main__": 428 | main() 429 | -------------------------------------------------------------------------------- /src/star_align/sanitize_data.py: -------------------------------------------------------------------------------- 1 | """Deduplication, filtering, and selection""" 2 | 3 | import random 4 | import os 5 | import ast 6 | import re 7 | import warnings 8 | from dataclasses import dataclass, field 9 | from pathlib import Path 10 | from typing import cast, Literal 11 | from datasets import load_dataset, Dataset 12 | from tqdm.auto import tqdm 13 | from transformers import HfArgumentParser 14 | 15 | from star_align.utils import find_code_blocks, write_jsonl, find_codeblock_indices 16 | 17 | LLAMA3 = os.getenv("LLAMA3") is not None 18 | if LLAMA3: 19 | print("LLAMA3 mode activated") 20 | 21 | 22 | @dataclass(frozen=True) 23 | class Args: 24 | data_files: list[str] 25 | output_file: str 26 | shuffle: bool = field(default=True) 27 | remove_strange: bool = field(default=True) 28 | parse_raw_response: bool = field(default=True) 29 | passing_only: bool = field(default=True) 30 | data_augmentation: bool = field(default=False) 31 | exact_match_dedup: bool = field(default=True) 32 | get_code_representation: bool = field(default=True) 33 | remove_comments_docstrings: bool = field(default=False) 34 | include_left_failed: bool = field(default=False) 35 | n_cores: int = field(default=os.cpu_count() or 1) 36 | diversify_func_names: bool = field(default=True) 37 | align_with: list[str] = field(default_factory=list) 38 | priority: Literal["passed", "failed", "none"] = field(default="none") 39 | seed: int = field(default=6666) 40 | 41 | 42 | def extract_and_concat_function_names(python_content): 43 | """ 44 | Extracts all function names from a given Python content string and concatenates them into a single string. 45 | 46 | Parameters: 47 | - python_content: A string containing the Python code to analyze. 48 | 49 | Returns: 50 | - A string containing all function names defined in the content, concatenated. 51 | """ 52 | tree = ast.parse(python_content) 53 | function_names = [] 54 | class_names = [] 55 | 56 | # Define a node visitor that adds the name of each function definition it visits 57 | class FuncClassDefVisitor(ast.NodeVisitor): 58 | def visit_ClassDef(self, node: ast.ClassDef): 59 | class_names.append(node.name) 60 | self.generic_visit(node) 61 | 62 | def visit_FunctionDef(self, node): 63 | function_names.append(node.name) 64 | # Process the subtree for this node 65 | self.generic_visit(node) 66 | 67 | def visit_AsyncFunctionDef(self, node): 68 | function_names.append(node.name) 69 | self.generic_visit(node) 70 | 71 | # Create a node visitor and walk through the AST 72 | visitor = FuncClassDefVisitor() 73 | visitor.visit(tree) 74 | 75 | def compress_name(name: str) -> str: 76 | return name.replace("_", "").lower() 77 | 78 | return frozenset(map(compress_name, function_names)), frozenset( 79 | map(compress_name, class_names) 80 | ) 81 | 82 | 83 | INCOMPLETE_SUBSTRINGS = [ 84 | "todo", 85 | "fixme", 86 | "write your code here", 87 | "your code here", 88 | "your code goes here", 89 | "notimplemented", 90 | ] 91 | 92 | RESPONSE_TEST_SPLIT = "\n\n" 93 | # special handling for llama3 since it has more examples not following the format 94 | LLAMA3_DEFAULT_TEST_SPLIT = r"### Tests \d\n" 95 | LLAMA3_ADDITIONAL_PATTERNS = [ 96 | "We can verify the functionality", 97 | "We can verify the correctness", 98 | "You can verify the correctness", 99 | "You can verify the functionality", 100 | "To ensure the correctness", 101 | "To verify the correctness", 102 | "To test the", 103 | "To test this", 104 | "To test this", 105 | "You can test the", 106 | "We can test the", 107 | "We can test this", 108 | "Now, we'll test", 109 | ] 110 | 111 | 112 | def split_llama3_response_tests(response: str) -> list[str]: 113 | splits = re.split(LLAMA3_DEFAULT_TEST_SPLIT, response) 114 | if len(splits) > 2: 115 | return [] 116 | if len(splits) == 2: 117 | return splits 118 | for pattern in LLAMA3_ADDITIONAL_PATTERNS: 119 | index = response.find(pattern) 120 | if index != -1: 121 | return [response[:index], response[index:]] 122 | return [] 123 | 124 | 125 | def preprocess_and_filter(x: dict) -> dict: 126 | """Filter out responses with wrong format""" 127 | 128 | def wrong_format(x: dict) -> dict: 129 | return {k: v for k, v in x.items()} | dict(wrong_format=True, tests="") 130 | 131 | response: str = x["response"] 132 | if not LLAMA3 and RESPONSE_TEST_SPLIT not in response: 133 | return wrong_format(x) 134 | if any(substring in response.lower() for substring in INCOMPLETE_SUBSTRINGS): 135 | return wrong_format(x) 136 | if LLAMA3: 137 | splits = split_llama3_response_tests(response) 138 | else: 139 | splits = response.split(RESPONSE_TEST_SPLIT) 140 | if len(splits) != 2: 141 | return wrong_format(x) 142 | response, tests = cast(tuple[str, str], tuple(map(str.strip, splits))) 143 | response_codeblocks = find_code_blocks(response, "python") 144 | tests_codeblocks = find_code_blocks(tests, "python") 145 | if len(response_codeblocks) == 0 or len(tests_codeblocks) == 0: 146 | return wrong_format(x) 147 | 148 | tests_content = "\n".join(tests_codeblocks) 149 | if "assert" not in tests or all( 150 | l.startswith("def") 151 | or l.startswith("class") 152 | or l.startswith("import") 153 | or l.startswith("from") 154 | for l in tests_content.splitlines() 155 | if len(l) > 0 and l[0].isalpha() 156 | ): 157 | return wrong_format(x) 158 | 159 | newx = {k: v for k, v in x.items() if k != "response"} | dict( 160 | response=response, tests=tests, wrong_format=False 161 | ) 162 | return newx 163 | 164 | 165 | def augment_data(x: dict, index: int) -> dict: 166 | random.seed(index) 167 | tests_content = "\n".join(find_code_blocks(x["tests"])) 168 | lines = tests_content.splitlines() 169 | if all(l.startswith("assert") for l in lines): 170 | ks = [1, 2, 3, 4, 5] 171 | assertions = random.sample(lines, k=min(random.choice(ks), len(lines))) 172 | assertion = "\n".join(assertions) 173 | assertion_term = "assertion" + ("s" if len(assertions) > 1 else "") 174 | else: 175 | assertion = tests_content 176 | assertion_term = "test case" 177 | if ( 178 | "assert" in assertion 179 | # 5 lines augmented block max 180 | and len(assertion.splitlines()) <= 5 181 | and random.random() < 0.5 182 | and "assert" not in x["instruction"] 183 | and "for example" not in x["instruction"].lower() 184 | and "test" not in x["instruction"].lower() 185 | ): 186 | assert "assert" in assertion 187 | assertion_str = ( 188 | f"Your code should pass the following {assertion_term}:\n```python\n" 189 | + assertion.strip() 190 | + "\n```" 191 | ) 192 | new_instruction = f"{x['instruction']}\n\n{assertion_str}" 193 | newx = {k: v for k, v in x.items()} | dict(instruction=new_instruction) 194 | return newx 195 | return x 196 | 197 | 198 | # raw response -> response + test 199 | # response/test -> passing (opt: passing) 200 | # (not)passing -> unique 201 | # unique -> aug / minihash / selection / educational -> final 202 | 203 | 204 | def remove_ast_docstrings(tree): 205 | # ref: https://gist.github.com/phpdude/1ae6f19de213d66286c8183e9e3b9ec1 206 | for node in ast.walk(tree): 207 | # let's work only on functions & classes definitions 208 | if not isinstance(node, (ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef)): 209 | continue 210 | if len(node.body) == 0: 211 | continue 212 | if not isinstance(node.body[0], ast.Expr): 213 | continue 214 | if ( 215 | not hasattr(node.body[0], "value") 216 | or not isinstance(node.body[0].value, ast.Str) 217 | # or not isinstance(node.body[0].value.value, str) 218 | ): 219 | continue 220 | node.body = node.body[1:] # type: ignore 221 | return tree 222 | 223 | 224 | def remove_comments_from_code_blocks( 225 | content: str, 226 | ) -> str: 227 | code_blocks = find_codeblock_indices(content) 228 | # Current index in the original content for tracking purposes 229 | current_index = 0 230 | # Buffer to store the new content 231 | new_content: list[str] = [] 232 | # Iterate over each code block 233 | for start, end in code_blocks: 234 | # Append the content before this code block 235 | new_content.append(content[current_index:start]) 236 | 237 | # Extract the code block content 238 | code_block_content = content[start:end] 239 | 240 | # Split into lines, process, and rejoin 241 | modified_block_content = remove_comments(code_block_content) 242 | 243 | new_content.append(modified_block_content) 244 | 245 | # Update current index 246 | current_index = end 247 | 248 | # Add the remaining part of the original content after the last code block 249 | new_content.append(content[current_index:]) 250 | 251 | # Join all parts to form the final modified content 252 | return "".join(new_content) 253 | 254 | 255 | def remove_comments(code: str) -> str: 256 | """Remove comments and docstrings using AST""" 257 | tree = ast.parse(code) 258 | tree = remove_ast_docstrings(tree) 259 | return ast.unparse(tree) 260 | 261 | 262 | def get_code_representation(response: str) -> str: 263 | """Keep classes and functions, removing comments and docstrings""" 264 | raw_code = "\n".join(find_code_blocks(response)) 265 | 266 | tree = ast.parse(raw_code) 267 | 268 | class ClassFunctionTransformer(ast.NodeTransformer): 269 | def visit_Module(self, node): 270 | # Visit all children nodes of the module 271 | node = self.generic_visit(node) 272 | # Filter out only function and class definitions 273 | node.body = [ 274 | n for n in node.body if isinstance(n, (ast.FunctionDef, ast.ClassDef)) 275 | ] 276 | return node 277 | 278 | visitor = ClassFunctionTransformer() 279 | tree = visitor.visit(tree) 280 | tree = remove_ast_docstrings(tree) 281 | return ast.unparse(tree) 282 | 283 | 284 | def map_code_representation(x: dict) -> dict: 285 | try: 286 | representation = get_code_representation(x["response"]) 287 | except SyntaxError: 288 | representation = "" 289 | return {k: v for k, v in x.items()} | dict(code_representation=representation) 290 | 291 | 292 | # def concat_list(lists: list[list]) -> list: 293 | # return [item for sublist in lists for item in sublist] 294 | 295 | 296 | def map_examples_batched(examples: dict, map_one) -> dict: 297 | all_keys = list(examples.keys()) 298 | list_of_examples = [ 299 | {k: examples[k][i] for k in all_keys} for i in range(len(examples[all_keys[0]])) 300 | ] 301 | results = [map_one(example) for example in list_of_examples] 302 | result_dict = {k: [result[k] for result in results] for k in results[0].keys()} 303 | return result_dict 304 | 305 | 306 | def map_remove_comments(x: dict) -> dict: 307 | try: 308 | response = x["response"] 309 | except SyntaxError: 310 | response = "" 311 | return {k: v for k, v in x.items() if k != "response"} | dict(response=response) 312 | 313 | 314 | def main(): 315 | args = cast(Args, HfArgumentParser(Args).parse_args_into_dataclasses()[0]) 316 | 317 | raw_data = load_dataset("json", data_files=args.data_files, split="train") 318 | if args.align_with: 319 | ref_data = load_dataset("json", data_files=args.align_with, split="train") 320 | ref_data_instructions = set(map(lambda x: x["instruction"], ref_data)) 321 | raw_data = raw_data.filter( 322 | lambda x: x["instruction"] in ref_data_instructions, num_proc=args.n_cores 323 | ) 324 | print("Raw samples:", len(raw_data)) 325 | 326 | if args.parse_raw_response: 327 | raw_data = raw_data.map( 328 | map_examples_batched, 329 | fn_kwargs=dict(map_one=preprocess_and_filter), 330 | batched=True, 331 | num_proc=args.n_cores, 332 | ) 333 | raw_data = raw_data.filter( 334 | lambda x: not x["wrong_format"], num_proc=args.n_cores 335 | ) 336 | raw_data = raw_data.remove_columns(["wrong_format"]) 337 | print("Correct format:", len(raw_data)) 338 | 339 | if args.include_left_failed: 340 | failed_data = raw_data.filter(lambda x: not x["pass"], num_proc=args.n_cores) 341 | 342 | if args.passing_only: 343 | raw_data = raw_data.filter(lambda x: x["pass"], num_proc=args.n_cores) 344 | print("Passing only:", len(raw_data)) 345 | 346 | if args.shuffle: 347 | raw_data = raw_data.shuffle(seed=args.seed) 348 | if args.include_left_failed: 349 | failed_data = failed_data.shuffle(seed=args.seed) 350 | 351 | if args.priority != "none": 352 | # Sort the examples such that failed/passed are at first 353 | raw_data = raw_data.map( 354 | map_examples_batched, 355 | fn_kwargs=dict(map_one=lambda x: dict(**x, rank=int(x["pass"]))), 356 | batched=True, 357 | num_proc=args.n_cores, 358 | ) 359 | reverse = args.priority == "passed" 360 | raw_data = raw_data.sort(column_names="rank", reverse=reverse) 361 | raw_data = raw_data.remove_columns("rank") 362 | 363 | def mk_key(instruction: str) -> str: 364 | return "".join(instruction.split()) 365 | 366 | seen_ids = set[frozenset[str]]() 367 | seen_keys = set[str]() 368 | if args.exact_match_dedup: 369 | new_data = list[dict]() 370 | 371 | def iterate(dataset: Dataset): 372 | for d in tqdm(dataset): 373 | if args.remove_strange: 374 | # NOTE: newly added 375 | if len(d["instruction"].split()) > 200: 376 | continue 377 | key_i, key_r = mk_key(d["instruction"]), mk_key(d["response"]) 378 | if key_i in seen_keys or key_r in seen_keys: 379 | continue 380 | if args.diversify_func_names: 381 | code_block = find_code_blocks(d["response"])[0] 382 | try: 383 | fn_names, class_names = extract_and_concat_function_names( 384 | code_block 385 | ) 386 | except SyntaxError: 387 | continue 388 | if (len(fn_names) > 0 and fn_names in seen_ids) or ( 389 | len(class_names) > 0 and class_names in seen_ids 390 | ): 391 | continue 392 | seen_ids.add(fn_names) 393 | seen_ids.add(class_names) 394 | new_data.append(d) 395 | seen_keys.add(key_i) 396 | seen_keys.add(key_r) 397 | 398 | iterate(raw_data) 399 | if args.include_left_failed: 400 | iterate(failed_data) 401 | 402 | print("Non exact matches:", len(new_data)) 403 | else: 404 | new_data = raw_data.to_list() 405 | if args.include_left_failed: 406 | new_data.extend(failed_data.to_list()) 407 | new_dataset = Dataset.from_list(new_data) 408 | 409 | if args.get_code_representation: 410 | new_dataset = new_dataset.map( 411 | map_examples_batched, 412 | fn_kwargs=dict(map_one=map_code_representation), 413 | batched=True, 414 | batch_size=1000, 415 | # num_proc=args.n_cores, 416 | ) 417 | new_dataset = new_dataset.filter( 418 | lambda x: x["code_representation"] != "", 419 | num_proc=args.n_cores, 420 | ) 421 | print("Extracted code representation:", len(new_dataset)) 422 | 423 | if args.remove_comments_docstrings: 424 | new_dataset = new_dataset.map( 425 | map_examples_batched, 426 | fn_kwargs=dict(map_one=map_remove_comments), 427 | batched=True, 428 | # num_proc=args.n_cores, 429 | ) 430 | new_dataset = new_dataset.filter( 431 | lambda x: x["response"] != "", 432 | num_proc=args.n_cores, 433 | ) 434 | print("Removed comments/docstrings:", len(new_dataset)) 435 | 436 | if args.data_augmentation: 437 | new_dataset = new_dataset.map( 438 | augment_data, 439 | num_proc=args.n_cores, 440 | with_indices=True, 441 | ) 442 | print("Augmented:", len(new_dataset)) 443 | 444 | write_jsonl(Path(args.output_file), new_dataset) 445 | 446 | 447 | if __name__ == "__main__": 448 | main() 449 | -------------------------------------------------------------------------------- /src/star_align/llm_wrapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from enum import Enum 4 | from typing import Callable, Literal 5 | 6 | import torch 7 | from transformers import AutoModelForCausalLM, AutoModelWithLMHead, AutoTokenizer 8 | from transformers import GenerationConfig as TransformersGenerationConfig 9 | from transformers import ( 10 | PreTrainedModel, 11 | PreTrainedTokenizer, 12 | StoppingCriteria, 13 | StoppingCriteriaList, 14 | ) 15 | 16 | # from peft import PeftModel, PeftConfig 17 | 18 | # Tokenization side modeling 19 | PaddingSide = Literal["left", "right"] 20 | # Input: a batch of chat pieces; Output: a batch of instructions and responses 21 | # The instances should encode in a way that the model can predict response from instruction 22 | InputIds = list[int] 23 | 24 | # Adopted from https://github.com/huggingface/transformers/pull/14897 25 | class EndOfFunctionCriteria(StoppingCriteria): 26 | def __init__(self, start_length, eos, tokenizer, *args, **kwargs): 27 | super().__init__(*args, **kwargs) 28 | self.start_length = start_length 29 | self.eos = eos 30 | self.tokenizer = tokenizer 31 | self.end_length = {} 32 | 33 | def __call__(self, input_ids, scores, **kwargs): 34 | """Returns true if all generated sequences contain any of the end-of-function strings.""" 35 | decoded_generations = self.tokenizer.batch_decode( 36 | input_ids[:, self.start_length :] 37 | ) 38 | done = [] 39 | for index, decoded_generation in enumerate(decoded_generations): 40 | finished = any( 41 | [stop_string in decoded_generation for stop_string in self.eos] 42 | ) 43 | if ( 44 | finished and index not in self.end_length 45 | ): # ensures first time we see it 46 | for stop_string in self.eos: 47 | if stop_string in decoded_generation: 48 | self.end_length[index] = len( 49 | input_ids[ 50 | index, # get length of actual generation 51 | self.start_length : -len( 52 | self.tokenizer.encode( 53 | stop_string, 54 | add_special_tokens=False, 55 | return_tensors="pt", 56 | )[0] 57 | ), 58 | ] 59 | ) 60 | done.append(finished) 61 | return all(done) 62 | 63 | 64 | @dataclass(frozen=True) 65 | class DecodingConfig: 66 | skip_special_tokens: bool 67 | 68 | @staticmethod 69 | def default() -> "DecodingConfig": 70 | return DecodingConfig(skip_special_tokens=True) 71 | 72 | 73 | # TransformChatPieceFunc = Callable[[ChatPiece], tuple[str, str]] 74 | 75 | 76 | @dataclass(frozen=True) 77 | class EncodingConfig: 78 | add_bos: bool 79 | add_eos: bool 80 | truncation: int | None = field(default=None) 81 | 82 | @staticmethod 83 | def default() -> "EncodingConfig": 84 | return EncodingConfig(add_bos=False, add_eos=False) 85 | 86 | 87 | @dataclass(frozen=True) 88 | class TokenizationContext: 89 | tokenizer: PreTrainedTokenizer 90 | pad_token_id: int 91 | bos_token: str 92 | eos_token: str 93 | 94 | @property 95 | def eos_token_id(self) -> int: 96 | return self.tokenizer.eos_token_id 97 | 98 | @staticmethod 99 | def from_model_key( 100 | model_key: str, model_name_or_path: str | None = None 101 | ) -> "TokenizationContext": 102 | # use_fast = model_key not in SupportedModelKeys.codellama_models() 103 | use_fast = True 104 | # if model_name_or_path is None: 105 | # model_name_or_path = model_key 106 | # TODO: check if tokenizers cannot be loaded with path 107 | model_name_or_path = model_key 108 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=use_fast) 109 | tokenization_context = TokenizationContext.from_tokenizer(tokenizer) 110 | return tokenization_context 111 | 112 | @staticmethod 113 | def from_tokenizer(tokenizer: PreTrainedTokenizer) -> "TokenizationContext": 114 | if (pad_token_id := tokenizer.pad_token_id) is None: 115 | pad_token_id = tokenizer.eos_token_id 116 | assert pad_token_id is not None 117 | bos_token = tokenizer.bos_token 118 | eos_token = tokenizer.eos_token 119 | return TokenizationContext( 120 | tokenizer=tokenizer, 121 | pad_token_id=pad_token_id, 122 | bos_token=bos_token, 123 | eos_token=eos_token, 124 | ) 125 | 126 | def encode(self, config: EncodingConfig, text_list: list[str]) -> list[list[int]]: 127 | # eos_token = self.eos_token if config.add_eos else "" 128 | # bos_token = self.bos_token if config.add_bos else "" 129 | # if eos_token != "" or bos_token != "": 130 | # text_list = [f"{bos_token}{text}{eos_token}" for text in text_list] 131 | # The string concatenation above may not always work for all tokenizers (strange). 132 | # e.g., when codellama's tokenizer is used with "[INST]". 133 | if config.truncation is not None: 134 | extra_args = dict(truncation=True, max_length=config.truncation) 135 | else: 136 | extra_args = {} 137 | input_ids = self.tokenizer( 138 | text_list, 139 | add_special_tokens=False, 140 | **extra_args, 141 | )["input_ids"] 142 | bos_token_id = self.tokenizer.bos_token_id 143 | eos_token_id = self.tokenizer.eos_token_id 144 | bos_token_ids = ( 145 | [bos_token_id] if config.add_bos and bos_token_id is not None else [] 146 | ) 147 | eos_token_ids = ( 148 | [eos_token_id] if config.add_eos and eos_token_id is not None else [] 149 | ) 150 | if len(bos_token_ids) > 0 or len(eos_token_ids) > 0: 151 | input_ids = [ 152 | bos_token_ids + input_id + eos_token_ids for input_id in input_ids 153 | ] 154 | return input_ids 155 | 156 | def decode( 157 | self, config: DecodingConfig, input_ids: list[InputIds] | torch.Tensor 158 | ) -> list[str]: 159 | return self.tokenizer.batch_decode( 160 | input_ids, skip_special_tokens=config.skip_special_tokens 161 | ) 162 | 163 | def encode_with_padding( 164 | self, padding_side: PaddingSide, config: EncodingConfig, text_list: list[str] 165 | ) -> torch.Tensor: 166 | input_ids_unpadded = self.encode(config, text_list) 167 | return pad_sequences( 168 | sequences=input_ids_unpadded, 169 | pad_value=self.pad_token_id, 170 | padding_side=padding_side, 171 | ) 172 | 173 | 174 | def pad_sequences( 175 | sequences: list[list[int]], 176 | pad_value: int, 177 | padding_side: Literal["left", "right"], 178 | dtype: torch.dtype = torch.long, 179 | padding_length: int | None = None, 180 | ) -> torch.Tensor: 181 | tensors = [torch.tensor(sequence, dtype=dtype) for sequence in sequences] 182 | max_len = max(len(sequence) for sequence in sequences) 183 | if padding_length is not None: 184 | assert padding_length >= max_len, "padding_length must be >= max_len" 185 | max_len = padding_length 186 | if padding_side == "right": 187 | result = torch.nn.utils.rnn.pad_sequence( 188 | tensors, batch_first=True, padding_value=pad_value 189 | ) 190 | remaining_length = max_len - result.shape[-1] 191 | # padding matrix of (batch_size * remaining_length) 192 | shape = result.shape[:-1] + (remaining_length,) 193 | padding_matrix = torch.full(shape, pad_value, dtype=dtype) 194 | result = torch.cat([result, padding_matrix], dim=-1) 195 | else: 196 | padded_tensors: list[torch.Tensor] = [] 197 | for tensor in tensors: 198 | n_pad_values = max_len - len(tensor) 199 | padded_values = torch.full((n_pad_values,), pad_value, dtype=dtype) 200 | padded_tensor = torch.cat([padded_values, tensor], dim=0) 201 | assert len(padded_tensor) == max_len 202 | padded_tensors.append(padded_tensor) 203 | result = torch.stack(padded_tensors, dim=0) 204 | assert result.shape == torch.Size([len(sequences), max_len]) 205 | return result 206 | 207 | 208 | # Inference side modeling 209 | @dataclass(frozen=True) 210 | class GenerationConfig: 211 | max_new_tokens: int 212 | top_p: float 213 | temperature: float 214 | max_length: int = field( 215 | default=99999999999999999, 216 | metadata={ 217 | "help": "The max length of the sequence to generate, including inputs." 218 | "Will be considered in tandem with max_new_tokens. Whichever is more restrictive will be used." 219 | }, 220 | ) 221 | 222 | def to_transformers_generation_config( 223 | self, eos_token_id: int, pad_token_id: int 224 | ) -> TransformersGenerationConfig: 225 | do_sample = self.temperature != 0.0 226 | kwargs = dict( 227 | max_new_tokens=self.max_new_tokens, 228 | top_p=self.top_p, 229 | eos_token_id=eos_token_id, 230 | pad_token_id=pad_token_id, 231 | do_sample=do_sample, 232 | ) 233 | if do_sample: 234 | kwargs["temperature"] = self.temperature 235 | return TransformersGenerationConfig(**kwargs) 236 | 237 | def with_max_new_tokens_being(self, max_new_tokens: int) -> "GenerationConfig": 238 | return GenerationConfig(max_new_tokens, self.top_p, self.temperature) 239 | 240 | @staticmethod 241 | def default() -> "GenerationConfig": 242 | return GenerationConfig(200, 1.0, 1.0) 243 | 244 | 245 | @dataclass(frozen=True) 246 | class Response: 247 | raw_inputs: torch.Tensor 248 | raw_outputs: torch.Tensor 249 | decoded_outputs: list[str] 250 | 251 | 252 | @dataclass 253 | class ModelContext: 254 | tokenization_context: TokenizationContext 255 | model: PreTrainedModel 256 | max_context_size: int 257 | 258 | def generate( 259 | self, 260 | config: GenerationConfig, 261 | input_ids: torch.Tensor, 262 | stop_tokens: list[str] | None = None, 263 | ) -> torch.Tensor: 264 | """Raise ValueError when input_ids exceeds the context.""" 265 | # NOTE: this implementation is only for decoder-only models 266 | # Recalculate the max number of tokens to avoid overflowing the context window 267 | input_len = input_ids.shape[1] 268 | if input_len >= self.max_context_size: 269 | raise ValueError( 270 | f"Input length {input_len} >= Context size {self.max_context_size}" 271 | ) 272 | if input_len >= config.max_length: 273 | raise ValueError( 274 | f"Input length {input_len} >= Max length {config.max_length}" 275 | ) 276 | assert input_len < self.max_context_size 277 | assert input_len < config.max_length 278 | 279 | max_new_tokens = min( 280 | self.max_context_size - input_len, 281 | config.max_new_tokens, 282 | config.max_length - input_len, 283 | ) 284 | config = config.with_max_new_tokens_being(max_new_tokens) 285 | 286 | tf_config = config.to_transformers_generation_config( 287 | eos_token_id=self.tokenization_context.eos_token_id, 288 | pad_token_id=self.tokenization_context.pad_token_id, 289 | ) 290 | attention_mask = input_ids.ne(self.tokenization_context.pad_token_id) 291 | # breakpoint() 292 | extra_kwargs: dict = {} 293 | if stop_tokens is not None: 294 | stopping_criteria = StoppingCriteriaList( 295 | [ 296 | EndOfFunctionCriteria( 297 | start_length=len(input_ids[0]), 298 | eos=stop_tokens, 299 | tokenizer=self.tokenization_context.tokenizer, 300 | ) 301 | ] 302 | ) 303 | extra_kwargs["stopping_criteria"] = stopping_criteria 304 | outputs = self.model.generate( 305 | input_ids=input_ids, 306 | attention_mask=attention_mask, 307 | generation_config=tf_config, 308 | **extra_kwargs, 309 | ) 310 | # input_len = input_ids.shape[1] 311 | return outputs[:, input_len:] 312 | 313 | def complete( 314 | self, 315 | config: GenerationConfig, 316 | prompts: list[str], 317 | stop_tokens: list[str] | None = None, 318 | ) -> Response: 319 | encoding_config = EncodingConfig(add_bos=True, add_eos=False) 320 | input_ids = self.tokenization_context.encode_with_padding( 321 | "left", encoding_config, prompts 322 | ) 323 | input_ids = input_ids.to(self.model.device) 324 | output_ids = self.generate(config, input_ids, stop_tokens) 325 | decoding_config = DecodingConfig(skip_special_tokens=True) 326 | output_strings = self.tokenization_context.decode(decoding_config, output_ids) 327 | return Response( 328 | raw_inputs=input_ids, 329 | raw_outputs=output_ids, 330 | decoded_outputs=output_strings, 331 | ) 332 | 333 | class SupportedModelKeys(Enum): 334 | # StarCoder-based models 335 | STARCODER_15B = "bigcode/starcoder" 336 | WIZARDCODER_STARCODER_15B = "WizardLM/WizardCoder-15B-V1.0" 337 | 338 | # CodeLlama-based models 339 | WIZARDCODER_CODELLAMA_PYTHON_7B = "WizardLM/WizardCoder-Python-7B-V1.0" 340 | WIZARDCODER_CODELLAMA_PYTHON_13B = "WizardLM/WizardCoder-Python-13B-V1.0" 341 | WIZARDCODER_CODELLAMA_PYTHON_34B = "WizardLM/WizardCoder-Python-34B-V1.0" 342 | CODELLAMA_PYTHON_7B = "codellama/CodeLlama-7b-Python-hf" 343 | CODELLAMA_PYTHON_13B = "codellama/CodeLlama-13b-Python-hf" 344 | CODELLAMA_PYTHON_34B = "codellama/CodeLlama-34b-Python-hf" 345 | 346 | # DeepSeek-Coder-based models 347 | DEEPSEEK_CODER_1_3B = "deepseek-ai/deepseek-coder-1.3b-base" 348 | DEEPSEEK_CODER_6_7B = "deepseek-ai/deepseek-coder-6.7b-base" 349 | DEEPSEEK_CODER_33B = "deepseek-ai/deepseek-coder-33b-base" 350 | 351 | @staticmethod 352 | def all() -> list[str]: 353 | return [member.value for member in SupportedModelKeys] 354 | 355 | @staticmethod 356 | def codellama_models() -> list[str]: 357 | return [ 358 | SupportedModelKeys.CODELLAMA_PYTHON_7B.value, 359 | SupportedModelKeys.CODELLAMA_PYTHON_13B.value, 360 | SupportedModelKeys.CODELLAMA_PYTHON_34B.value, 361 | # SupportedModelKeys.WIZARDCODER_CODELLAMA_PYTHON_7B.value, 362 | # SupportedModelKeys.WIZARDCODER_CODELLAMA_PYTHON_13B.value, 363 | # SupportedModelKeys.WIZARDCODER_CODELLAMA_PYTHON_34B.value, 364 | ] 365 | 366 | @staticmethod 367 | def codellama_based_models() -> list[str]: 368 | return SupportedModelKeys.codellama_models() + [ 369 | SupportedModelKeys.WIZARDCODER_CODELLAMA_PYTHON_7B.value, 370 | SupportedModelKeys.WIZARDCODER_CODELLAMA_PYTHON_13B.value, 371 | SupportedModelKeys.WIZARDCODER_CODELLAMA_PYTHON_34B.value, 372 | ] 373 | 374 | @staticmethod 375 | def starcoder_based_models() -> list[str]: 376 | return [ 377 | SupportedModelKeys.STARCODER_15B.value, 378 | SupportedModelKeys.WIZARDCODER_STARCODER_15B.value, 379 | ] 380 | 381 | @staticmethod 382 | def deepseekcoder_based_models() -> list[str]: 383 | return [ 384 | SupportedModelKeys.DEEPSEEK_CODER_1_3B.value, 385 | SupportedModelKeys.DEEPSEEK_CODER_6_7B.value, 386 | SupportedModelKeys.DEEPSEEK_CODER_33B.value, 387 | ] 388 | 389 | 390 | def get_model_context( 391 | model_key: str, 392 | model_name_or_path: str | None = None, 393 | tokenization_context: TokenizationContext | None = None, 394 | inference_mode: bool = True, 395 | use_flash_attention: bool = False, 396 | attention_dropout: float | None = None, 397 | residual_dropout: float | None = None, 398 | embedding_dropout: float | None = None, 399 | ) -> ModelContext: 400 | # `model_key` defines the model and the tokenizer to use, while `model_name_or_path` 401 | # defines where to load the weights. It can be from a local directory. 402 | # assert model_key in SupportedModelKeys.all(), model_key 403 | if model_key not in SupportedModelKeys.all(): 404 | import warnings 405 | 406 | warnings.warn( 407 | f"{model_key} not explicitly supported. This may or may not lead to unexpected behaviors." 408 | ) 409 | if model_name_or_path is None: 410 | model_name_or_path = model_key 411 | if model_key in SupportedModelKeys.codellama_based_models(): 412 | max_context_size = 16384 413 | elif model_key in SupportedModelKeys.starcoder_based_models(): 414 | max_context_size = 8192 415 | elif model_key in SupportedModelKeys.deepseekcoder_based_models(): 416 | max_context_size = 16384 417 | else: 418 | import warnings 419 | 420 | warnings.warn( 421 | f"{model_key} does not have a specified max context, using default 4096" 422 | ) 423 | max_context_size = 4096 424 | if tokenization_context is None: 425 | tokenization_context = TokenizationContext.from_model_key(model_key) 426 | # TODO: check if all these models use bfloat16 427 | dtype = torch.bfloat16 428 | other_kwargs: dict = {} 429 | if inference_mode: 430 | other_kwargs["device_map"] = "auto" 431 | if use_flash_attention: 432 | # if "starcoder2" in model_key: 433 | # other_kwargs["attn_implementation"] = "flash_attention_2" 434 | # else: 435 | import transformers 436 | 437 | if transformers.__version__ <= "4.35.0": 438 | other_kwargs["use_flash_attention_2"] = True 439 | else: 440 | other_kwargs["attn_implementation"] = "flash_attention_2" 441 | # other_kwargs["use_flash_attention_2"] = True 442 | # cls = AutoModelWithLMHead if "starcoder2-3b" in model_key else AutoModelForCausalLM 443 | 444 | if "starcoder" in model_key.lower(): 445 | print("Hack for starcoder") 446 | attention_dropout = attention_dropout or 0.0 447 | residual_dropout = residual_dropout or 0.0 448 | embedding_dropout = embedding_dropout or 0.0 449 | 450 | if attention_dropout is not None: 451 | other_kwargs["attention_dropout"] = attention_dropout 452 | if residual_dropout is not None: 453 | other_kwargs["residual_dropout"] = residual_dropout 454 | if embedding_dropout is not None: 455 | other_kwargs["embedding_dropout"] = embedding_dropout 456 | # if (dropout := os.getenv("ATTENTION_DROPOUT")) is not None: 457 | # other_kwargs["attention_dropout"] = float(dropout) 458 | # print(f"Using attention dropout: {dropout}") 459 | model = AutoModelForCausalLM.from_pretrained( 460 | model_name_or_path, 461 | torch_dtype=dtype, 462 | # hack 463 | # revision=os.getenv("REVISION"), 464 | **other_kwargs, 465 | ) 466 | print("Successfully loaded model.") 467 | print(model.config) 468 | return ModelContext(tokenization_context, model, max_context_size) 469 | 470 | 471 | def form_starcoder_infill(prefix: str, suffix: str) -> str: 472 | FIM_PREFIX = "" 473 | FIM_MIDDLE = "" 474 | FIM_SUFFIX = "" 475 | prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}" 476 | return prompt 477 | 478 | 479 | def form_codellama_infill(prefix: str, suffix: str) -> str: 480 | # NOTE: not using because it's treated as a special token 481 | # but we pass `add_special_tokens=False` to the tokenizer 482 | return f"▁
{prefix}▁{suffix}▁"
483 | 
484 | 
485 | def form_deepseekcoder_infill(
486 |     tokenizer: PreTrainedTokenizer, prefix: str, suffix: str
487 | ) -> str:
488 |     def get_token(idx: int) -> str:
489 |         return tokenizer.convert_ids_to_tokens([idx])[0]
490 | 
491 |     FIM_PREFIX = get_token(32016)
492 |     FIM_MIDDLE = get_token(32015)
493 |     FIM_SUFFIX = get_token(32017)
494 |     assert "begin" in FIM_PREFIX and "hole" in FIM_MIDDLE and "end" in FIM_SUFFIX
495 |     prompt = f"{FIM_PREFIX}{prefix}{FIM_MIDDLE}{suffix}{FIM_SUFFIX}"
496 |     return prompt
497 | 
498 | 
499 | def create_infilling_prompt(
500 |     model_key: str,
501 |     prefix: str,
502 |     suffix: str,
503 |     tokenizer: PreTrainedTokenizer | None = None,
504 | ) -> str:
505 |     if model_key in SupportedModelKeys.starcoder_based_models():
506 |         return form_starcoder_infill(prefix, suffix)
507 |     elif (
508 |         model_key in SupportedModelKeys.codellama_based_models()
509 |         and not "python" in model_key.lower()
510 |     ):
511 |         return form_codellama_infill(prefix, suffix)
512 |     elif model_key in SupportedModelKeys.deepseekcoder_based_models():
513 |         assert tokenizer is not None
514 |         return form_deepseekcoder_infill(tokenizer, prefix, suffix)
515 | 
516 |     # TODO: other models
517 |     assert False, f"Unsupported model key: {model_key}"
518 | 


--------------------------------------------------------------------------------