├── generative_cert ├── __init__.py ├── utils │ ├── __init__.py │ ├── data_analytic.py │ └── utils.py ├── scripts │ ├── fitcluster │ │ ├── script.sh │ │ ├── README.md │ │ ├── sampling-sensitive │ │ │ ├── run_generate_temp_sampling.sh │ │ │ ├── run_generate_nucleus.sh │ │ │ └── run_generate_top_k.sh │ │ ├── run-fewshot-cot-with-hint.sh │ │ └── run-fewshot-cot.sh │ ├── eval_llm.sh │ └── generation │ │ ├── chatgpt_generation.sh │ │ └── template.sh ├── README.md └── prompts.py ├── .env.example ├── figures ├── framework.png ├── generative_result.png └── discriminative_result.png ├── .gitignore ├── requirements.txt ├── llms ├── __init__.py ├── base_language_model.py ├── model_adapter.py ├── base_hf_causal_model.py └── chatgpt.py ├── scripts ├── gen-cert │ ├── llm_prompting.sh │ ├── run.sh │ ├── subgraph.sh │ ├── extract_triplet.sh │ ├── job_eval_llm_finegrained.sh │ └── job_eval_llm.sh └── disc-cert │ ├── summary_results.py │ └── submit_discriminative_cert.sh ├── preprocess_data ├── readme.md ├── cwq_graph.py ├── splitted_ground_truth_paths.py ├── subgraph_discriminative_cert.py ├── ground_truth_paths.py ├── grailqa_graph.py └── sparql.py ├── discriminative_cert ├── scripts │ ├── summary_results.py │ └── submit_discriminative_cert.sh ├── utils │ └── utils.py └── discriminative_prompts.py ├── llm_generation.py ├── extract_subgraph ├── graph_loader.py ├── deal_cvt.py └── preprocess_data.py ├── finegrained_gen_cert.py ├── virtuoso_db └── virtuoso.py ├── evaluate_results.py ├── finegrained_analysis.py ├── README.md ├── generative-cert.py ├── discriminative-cert.py └── certify_fact.py /generative_cert/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /generative_cert/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.env.example: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY = sk-xx 2 | HF_TOKEN = hf_xxx 3 | OPENAI_ORG = '' 4 | -------------------------------------------------------------------------------- /figures/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinhVuong2000/LLMReasonCert/HEAD/figures/framework.png -------------------------------------------------------------------------------- /figures/generative_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinhVuong2000/LLMReasonCert/HEAD/figures/generative_result.png -------------------------------------------------------------------------------- /figures/discriminative_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MinhVuong2000/LLMReasonCert/HEAD/figures/discriminative_result.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | logs/ 3 | results/ 4 | tmp/ 5 | .vscode 6 | *.pyc 7 | slurm-* 8 | smux-* 9 | virtuoso_db/virtuoso-opensource/* 10 | virtuoso_db/virtuoso_db/* 11 | dis_results/* 12 | data/* 13 | test/* 14 | .DS_Store 15 | .ruff_cache 16 | wandb 17 | new_dis_results/* 18 | new_results* 19 | -------------------------------------------------------------------------------- /generative_cert/scripts/fitcluster/script.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_NAME=llama 4 | HF_MODEL_NAME=meta-llama/Llama-2-7b-chat-hf 5 | MODEL_PREFIX=llama2-7b 6 | for DATASET in cwq webqsp grail_qa; do 7 | sbatch --job-name=$MODEL_PREFIX-$DATASET run-fewshot-cot.sh $MODEL_NAME $HF_MODEL_NAME $DATASET 8 | done -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch --extra-index-url https://download.pytorch.org/whl/cu117 2 | datasets 3 | networkx 4 | openai==1.3.8 5 | faiss-cpu 6 | wandb 7 | ruff 8 | sentence_transformers 9 | numpy 10 | pandas 11 | fuzzywuzzy 12 | transformers>=4.35.2 13 | transformers_stream_generator 14 | einops 15 | python-dotenv 16 | tiktoken 17 | sentencepiece 18 | accelerate 19 | flash-attn 20 | protobuf 21 | minineedle 22 | -------------------------------------------------------------------------------- /llms/__init__.py: -------------------------------------------------------------------------------- 1 | from .chatgpt import ChatGPT 2 | from .base_language_model import BaseLanguageModel 3 | from .model_adapter import * 4 | 5 | registed_language_models = { 6 | "gpt-4": ChatGPT, 7 | "gpt-3.5-turbo": ChatGPT, 8 | "llama": Llama, 9 | "qwen": Qwen, 10 | "mistral": Mistral, 11 | "vicuna": Vicuna, 12 | } 13 | 14 | 15 | def get_registed_model(model_name) -> BaseLanguageModel: 16 | for key, value in registed_language_models.items(): 17 | if key in model_name.lower(): 18 | return value 19 | raise ValueError(f"No registered model found for name {model_name}") 20 | -------------------------------------------------------------------------------- /scripts/gen-cert/llm_prompting.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MODEL_NAME=gpt-3.5-turbo 3 | N_PROCESS=8 4 | 5 | ## consistency 6 | TOPP=0.9 7 | TEMP=0.7 8 | MODE=fewshot-cot-only 9 | N_SEQ=20 10 | for DATASET in cwq grail_qa; do 11 | sbatch generative_cert/scripts/generation/chatgpt_generation.sh $MODEL_NAME $MODE $DATASET $N_PROCESS $TEMP $TOPP $N_SEQ 12 | done 13 | 14 | ## temprerature topp 15 | TOPP=0.9 16 | TEMP=0.7 17 | N_SEQ=1 18 | for DATASET in cwq grail_qa; do 19 | for MODE in "fewshot-cot-only" "fewshot-cot-hint" "fewshot-cot-hint-ground"; do 20 | sbatch generative_cert/scripts/generation/chatgpt_generation.sh $MODEL_NAME $MODE $DATASET $N_PROCESS $TEMP $TOPP $N_SEQ 21 | done 22 | done -------------------------------------------------------------------------------- /scripts/gen-cert/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=certify_fact 3 | # SBATCH --account=da34 4 | #SBATCH --time=01:00:00 5 | #SBATCH --nodes=1 # node count 6 | # SBATCH --ntasks=1 7 | # SBATCH --ntasks-per-node=1 # total number of tasks per node 8 | # SBATCH --gres=1 # number of gpus 9 | #SBATCH --mem-per-cpu=50000 10 | #SBATCH --cpus-per-task=1 11 | #SBATCH --partition=gpu 12 | # SBATCH --qos=normal 13 | #SBATCH --mail-type=BEGIN,END,FAIL 14 | #SBATCH --mail-user=minhvuong160620@gmail.com 15 | 16 | module load python/3.8.5 17 | 18 | ROOT_DIR='LLMReasoningCert/' 19 | HUGGINGFACE_HUB_CACHE=$ROOT_DIR/envs/huggingface 20 | HF_HOME=$ROOT_DIR/envs/huggingface 21 | 22 | source $ROOT_DIR/envs/vuongntm/bin/activate 23 | cd $ROOT_DIR/LLMReasonCert 24 | python certify_fact.py --prob_thres 70 #--dataset cwq 25 | -------------------------------------------------------------------------------- /scripts/gen-cert/subgraph.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # SBATCH --job-name=get_subgraph 3 | # SBATCH --account=da34 4 | # SBATCH --time=04:00:00 5 | # SBATCH --nodes=1 # node count 6 | # SBATCH --ntasks=1 7 | # SBATCH --ntasks-per-node=1 # total number of tasks per node 8 | # SBATCH --cpus-per-task=16 9 | # SBATCH --gres=None # number of gpus 10 | # SBATCH --mem-per-cpu=40960 11 | # SBATCH --partition=fit 12 | # SBATCH --qos=fit 13 | # SBATCH --mail-type=BEGIN,END,FAIL 14 | # SBATCH --mail-user=minhvuong160620@gmail.com 15 | # SBATCH --output=%x-%j.out 16 | # SBATCH --error=%x-%j.err 17 | 18 | module load python/3.8.5 19 | ROOT_DIR='LLMReasoningCert' 20 | source $ROOT_DIR/envs/vuongntm/bin/activate 21 | 22 | export HUGGINGFACE_HUB_CACHE='envs/huggingface/' 23 | export HF_HOME='envs/huggingface/' 24 | 25 | cd $ROOT_DIR/LLMReasonCert 26 | srun --jobid $SLURM_JOBID bash -c 'python preprocess_data.py' 27 | -------------------------------------------------------------------------------- /generative_cert/README.md: -------------------------------------------------------------------------------- 1 | A Guide for LLM Generation 2 | --- 3 | 4 | ## Run step by step: 5 | 1. Move to the main working directory 6 | ```bash 7 | cd LLMReasoningCert/LLMReasonCert 8 | ``` 9 | 2. Modify bash file at `./scripts/generation/template.sh`. \ 10 | e.g. `sbatch generative_cert/scripts/generation/template.sh mistralai/Mistral-7B-Instruct-v0.1 fewshot-cot-only cwq 50 1 4` 11 | arguments can be found in `generative_cert/llm_generation.py` and `llms/base_hf_causal_model.py` 12 | 13 | 3. Run a few samples to make sure output is OK \ 14 | 3.1. 15 | `--run_sample` \ 16 | Default is DATASET can be `all`, mean ['cwq','FreebaseQA','FreebaseQA]. You can only generate 1 out of 3 datasets by change row DATASET 17 | 3.2. Check Output results to make sure output is OK (3 files)\ 18 | a. `../data/cwq/{model_name}/test/llm_prompt_response.jsonl`\ 19 | b. `../data/FreebaseQA/{model_name}/test/llm_prompt_response.jsonl`\ 20 | b. `../data/grail_qa/{model_name}/test/llm_prompt_response.jsonl`\ 21 | 22 | 4. Run all\ 23 | a. comment the final row `--run_sample` \ 24 | b. run sbatch 25 | 26 | -------------------------------------------------------------------------------- /llms/base_language_model.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | 4 | class BaseLanguageModel(object): 5 | """ 6 | Base lanuage model. Define how to generate sentence by using a LM 7 | Args: 8 | args: arguments for LM configuration 9 | """ 10 | 11 | @staticmethod 12 | def add_args(parser): 13 | return 14 | 15 | def __init__(self, args): 16 | self.args = args 17 | 18 | def load_model(self, **kwargs): 19 | raise NotImplementedError 20 | 21 | def prepare_for_inference(self, **model_kwargs): 22 | raise NotImplementedError 23 | 24 | def prepare_model_prompt(self, query): 25 | """ 26 | Add model-specific prompt to the input 27 | 28 | Args: 29 | instruction (str) 30 | input (str): str 31 | """ 32 | raise NotImplementedError 33 | 34 | def generate_sentence(self, llm_input): 35 | """ 36 | Generate sentence by using a LM 37 | 38 | Args: 39 | lm_input (LMInput): input for LM 40 | """ 41 | raise NotImplementedError 42 | -------------------------------------------------------------------------------- /scripts/gen-cert/extract_triplet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=extract_triplet-skip_unknown_ent-only_test 3 | # SBATCH --account=da34 4 | #SBATCH --time=2-00:00:00 5 | #SBATCH --nodes=1 # node count 6 | # SBATCH --ntasks=1 7 | # SBATCH --ntasks-per-node=1 # total number of tasks per node 8 | # SBATCH --gres=0 # number of gpus 9 | #SBATCH --mem-per-cpu=50000 10 | #SBATCH --cpus-per-task=2 11 | # SBATCH --partition=gpu 12 | # SBATCH --qos=normal 13 | #SBATCH --mail-type=BEGIN,END,FAIL 14 | #SBATCH --mail-user=minhvuong160620@gmail.com 15 | #SBATCH --output=LLMReasoningCert/slum/%x-%j.out 16 | #SBATCH --error=LLMReasoningCert/slum/%x-%j.err 17 | 18 | module load python/3.8.5 19 | 20 | ROOT_DIR='LLMReasoningCert/' 21 | export HUGGINGFACE_HUB_CACHE=$ROOT_DIR/huggingface 22 | export HF_HOME=$ROOT_DIR/huggingface 23 | 24 | export PYTHONPATH="${PYTHONPATH}:${ROOT_DIR}/LLMReasonCert" 25 | 26 | source $ROOT_DIR/envs/vuongntm/bin/activate 27 | cd $ROOT_DIR/LLMReasonCert 28 | DATASET=$1 29 | python extract_triplet/extract_triplet.py --dataset $DATASET --create_db ###--query 'Francis Avent Gumm was born in Tennessee' 30 | -------------------------------------------------------------------------------- /preprocess_data/readme.md: -------------------------------------------------------------------------------- 1 | How to create Subgraphs from SPARQL 2 | --- 3 | 4 | ### A. Firstly, you need to construct Freebase Database before querying via SparQL 5 | 1. Download [virtuoso-opensource](https://github.com/openlink/virtuoso-opensource/releases/tag/v7.2.11) 6 | -> Extract and put the folder in the `./virtuoso_db` 7 | 2. Download and [virtuoso_db file](virtuoso_db) 8 | -> unzip and put the folder in the `./virtuoso_db/virtuoso_db` 9 | 10 | Read more in [this link](https://juejin.cn/post/7283690681175113740) 11 | 12 | ### B. After finishing the preparation 13 | It requires a terminal multiplexer, might use `vim` or `tmux`/`smux` (preferred). 14 | ##### Window1. 15 | 1. Change the working dir: `cd ./virtuoso_db` 16 | 2. `python3 virtuoso.py start 3001 -d virtuoso_db`\ 17 | To stop: `python3 virtuoso.py stop 3001` 18 | 19 | ##### Window2. 20 | Obtain a **raw-subgraph** via `SPARQL` and `CONSTRUCT` query and \ 21 | **subgraph** from the **raw-subgraph** by skipping unnamed entities:`h,r1,*`,`*,r2,t` ->`h,r1/r2,t` 22 | - Handle CWQ: `python ./preprocess_data/cwq_graph.py` 23 | - Handle GrailQA: `python ./preprocess_data/grailqa_graph.py` 24 | -------------------------------------------------------------------------------- /generative_cert/scripts/eval_llm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=llm_generate 3 | #SBATCH --time=1-00:00:00 4 | #SBATCH --nodes=1 # node count 5 | #SBATCH --ntasks=1 6 | #SBATCH --ntasks-per-node=1 # total number of tasks per node 7 | #SBATCH --mem-per-cpu=50000 8 | #SBATCH --cpus-per-task=1 9 | #SBATCH --mail-type=BEGIN,END,FAIL 10 | #SBATCH --mail-user=minhvuong160620@gmail.com 11 | #SBATCH --output=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/slum/llm_cert/%x-%j.out 12 | #SBATCH --error=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/slum/llm_cert/%x-%j.err 13 | 14 | module load python/3.8.5 15 | 16 | ROOT_DIR='/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/' 17 | export HUGGINGFACE_HUB_CACHE=$ROOT_DIR/huggingface 18 | export HF_HOME=$ROOT_DIR/huggingface 19 | 20 | source $ROOT_DIR/envs/vuongntm/bin/activate 21 | cd $ROOT_DIR/LLMReasonCert 22 | 23 | DATASET=$1 24 | MODE=$2 25 | DATA_PATH=$3 26 | PROB_THRES=$4 27 | ENT_THRES=$5 28 | IS_SC=$6 29 | 30 | echo "=============================================================" 31 | echo "Running certification of LLM" 32 | echo " * RAW_DATA_PATH = "$DATA_PATH 33 | echo " * MODE = "$MODE 34 | echo " * DATASET = "$DATASET 35 | echo " * PROB_THRES = "$PROB_THRES 36 | echo " * ENT_THRES = "$ENT_THRES 37 | echo " * SELF_CONSISTENCY = "$IS_SC 38 | echo "=============================================================" 39 | 40 | 41 | python generative-cert.py \ 42 | --dataset $DATASET \ 43 | --mode $MODE \ 44 | --raw_dat_path $DATA_PATH \ 45 | --prob_thres $PROB_THRES \ 46 | --ent_thres $ENT_THRES \ 47 | --is_sc $IS_SC 48 | -------------------------------------------------------------------------------- /scripts/gen-cert/job_eval_llm_finegrained.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATASET_LIST="cwq grail_qa" 3 | 4 | # gpt-3.5-turbo 5 | for DATASET in $DATASET_LIST; do 6 | JOB_NAME=gpt-3.5-turbo-fewshot-cot-hint-$DATASET-temp-0.7-p-0.9-consistency-1-is_sc_1 7 | DATA_PATH=LLMReasoningCert/LLMReasonCert/results/$DATASET/gpt-3.5-turbo/$JOB_NAME/full.jsonl 8 | python finegrained_gen_cert.py --dat_path $DATA_PATH 9 | 10 | JOB_NAME=gpt-3.5-turbo-fewshot-cot-only-$DATASET-temp-0.7-p-0.9-consistency-1-is_sc_1 11 | DATA_PATH=LLMReasoningCert/LLMReasonCert/results/$DATASET/gpt-3.5-turbo/$JOB_NAME/full.jsonl 12 | python finegrained_gen_cert.py --dat_path $DATA_PATH 13 | 14 | JOB_NAME=gpt-3.5-turbo-fewshot-cot-only-$DATASET-temp-0.7-p-0.9-consistency-20-is_sc_4 15 | DATA_PATH=LLMReasoningCert/LLMReasonCert/results/$DATASET/gpt-3.5-turbo/$JOB_NAME/full.jsonl 16 | python finegrained_gen_cert.py --dat_path $DATA_PATH 17 | done 18 | 19 | 20 | # others 21 | MODEL_LIST="Llama-2-70b-chat-hf Qwen-7B-Chat Qwen-14B-Chat Mistral-7B-Instruct-v0.1 vicuna-33b-v1.3" 22 | for DATASET in $DATASET_LIST; do 23 | for MODEL_NAME in $MODEL_LIST; do 24 | DATA_PATH=LLMReasoningCert/LLMReasonCert/results/$DATASET/$MODEL_NAME/cot-hint-temp-0.7-p-0.9-is_sc_1/full.jsonl 25 | python finegrained_gen_cert.py --dat_path $DATA_PATH 26 | 27 | DATA_PATH=LLMReasoningCert/LLMReasonCert/results/$DATASET/$MODEL_NAME/cot-temp-0.7-p-0.9-is_sc_1/full.jsonl 28 | python finegrained_gen_cert.py --dat_path $DATA_PATH 29 | 30 | DATA_PATH=LLMReasoningCert/LLMReasonCert/results/$DATASET/$MODEL_NAME/cot-temp-0.7-p-0.9-is_sc_4/full.jsonl 31 | python finegrained_gen_cert.py --dat_path $DATA_PATH 32 | done 33 | done 34 | -------------------------------------------------------------------------------- /generative_cert/scripts/generation/chatgpt_generation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=llm_generate 3 | #SBATCH --time=5-00:00:00 4 | #SBATCH --nodes=1 # node count 5 | #SBATCH --ntasks=1 6 | #SBATCH --ntasks-per-node=1 # total number of tasks per node 7 | #SBATCH --mem-per-cpu=50000 8 | #SBATCH --cpus-per-task=2 9 | #SBATCH --mail-type=BEGIN,END,FAIL 10 | #SBATCH --mail-user=minhvuong160620@gmail.com 11 | #SBATCH --output=LLMReasoningCert/slum/llm_generation/%x-%j.out 12 | #SBATCH --error=LLMReasoningCert/slum/llm_generation/%x-%j.err 13 | 14 | module load python/3.8.5 15 | 16 | ROOT_DIR='LLMReasoningCert/' 17 | export HUGGINGFACE_HUB_CACHE=$ROOT_DIR/huggingface 18 | export HF_HOME=$ROOT_DIR/huggingface 19 | 20 | source $ROOT_DIR/envs/vuongntm/bin/activate 21 | cd $ROOT_DIR/LLMReasonCert 22 | 23 | MODEL_NAME=$1 24 | MODE=$2 25 | DATASET=$3 26 | N_PROCESS=$4 27 | TEMP=$5 28 | TOPP=$6 29 | N_SEQ=$7 30 | EXP_NAME=$MODEL_NAME-$MODE-$DATASET-temp-$TEMP-p-$TOPP-consistency-$N_SEQ 31 | echo "=============================================================" 32 | echo "Running generation with temprature sampling" 33 | echo " * MODEL_NAME = "$MODEL_NAME 34 | echo " * MODE = "$MODE 35 | echo " * DATASET = "$DATASET 36 | echo " * EXP_NAME = "$EXP_NAME 37 | echo " * TEMPERATURE = "$TEMP 38 | echo " * TOPP = "$TOPP 39 | echo " * NUM_RETURN_SEQUENCE = "$N_SEQ 40 | echo "=============================================================" 41 | 42 | python llm_generation.py \ 43 | --exp_name $EXP_NAME \ 44 | --model_name $MODEL_NAME \ 45 | --mode $MODE \ 46 | --dataset $DATASET \ 47 | --n $N_PROCESS \ 48 | --temperature $TEMP \ 49 | --top_p $TOPP \ 50 | --num_return_sequences $N_SEQ 51 | -------------------------------------------------------------------------------- /generative_cert/scripts/fitcluster/README.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | ```bash 3 | module load anaconda/anaconda3 4 | conda create --p env python=3.10 5 | source activate ./env 6 | conda install pip 7 | 8 | module load cudnn/8.5.0.96-11.7-gcc-8.5.0-l5kw6yn 9 | module load cuda/11.7.0-gcc-8.5.0-xcmnp4n 10 | 11 | #python3 -m venv env 12 | #source env/bin/activate 13 | pip3 install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 14 | 15 | cd LLMReasonCert; pip install -r requirements.txt 16 | ``` 17 | 18 | # Experiments 19 | Setup environment: 20 | - Add `.env` file with your OPENAI key and HF_TOKEN. Please check the `.env.example` for an example 21 | ## LLMs 22 | We will evaluate on the following LLMs 23 | ``` 24 | - meta-llama/Llama-2-7b-chat-hf 25 | - meta-llama/Llama-2-13b-chat-hf 26 | - meta-llama/Llama-2-70b-chat-hf 27 | - mistralai/Mistral-7B-Instruct-v0.1 28 | - Qwen/Qwen-14B-Chat 29 | - lmsys/vicuna-33b-v1.3 30 | 31 | TODO: llama2-instruct 32 | ``` 33 | 34 | ## Submit jobs on Fitcluster 35 | Note the model_name for different models 36 | - For qwen family, set `MODEL_NAME=qwen` 37 | - For mistra family, set `MODEL_NAME=mistral` 38 | - For vicuna family, set `MODEL_NAME=vicuna` 39 | - For llama family, set `MODEL_NAME=llama` 40 | 41 | ### Fewshot-CoT 42 | ```bash 43 | MODEL_NAME=llama 44 | HF_MODEL_NAME=meta-llama/Llama-2-7b-chat-hf 45 | MODEL_PREFIX=llama2-7b 46 | for DATASET in cwq grail_qa; do 47 | sbatch --job-name=$MODEL_PREFIX-$DATASET run-fewshot-cot.sh $MODEL_NAME $HF_MODEL_NAME $DATASET 48 | done 49 | ``` 50 | 51 | ### Fewshot-CoT-with-hint 52 | ```bash 53 | MODEL_NAME=llama 54 | HF_MODEL_NAME=meta-llama/Llama-2-7b-chat-hf 55 | MODEL_PREFIX=llama2-7b 56 | for DATASET in cwq grail_qa; do 57 | sbatch --job-name=hint-$MODEL_PREFIX-$DATASET run-fewshot-cot-with-hint.sh $MODEL_NAME $HF_MODEL_NAME $DATASET 58 | done 59 | ``` 60 | -------------------------------------------------------------------------------- /llms/model_adapter.py: -------------------------------------------------------------------------------- 1 | from .base_hf_causal_model import HfCausalModel 2 | from .conv_prompt import * 3 | 4 | 5 | class Llama(HfCausalModel): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | 9 | def prepare_model_prompt(self, query): 10 | """ 11 | Add model-specific prompt to the input 12 | """ 13 | conv = get_conv_template("llama-2") 14 | conv.append_message(conv.roles[0], query) 15 | conv.append_message(conv.roles[1], None) 16 | 17 | return conv.get_prompt() 18 | 19 | 20 | class Mistral(HfCausalModel): 21 | def __init__(self, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | 24 | def prepare_model_prompt(self, query): 25 | """ 26 | Add model-specific prompt to the input 27 | """ 28 | conv = get_conv_template("mistral") 29 | conv.append_message(conv.roles[0], query) 30 | conv.append_message(conv.roles[1], None) 31 | 32 | return conv.get_prompt() 33 | 34 | 35 | class Qwen(HfCausalModel): 36 | def __init__(self, *args, **kwargs): 37 | super().__init__(*args, **kwargs) 38 | 39 | def prepare_model_prompt(self, query): 40 | """ 41 | Add model-specific prompt to the input 42 | """ 43 | conv = get_conv_template("qwen") 44 | conv.append_message(conv.roles[0], query) 45 | conv.append_message(conv.roles[1], None) 46 | 47 | return conv.get_prompt() 48 | 49 | 50 | class Vicuna(HfCausalModel): 51 | def __init__(self, *args, **kwargs): 52 | super().__init__(*args, **kwargs) 53 | 54 | def prepare_model_prompt(self, query): 55 | """ 56 | Add model-specific prompt to the input 57 | """ 58 | conv = get_conv_template("vicuna_v1.1") 59 | conv.append_message(conv.roles[0], query) 60 | conv.append_message(conv.roles[1], None) 61 | 62 | return conv.get_prompt() 63 | -------------------------------------------------------------------------------- /generative_cert/scripts/fitcluster/sampling-sensitive/run_generate_temp_sampling.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=llm_generate 3 | #SBATCH --time=5-00:00:00 4 | #SBATCH --nodes=1 # node count 5 | #SBATCH --ntasks=1 6 | #SBATCH --ntasks-per-node=1 # total number of tasks per node 7 | #SBATCH --mem-per-cpu=50000 8 | #SBATCH --cpus-per-task=2 9 | #SBATCH --gres=gpu:1 10 | #SBATCH --mail-type=BEGIN,END,FAIL 11 | #SBATCH --mail-user=trang.vu1@monash.edu 12 | #SBATCH --output=%x-%j.out 13 | #SBATCH --error=%x-%j.err 14 | 15 | ROOT_DIR=/data/others/trang/cert-reasoning 16 | module load anaconda/anaconda3 17 | module load cudnn/8.5.0.96-11.7-gcc-8.5.0-l5kw6yn 18 | module load cuda/11.7.0-gcc-8.5.0-xcmnp4n 19 | source activate $ROOT_DIR/env 20 | 21 | export TMPDIR=$ROOT_DIR/tmp 22 | export HF_HOME=$ROOT_DIR/huggingface 23 | export HUGGINGFACE_HUB_CACHE=$ROOT_DIR/huggingface 24 | 25 | 26 | SRC_DIR=$ROOT_DIR/LLMReasonCert 27 | DATA_DIR=$ROOT_DIR/data/{}/gpt-3.5-turbo/test/splitted_ground_truth_paths.json 28 | OUT_DIR=$ROOT_DIR/output 29 | HF_MODEL_NAME=$1 30 | DATASET=$2 31 | TEMP=$3 32 | EXP_NAME=sampling-temp-$TEMP 33 | PROMT_TEMPLATE_DIR=$SRC_DIR/srcs/conf/generation 34 | DEVICE=cuda 35 | 36 | echo "=============================================================" 37 | echo "Running generation with temprature sampling" 38 | echo " * DATA_DIR = "$DATA_DIR 39 | echo " * OUT_DIR = "$OUT_DIR 40 | echo " * HF_MODEL_NAME = "$HF_MODEL_NAME 41 | echo " * PROMT_TEMPLATE_DIR = "$PROMT_TEMPLATE_DIR 42 | echo " * DEVICE = "$DEVICE 43 | echo " * EXP_NAME = "$EXP_NAME 44 | echo " * TEMPERATURE = "$TEMP 45 | echo " * TOPK = "$TOPK 46 | echo " * TOPP = "$TOPP 47 | echo "=============================================================" 48 | 49 | python $SRC_DIR/srcs/main.py \ 50 | --HF_model_name $HF_MODEL_NAME \ 51 | --dataset $DATASET \ 52 | --in_dir $DATA_DIR \ 53 | --out_dir $OUT_DIR \ 54 | --prompt_template_dir $PROMT_TEMPLATE_DIR \ 55 | --device $DEVICE \ 56 | --temperature $TEMP --exp_name $EXP_NAME 57 | -------------------------------------------------------------------------------- /generative_cert/scripts/fitcluster/sampling-sensitive/run_generate_nucleus.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=llm_generate 3 | #SBATCH --time=5-00:00:00 4 | #SBATCH --nodes=1 # node count 5 | #SBATCH --ntasks=1 6 | #SBATCH --ntasks-per-node=1 # total number of tasks per node 7 | #SBATCH --mem-per-cpu=50000 8 | #SBATCH --cpus-per-task=2 9 | #SBATCH --gres=gpu:1 10 | #SBATCH --mail-type=BEGIN,END,FAIL 11 | #SBATCH --mail-user=trang.vu1@monash.edu 12 | #SBATCH --output=%x-%j.out 13 | #SBATCH --error=%x-%j.err 14 | 15 | ROOT_DIR=/data/others/trang/cert-reasoning 16 | module load anaconda/anaconda3 17 | module load cudnn/8.5.0.96-11.7-gcc-8.5.0-l5kw6yn 18 | module load cuda/11.7.0-gcc-8.5.0-xcmnp4n 19 | source activate $ROOT_DIR/env 20 | 21 | export TMPDIR=$ROOT_DIR/tmp 22 | export HF_HOME=$ROOT_DIR/huggingface 23 | export HUGGINGFACE_HUB_CACHE=$ROOT_DIR/huggingface 24 | 25 | 26 | SRC_DIR=$ROOT_DIR/LLMReasonCert 27 | DATA_DIR=$ROOT_DIR/data/{}/gpt-3.5-turbo/test/splitted_ground_truth_paths.json 28 | OUT_DIR=$ROOT_DIR/output 29 | HF_MODEL_NAME=$1 30 | DATASET=$2 31 | TEMP=1.0 32 | TOPP=$3 33 | TOPK=50 34 | PROMT_TEMPLATE_DIR=$SRC_DIR/srcs/conf/generation 35 | DEVICE=cuda 36 | EXP_NAME=sampling-topk-$TOPK-topp-$TOPP 37 | 38 | echo "=============================================================" 39 | echo "Running generation with nucleus sampling" 40 | echo " * DATA_DIR = "$DATA_DIR 41 | echo " * OUT_DIR = "$OUT_DIR 42 | echo " * HF_MODEL_NAME = "$HF_MODEL_NAME 43 | echo " * PROMT_TEMPLATE_DIR = "$PROMT_TEMPLATE_DIR 44 | echo " * DEVICE = "$DEVICE 45 | echo " * EXP_NAME = "$EXP_NAME 46 | echo " * TEMPERATURE = "$TEMP 47 | echo " * TOPK = "$TOPK 48 | echo " * TOPP = "$TOPP 49 | echo "=============================================================" 50 | 51 | python $SRC_DIR/srcs/main.py \ 52 | --HF_model_name $HF_MODEL_NAME \ 53 | --dataset $DATASET \ 54 | --in_dir $DATA_DIR \ 55 | --out_dir $OUT_DIR \ 56 | --prompt_template_dir $PROMT_TEMPLATE_DIR \ 57 | --device $DEVICE \ 58 | --temperature $TEMP --exp_name $EXP_NAME --top_p $TOPP --top_k $TOPK -------------------------------------------------------------------------------- /preprocess_data/cwq_graph.py: -------------------------------------------------------------------------------- 1 | import json 2 | from preprocess_data.sparql import SparQL 3 | from tqdm import tqdm 4 | from datasets import load_dataset 5 | 6 | 7 | def get_topic_entities(dataset="cwq", split="test"): 8 | path = f"LLMReasoningCert/data/{dataset}/old_data/{split}.jsonl" 9 | res = {} 10 | with open(path) as f: 11 | for line in f: 12 | data = json.loads(line) 13 | res[data["id"]] = { 14 | "q_entity": data["q_entity"], 15 | "answer": data["answer"], 16 | "a_entity": data["a_entity"], 17 | } 18 | return res 19 | 20 | 21 | def get_data(dataset="cwq", split="ComplexWebQuestions_test", short_split="test"): 22 | old_data = get_topic_entities(dataset, short_split) 23 | path = f"LLMReasoningCert/data/{dataset}/raw/{split}.json" 24 | out_path = f"LLMReasoningCert/data/{dataset}/{short_split}.jsonl" 25 | with open(path) as f: 26 | data = json.load(f) 27 | sparql = SparQL("http://localhost:3001/sparql") 28 | with open(out_path, "w") as f: 29 | out_count = 0 30 | for sample in tqdm(data): 31 | id = sample["ID"] 32 | expand_info = old_data.get(id, None) 33 | if not expand_info: 34 | continue 35 | try: 36 | rng_path, processed_rng_path = sparql.query_reasoning_path( 37 | sample["sparql"] 38 | ) 39 | except Exception as e: 40 | print(e) 41 | continue 42 | dic = { 43 | "id": id, 44 | "question": sample["question"], 45 | "q_entity": expand_info["q_entity"], 46 | "a_entity": expand_info["a_entity"], 47 | "answer": expand_info["answer"], 48 | "graph": processed_rng_path, 49 | "raw_graph": rng_path, 50 | } 51 | f.write(json.dumps(dic) + "\n") 52 | out_count += 1 53 | print(f"CWQ: Wrote: {out_count}/{len(data)} samples") 54 | 55 | 56 | if __name__ == "__main__": 57 | get_data(dataset="cwq", split="ComplexWebQuestions_test", short_split="test") 58 | -------------------------------------------------------------------------------- /generative_cert/scripts/fitcluster/sampling-sensitive/run_generate_top_k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=llm_generate 3 | #SBATCH --time=5-00:00:00 4 | #SBATCH --nodes=1 # node count 5 | #SBATCH --ntasks=1 6 | #SBATCH --ntasks-per-node=1 # total number of tasks per node 7 | #SBATCH --mem-per-cpu=50000 8 | #SBATCH --cpus-per-task=2 9 | #SBATCH --gres=gpu:1 10 | #SBATCH --mail-type=BEGIN,END,FAIL 11 | # SBATCH --mail-user=trang.vu1@monash.edu 12 | #SBATCH --output=%x-%j.out 13 | #SBATCH --error=%x-%j.err 14 | 15 | # ROOT_DIR=/data/others/trang/cert-reasoning 16 | ROOT_DIR=LLMReasoningCert 17 | # module load anaconda/anaconda3 18 | # module load cudnn/8.5.0.96-11.7-gcc-8.5.0-l5kw6yn 19 | # module load cuda/11.7.0-gcc-8.5.0-xcmnp4n 20 | module load python/3.8.5 21 | source $ROOT_DIR/envs/vuongntm 22 | 23 | export TMPDIR=$ROOT_DIR/tmp 24 | export HF_HOME=$ROOT_DIR/huggingface 25 | export HUGGINGFACE_HUB_CACHE=$ROOT_DIR/huggingface 26 | 27 | 28 | SRC_DIR=$ROOT_DIR/LLMReasonCert 29 | DATA_DIR=$ROOT_DIR/data/{}/gpt-3.5-turbo/test/splitted_ground_truth_paths.json 30 | OUT_DIR=$ROOT_DIR/output 31 | HF_MODEL_NAME=$1 32 | DATASET=$2 33 | TEMP=1.0 34 | TOPK=$3 35 | MODE=$4 36 | PROMT_TEMPLATE_DIR=$SRC_DIR/srcs/conf/generation 37 | DEVICE=cuda 38 | EXP_NAME=sampling-topk-$TOPK 39 | 40 | echo "=============================================================" 41 | echo "Running generation with topk sampling" 42 | echo " * DATA_DIR = "$DATA_DIR 43 | echo " * OUT_DIR = "$OUT_DIR 44 | echo " * HF_MODEL_NAME = "$HF_MODEL_NAME 45 | echo " * PROMT_TEMPLATE_DIR = "$PROMT_TEMPLATE_DIR 46 | echo " * MODE = "$MODE 47 | echo " * DEVICE = "$DEVICE 48 | echo " * EXP_NAME = "$EXP_NAME 49 | echo " * TEMPERATURE = "$TEMP 50 | echo " * TOPK = "$TOPK 51 | echo " * TOPP = "$TOPP 52 | echo "=============================================================" 53 | 54 | python $SRC_DIR/generative-cert/main.py \ 55 | --model_name $HF_MODEL_NAME \ 56 | --model_path $HF_MODEL_NAME \ 57 | --mode $MODE \ 58 | --dataset $DATASET \ 59 | --in_dir $DATA_DIR \ 60 | --out_dir $OUT_DIR \ 61 | --temperature $TEMP \ 62 | --exp_name $EXP_NAME \ 63 | --top_k $TOPK 64 | -------------------------------------------------------------------------------- /generative_cert/scripts/fitcluster/run-fewshot-cot-with-hint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=fewshot-cot 3 | #SBATCH --time=5-00:00:00 4 | #SBATCH --nodes=1 # node count 5 | #SBATCH --ntasks=1 6 | #SBATCH --ntasks-per-node=1 # total number of tasks per node 7 | #SBATCH --mem-per-cpu=50000 8 | #SBATCH --cpus-per-task=2 9 | #SBATCH --gres=gpu:1 10 | #SBATCH --mail-type=BEGIN,END,FAIL 11 | #SBATCH --mail-user=trang.vu1@monash.edu 12 | #SBATCH --output=%x-%j.out 13 | #SBATCH --error=%x-%j.err 14 | 15 | ROOT_DIR=/data/others/trang/cert-reasoning 16 | module load anaconda/anaconda3 17 | module load cudnn/8.5.0.96-11.7-gcc-8.5.0-l5kw6yn 18 | module load cuda/11.7.0-gcc-8.5.0-xcmnp4n 19 | source activate $ROOT_DIR/env 20 | 21 | export TMPDIR=$ROOT_DIR/tmp 22 | export HF_HOME=$ROOT_DIR/huggingface 23 | export HUGGINGFACE_HUB_CACHE=$ROOT_DIR/huggingface 24 | #export HF_DATASETS_OFFLINE=1 25 | #export TRANSFORMERS_OFFLINE=1 26 | 27 | 28 | SRC_DIR=$ROOT_DIR/LLMReasonCert 29 | DATA_DIR=$ROOT_DIR/data/{}/gpt-3.5-turbo/test/splitted_ground_truth_paths.json 30 | OUT_DIR=$ROOT_DIR/output 31 | MODEL_NAME=$1 32 | HF_MODEL_NAME=$2 33 | DATASET=$3 34 | TEMP=0.7 35 | TOPP=0.9 36 | EXP_NAME=cot-hint-temp-$TEMP-p-$TOPP 37 | NUM_RETURN_SEQUENCE=4 38 | MODE="fewshot-cot-hint" 39 | 40 | echo "=============================================================" 41 | echo "Running generation with temprature sampling" 42 | echo " * DATA_DIR = "$DATA_DIR 43 | echo " * DATASET = "$DATASET 44 | echo " * OUT_DIR = "$OUT_DIR 45 | echo " * HF_MODEL_NAME = "$HF_MODEL_NAME 46 | echo " * EXP_NAME = "$EXP_NAME 47 | echo " * TEMPERATURE = "$TEMP 48 | echo " * TOPP = "$TOPP 49 | echo " * NUM_RETURN_SEQUENCE = "$NUM_RETURN_SEQUENCE 50 | echo " * BATCH_SIZE = "$BATCH_SIZE 51 | echo " * MODE = "$MODE 52 | echo "=============================================================" 53 | 54 | python3 $SRC_DIR/llm_generation.py \ 55 | --model_name $MODEL_NAME --model_path $HF_MODEL_NAME \ 56 | --dataset $DATASET \ 57 | --in_dir $DATA_DIR \ 58 | --out_dir $OUT_DIR --mode $MODE \ 59 | --temperature $TEMP --top_p $TOPP \ 60 | --exp_name $EXP_NAME --num_return_sequences $NUM_RETURN_SEQUENCE 61 | -------------------------------------------------------------------------------- /generative_cert/scripts/fitcluster/run-fewshot-cot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=fewshot-cot 3 | #SBATCH --time=5-00:00:00 4 | #SBATCH --nodes=1 # node count 5 | #SBATCH --ntasks=1 6 | #SBATCH --ntasks-per-node=1 # total number of tasks per node 7 | #SBATCH --mem-per-cpu=50000 8 | #SBATCH --cpus-per-task=2 9 | #SBATCH --gres=gpu:1 10 | #SBATCH --mail-type=BEGIN,END,FAIL 11 | #SBATCH --mail-user=trang.vu1@monash.edu 12 | #SBATCH --output=%x-%j.out 13 | #SBATCH --error=%x-%j.err 14 | 15 | ROOT_DIR=/data/others/trang/cert-reasoning 16 | module load anaconda/anaconda3 17 | module load cudnn/8.5.0.96-11.7-gcc-8.5.0-l5kw6yn 18 | module load cuda/11.7.0-gcc-8.5.0-xcmnp4n 19 | source activate $ROOT_DIR/env 20 | 21 | export TMPDIR=$ROOT_DIR/tmp 22 | #export HF_HOME=/data/others/lluo/huggingface/ 23 | export HF_HOME=$ROOT_DIR/huggingface 24 | export HUGGINGFACE_HUB_CACHE=$ROOT_DIR/huggingface 25 | #export HF_DATASETS_OFFLINE=1 26 | #export TRANSFORMERS_OFFLINE=1 27 | 28 | SRC_DIR=$ROOT_DIR/LLMReasonCert 29 | DATA_DIR=$ROOT_DIR/data/{}/gpt-3.5-turbo/test/splitted_ground_truth_paths.json 30 | OUT_DIR=$ROOT_DIR/output 31 | MODEL_NAME=$1 32 | HF_MODEL_NAME=$2 33 | DATASET=$3 34 | TEMP=0.7 35 | TOPP=0.9 36 | EXP_NAME=cot-temp-$TEMP-p-$TOPP 37 | NUM_RETURN_SEQUENCE=4 38 | MODE="fewshot-cot-only" 39 | 40 | echo "=============================================================" 41 | echo "Running generation with temprature sampling" 42 | echo " * DATA_DIR = "$DATA_DIR 43 | echo " * DATASET = "$DATASET 44 | echo " * OUT_DIR = "$OUT_DIR 45 | echo " * MODEL_NAME = "$MODEL_NAME 46 | echo " * HF_MODEL_NAME = "$HF_MODEL_NAME 47 | echo " * EXP_NAME = "$EXP_NAME 48 | echo " * TEMPERATURE = "$TEMP 49 | echo " * TOPP = "$TOPP 50 | echo " * NUM_RETURN_SEQUENCE = "$NUM_RETURN_SEQUENCE 51 | echo " * BATCH_SIZE = "$BATCH_SIZE 52 | echo " * MODE = "$MODE 53 | echo "=============================================================" 54 | 55 | python3 $SRC_DIR/llm_generation.py \ 56 | --model_name $MODEL_NAME --model_path $HF_MODEL_NAME \ 57 | --dataset $DATASET --mode $MODE \ 58 | --in_dir $DATA_DIR \ 59 | --out_dir $OUT_DIR \ 60 | --temperature $TEMP --top_p $TOPP \ 61 | --exp_name $EXP_NAME --num_return_sequences $NUM_RETURN_SEQUENCE 62 | -------------------------------------------------------------------------------- /generative_cert/scripts/generation/template.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=llm_generate 3 | #SBATCH --time=2-00:00:00 4 | #SBATCH --nodes=1 # node count 5 | # SBATCH --ntasks=1 6 | # SBATCH --ntasks-per-node=1 # total number of tasks per node 7 | #SBATCH --mem-per-cpu=50000 8 | #SBATCH --cpus-per-task=2 9 | #SBATCH --partition=gpu 10 | #SBATCH --gres=gpu:1 # number of gpus 11 | # SBATCH --qos=normal 12 | #SBATCH --mail-type=BEGIN,END,FAIL 13 | # SBATCH --mail-user=minhvuong160620@gmail.com 14 | #SBATCH --output=%x-%j.out 15 | #SBATCH --error=%x-%j.err 16 | 17 | module load python/3.8.5 18 | 19 | ROOT_DIR=LLMReasoningCert/ 20 | export HUGGINGFACE_HUB_CACHE=$ROOT_DIR/huggingface 21 | export HF_HOME=$ROOT_DIR/huggingface 22 | export PYTHONPATH="${PYTHONPATH}:LLMReasoningCert/LLMReasonCert" 23 | 24 | source $ROOT_DIR/envs/vuongntm/bin/activate 25 | cd $ROOT_DIR/LLMReasonCert 26 | 27 | HF_MODEL_NAME=$1 28 | DATASET=$3 29 | DATA_DIR=LLMReasoningCert/data/{}/gpt-3.5-turbo/test/splitted_ground_truth_paths.json 30 | OUT_DIR=LLMReasoningCert/data 31 | PROMT_TEMPLATE_DIR=srcs/conf/generation 32 | DEVICE=cuda 33 | TEMP=1.0 34 | TOPK=$4 35 | MODE=$2 36 | NUM_RETURN_SEQUENCE=$5 37 | BATCH_SIZE=$6 38 | EXP_NAME=sampling-topk-$TOPK 39 | 40 | 41 | echo "=============================================================" 42 | echo "Running generation with topk sampling" 43 | echo " * DATA_DIR = "$DATA_DIR 44 | echo " * OUT_DIR = "$OUT_DIR 45 | echo " * HF_MODEL_NAME = "$HF_MODEL_NAME 46 | echo " * PROMT_TEMPLATE_DIR = "$PROMT_TEMPLATE_DIR 47 | echo " * MODE = "$MODE 48 | echo " * DEVICE = "$DEVICE 49 | echo " * EXP_NAME = "$EXP_NAME 50 | echo " * TEMPERATURE = "$TEMP 51 | echo " * TOPK = "$TOPK 52 | echo " * TOPP = "$TOPP 53 | echo " * NUM_RETURN_SEQUENCE = "$NUM_RETURN_SEQUENCE 54 | echo " * BATCH_SIZE = "$BATCH_SIZE 55 | echo "=============================================================" 56 | 57 | 58 | python llm_generation.py \ 59 | --model_name $HF_MODEL_NAME \ 60 | --dataset $DATASET \ 61 | --in_dir $DATA_DIR \ 62 | --out_dir $OUT_DIR \ 63 | --temperature $TEMP \ 64 | --exp_name $EXP_NAME \ 65 | # --top_k $TOPK \ 66 | --mode $MODE \ 67 | # --model_path $HF_MODEL_NAME \ 68 | # --dtype fp16 \ 69 | # --quant none \ 70 | --num_return_sequences $NUM_RETURN_SEQUENCE \ 71 | # --batch_size $BATCH_SIZE \ 72 | --run_sample -------------------------------------------------------------------------------- /preprocess_data/splitted_ground_truth_paths.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from itertools import chain 4 | from generative_cert.utils.utils import is_unknown_ent 5 | 6 | 7 | def count_min_hop_path(ground_truth_paths): 8 | min_hop = min(map(len, ground_truth_paths)) 9 | return min_hop 10 | 11 | 12 | def rearrange_paths(inp_f, out_f): 13 | with open(out_f, "w") as fout: 14 | dic = {"min_1hop": [], "min_2hop": [], "min_multihop": []} 15 | with open(inp_f, "r") as fin: 16 | for line in fin: 17 | res = json.loads(line) 18 | if count_min_hop_path(res["ground_truth_paths"]) < 2: 19 | dic["min_1hop"].append(res) 20 | elif count_min_hop_path(res["ground_truth_paths"]) < 3: 21 | dic["min_2hop"].append(res) 22 | else: 23 | dic["min_multihop"].append(res) 24 | fout.write(json.dumps(dic, indent=4)) 25 | 26 | 27 | def count_hop_cate(inp_f, out_f): 28 | with open(out_f, "w") as fout: 29 | dic = {} 30 | with open(inp_f, "r") as fin: 31 | for line in fin: 32 | res = json.loads(line) 33 | if count_min_hop_path(res["ground_truth_paths"]) < 2: 34 | dic["min_1hop"].append(res) 35 | elif count_min_hop_path(res["ground_truth_paths"]) < 3: 36 | dic["min_2hop"].append(res) 37 | else: 38 | dic["min_multihop"].append(res) 39 | fout.write(json.dumps(dic, indent=4)) 40 | 41 | 42 | if __name__ == "__main__": 43 | d_l = ["grail_qa", "cwq"] 44 | split_l = ["test"] 45 | 46 | data_path = "LLMReasoningCert/data/" 47 | ground_truth_paths_file = os.path.join( 48 | data_path, "{}/gpt-3.5-turbo/{}/ground_truth_paths.jsonl" 49 | ) 50 | ground_truth_multi_paths_file = os.path.join( 51 | data_path, "{}/gpt-3.5-turbo/{}/splitted_ground_truth_paths.json" 52 | ) 53 | 54 | dic = {} 55 | for d in d_l: 56 | for split in split_l: 57 | print("Handling data {} and split {}: ".format(d, split)) 58 | rearrange_paths( 59 | ground_truth_paths_file.format(d, split), 60 | ground_truth_multi_paths_file.format(d, split), 61 | ) 62 | with open(ground_truth_multi_paths_file.format(d, split), "r") as fin: 63 | res = json.load(fin) 64 | dic[f"{d}_{split}"] = { 65 | "1hop": len(res["min_1hop"]), 66 | "2hop": len(res["min_2hop"]), 67 | "multihop": len(res["min_multihop"]), 68 | } 69 | with open(os.path.join(data_path, "stat.json"), "w") as fout: 70 | fout.write(json.dumps(dic, indent=2)) 71 | -------------------------------------------------------------------------------- /generative_cert/utils/data_analytic.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import os 4 | 5 | 6 | def check_type_question(dir, filename): 7 | data = [] 8 | for path in filename: 9 | inp_path = os.path.join(dir, path) 10 | sample = json.load(inp_path) 11 | samples = sample["min_2hop"] + sample["min_multihop"] 12 | for ques in sample: 13 | data.append(ques["question"]) 14 | 15 | res_dic = { 16 | "atleast_atmost_which": [], 17 | "atleast_atmost_count": [], 18 | "more_less_which": [], 19 | "more_less_count": [], 20 | "bool": [], 21 | "min_max": [], 22 | "count": [], 23 | "logical": [], 24 | "unknown": [], 25 | } 26 | for ques in data: 27 | ques = ques.lower() 28 | if " more " in ques or " less " in ques: 29 | if "which" in ques: 30 | res_dic["more_less_which"].append(ques) 31 | elif "how many" in ques: 32 | res_dic["more_less_count"].append(ques) 33 | else: 34 | res_dic["unknown"].append(ques) 35 | elif " at least " in ques or " at most " in ques: 36 | if "which" in ques: 37 | res_dic["atleast_atmost_which"].append(ques) 38 | elif "how many" in ques: 39 | res_dic["atleast_atmost_count"].append(ques) 40 | else: 41 | res_dic["unknown"].append(ques) 42 | elif re.search("^(does |do |is |are )", ques): 43 | res_dic["bool"].append(ques) 44 | elif re.search("minimum|maximum|largest|smallest", ques): 45 | res_dic["min_max"].append(ques) 46 | elif "how many" in ques: 47 | res_dic["count"].append(ques) 48 | elif re.search("what|which", ques) in ques and re.search( 49 | " and | or | not ", ques 50 | ): 51 | res_dic["logical"].append(ques) 52 | else: 53 | res_dic["unknown"].append(ques) 54 | with open(os.path.join(dir, "check_type.json"), "w") as fout: 55 | json.dump(res_dic, fout, indent=4) 56 | 57 | 58 | def count_hop_grailQA( 59 | cache_dir="LLMReasoningCert/data", 60 | ): 61 | data = load_dataset("grail_qa", cache_dir=cache_dir) 62 | dic = {} 63 | for split in data: 64 | if split == "test": 65 | continue 66 | dic[split] = {} 67 | for v in data[split]["num_edge"]: 68 | if v not in dic[split]: 69 | dic[split][v] = 1 70 | else: 71 | dic[split][v] += 1 72 | return dic 73 | 74 | 75 | def count_hop_fbQA(dat_path): 76 | with open(dat_path) as fin: 77 | data = json.load(fin) 78 | data = data["Questions"] 79 | filter_domain = set() 80 | hop_dic = {} 81 | for ques in data: 82 | min_num_hop = min( 83 | [len(parse["InferentialChain"].split("..")) for parse in ques["Parses"]] 84 | ) 85 | if min_num_hop not in hop_dic: 86 | hop_dic[min_num_hop] = [ques["Question-ID"]] 87 | else: 88 | hop_dic[min_num_hop].append(ques["Question-ID"]) 89 | return {k: len(v) for k, v in hop_dic.items()} 90 | 91 | 92 | if __name__ == "__main__": 93 | dir = "LLMReasoningCert/data/cwq/" 94 | check_type_question( 95 | dir, 96 | filename=[ 97 | "gpt-3.5-turbo/train/splitted_ground_truth_paths.json", 98 | "gpt-3.5-turbo/test/splitted_ground_truth_paths.json", 99 | "gpt-3.5-turbo/dev/splitted_ground_truth_paths.json", 100 | ], 101 | ) 102 | -------------------------------------------------------------------------------- /discriminative_cert/scripts/summary_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import glob 4 | import numpy as np 5 | 6 | result_path = "new_results_w_ans" 7 | 8 | for data_path in glob.glob(os.path.join(result_path, "*")): 9 | # print(data_path) 10 | data_name = os.path.basename(data_path) 11 | eval_neg = False 12 | if "neg" in data_name: 13 | eval_neg = True 14 | print(f"{data_name} zero-shot zero-shot-cot few-shot few-shot-cot") 15 | # print(os.path.basename(data_path)) 16 | for model_path in glob.glob(os.path.join(data_path, "*")): 17 | # print(model_path) 18 | model_name = os.path.basename(model_path) 19 | result_dict = { 20 | "zero-shot": None, 21 | "zero-shot-cot": None, 22 | "few-shot": None, 23 | "few-shot-cot": None, 24 | } 25 | for prediction_path in glob.glob(os.path.join(model_path, "*.jsonl")): 26 | total = 0 27 | with open(prediction_path, "r") as f: 28 | lines = f.readlines() 29 | all_result = [] 30 | for line in lines: 31 | try: 32 | data = json.loads(line.strip()) 33 | except: 34 | # print("Error in parsing line: ", line) 35 | # print(prediction_path) 36 | continue 37 | # exit() 38 | result_list = [] 39 | total += 1 40 | for r in data["details"]: 41 | response = r["raw_response"] 42 | prediction = 0 43 | if eval_neg: 44 | if ( 45 | "NO" in response.upper() 46 | and "YES" not in response.upper() 47 | ): 48 | prediction = 1 49 | else: 50 | if ( 51 | "YES" in response.upper() 52 | and "NO" not in response.upper() 53 | ): 54 | prediction = 1 55 | result_list.append(prediction) 56 | all_result.append(np.mean(result_list)) 57 | avg_result = np.mean(all_result) 58 | for key in result_dict.keys(): 59 | result_key = ( 60 | os.path.basename(prediction_path).split("_")[1].split(".")[0] 61 | ) 62 | if key == result_key: 63 | result_dict[key] = f"{avg_result:.4f}" 64 | # result_dict[key] = f"{avg_result:.4f} ({total})" 65 | # for result_path in glob.glob(os.path.join(model_path, "*.txt")): 66 | # # print(result_path) 67 | # with open(result_path, "r") as f: 68 | # lines = f.readlines() 69 | # for line in lines: 70 | # if "Accuracy" in line: 71 | # auc = line.split(" ")[-1].strip() 72 | # for key in result_dict.keys(): 73 | # if key in result_path: 74 | # result_dict[key] = auc 75 | print( 76 | f"{model_name} {result_dict['zero-shot']} {result_dict['zero-shot-cot']} {result_dict['few-shot']} {result_dict['few-shot-cot']}" 77 | ) 78 | # print("{} {} {} {} {} {} {}".format(model_name, result_dict["zero-shot"], result_dict["zero-shot-cot"], result_dict["few-shot"], result_dict["few-shot-cot"], result_dict["neg-few-shot"], result_dict["neg-few-shot-cot"])) 79 | print("-" * 50) 80 | -------------------------------------------------------------------------------- /scripts/disc-cert/summary_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import glob 4 | import numpy as np 5 | 6 | 7 | result_path = "new_results_w_ans" 8 | 9 | for data_path in glob.glob(os.path.join(result_path, "*")): 10 | # print(data_path) 11 | data_name = os.path.basename(data_path) 12 | eval_neg = False 13 | if "neg" in data_name: 14 | eval_neg = True 15 | print(f"{data_name} zero-shot zero-shot-cot few-shot few-shot-cot") 16 | # print(os.path.basename(data_path)) 17 | for model_path in glob.glob(os.path.join(data_path, "*")): 18 | # print(model_path) 19 | model_name = os.path.basename(model_path) 20 | result_dict = { 21 | "zero-shot": None, 22 | "zero-shot-cot": None, 23 | "few-shot": None, 24 | "few-shot-cot": None, 25 | } 26 | for prediction_path in glob.glob(os.path.join(model_path, "*.jsonl")): 27 | total = 0 28 | with open(prediction_path, "r") as f: 29 | lines = f.readlines() 30 | all_result = [] 31 | for line in lines: 32 | try: 33 | data = json.loads(line.strip()) 34 | except: 35 | # print("Error in parsing line: ", line) 36 | # print(prediction_path) 37 | continue 38 | # exit() 39 | result_list = [] 40 | total += 1 41 | for r in data["details"]: 42 | response = r["raw_response"] 43 | prediction = 0 44 | if eval_neg: 45 | if ( 46 | "NO" in response.upper() 47 | and "YES" not in response.upper() 48 | ): 49 | prediction = 1 50 | else: 51 | if ( 52 | "YES" in response.upper() 53 | and "NO" not in response.upper() 54 | ): 55 | prediction = 1 56 | result_list.append(prediction) 57 | all_result.append(np.mean(result_list)) 58 | avg_result = np.mean(all_result) 59 | for key in result_dict.keys(): 60 | result_key = ( 61 | os.path.basename(prediction_path).split("_")[1].split(".")[0] 62 | ) 63 | if key == result_key: 64 | result_dict[key] = f"{avg_result:.4f}" 65 | # result_dict[key] = f"{avg_result:.4f} ({total})" 66 | # for result_path in glob.glob(os.path.join(model_path, "*.txt")): 67 | # # print(result_path) 68 | # with open(result_path, "r") as f: 69 | # lines = f.readlines() 70 | # for line in lines: 71 | # if "Accuracy" in line: 72 | # auc = line.split(" ")[-1].strip() 73 | # for key in result_dict.keys(): 74 | # if key in result_path: 75 | # result_dict[key] = auc 76 | print( 77 | f"{model_name} {result_dict['zero-shot']} {result_dict['zero-shot-cot']} {result_dict['few-shot']} {result_dict['few-shot-cot']}" 78 | ) 79 | # print("{} {} {} {} {} {} {}".format(model_name, result_dict["zero-shot"], result_dict["zero-shot-cot"], result_dict["few-shot"], result_dict["few-shot-cot"], result_dict["neg-few-shot"], result_dict["neg-few-shot-cot"])) 80 | print("-" * 50) 81 | -------------------------------------------------------------------------------- /llm_generation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import pandas as pd 5 | from tqdm import tqdm 6 | import torch 7 | from datasets import load_dataset, Dataset 8 | from generative_cert.prompts import * 9 | from llms import get_registed_model 10 | 11 | 12 | def prompt_builder(question, hint, mode="zero-shot"): 13 | if mode == "fewshot-cot-hint": 14 | query = FEWSHOT_COT_HINT.format(question=question) 15 | elif mode == "fewshot-cot-hint-ground": 16 | query = FEWSHOT_COT_HINT_GROUND.format(question=question, hint=hint) 17 | else: 18 | query = FEWSHOT_COT_ONLY.format(question=question) 19 | return query 20 | 21 | 22 | def load_data(path, dataset): 23 | in_path = path.format(dataset) 24 | with open(in_path, "r") as fin: 25 | results = json.load(fin) 26 | # ignore_1hop: 27 | results = results["min_2hop"] + results["min_multihop"] 28 | dat = pd.DataFrame(results)[["id", "question", "ground_truth_paths"]] 29 | return dat 30 | 31 | 32 | def write_results(out_dir, file_name): 33 | if not os.path.exists(out_dir): 34 | os.makedirs(out_dir) 35 | out_path = os.path.join(out_dir, file_name) 36 | fout = open(out_path, "w") 37 | return fout 38 | 39 | 40 | def predict(data, args, model): 41 | data_id, row = data 42 | question = row["question"] 43 | hint = " -> ".join([tri[1] for tri in row["ground_truth_paths"][0]]) 44 | 45 | query = prompt_builder(question, hint, args.mode) 46 | query = model.prepare_model_prompt(query) 47 | response = model.generate_sentence(query) 48 | row["reasoning_ans"] = response 49 | return row 50 | 51 | 52 | def main(args, LLM): 53 | model = LLM(args) 54 | print("Prepare pipeline for inference...") 55 | model.prepare_for_inference() 56 | if args.dataset == "all": 57 | dataset = ["cwq", "FreebaseQA"] 58 | else: 59 | dataset = [args.dataset] 60 | for d in dataset: 61 | out_dir, file_name = ( 62 | os.path.join(args.out_dir, d, args.model_name, args.exp_name), 63 | "llm_prompt_response.jsonl", 64 | ) 65 | fout = write_results(out_dir, file_name) 66 | ques_dat = load_data(args.in_dir, d) 67 | if args.run_sample: 68 | ques_dat = ques_dat.iloc[:3] 69 | with tqdm(ques_dat.iterrows(), total=len(ques_dat)) as phar: 70 | for data in phar: 71 | r = predict(data, args, model) 72 | fout.write(json.dumps(r.to_dict()) + "\n") 73 | fout.close() 74 | 75 | 76 | if __name__ == "__main__": 77 | argparser = argparse.ArgumentParser() 78 | argparser.add_argument("--model_name", help="LLM", type=str, required=True) 79 | argparser.add_argument( 80 | "--mode", 81 | help="LLM", 82 | type=str, 83 | choices=["fewshot-cot-only", "fewshot-cot-hint", "fewshot-cot-hint-ground"], 84 | default="fewshot-cot-only", 85 | required=True, 86 | ) 87 | argparser.add_argument( 88 | "--dataset", 89 | help="dataset name", 90 | type=str, 91 | choices=["cwq", "FreebaseQA", "grail_qa", "all"], 92 | required=True, 93 | ) 94 | argparser.add_argument( 95 | "--in_dir", 96 | help="directory containing question data", 97 | type=str, 98 | default="LLMReasoningCert/data/{}/gpt-3.5-turbo/test/splitted_ground_truth_paths.json", 99 | ) 100 | argparser.add_argument( 101 | "--out_dir", 102 | help="directory containing answer output data", 103 | type=str, 104 | default="LLMReasoningCert/data", 105 | ) 106 | argparser.add_argument( 107 | "--run_sample", help="only run 3 samples", action="store_true" 108 | ) 109 | argparser.add_argument("--exp_name", type=str, default="", help="Experiment name") 110 | 111 | args, _ = argparser.parse_known_args() 112 | LLM = get_registed_model(args.model_name) 113 | LLM.add_args(argparser) 114 | args = argparser.parse_args() 115 | 116 | main(args, LLM) 117 | -------------------------------------------------------------------------------- /extract_subgraph/graph_loader.py: -------------------------------------------------------------------------------- 1 | from generative_cert.utils.utils import * 2 | import networkx as nx 3 | import random 4 | 5 | PROMPT = """Given this reasoning path, do you think this is a valid path to answer the question? If yes please answer "YES", otherwise please answer "NO". 6 | 7 | Reasoning path: 8 | {path} 9 | 10 | Question: 11 | {question} 12 | """ 13 | 14 | 15 | class GraphProcess(object): 16 | def __init__(self, data, args) -> None: 17 | self.args = args 18 | self.model = args.model_name 19 | self.graph = None 20 | self.chat_log = [] 21 | self.memory = [] 22 | # Init grpah and memory 23 | self.graph = build_graph(data["graph"]) 24 | question = data["question"] 25 | if not question.endswith("?"): 26 | question += "?" 27 | 28 | self.data = data 29 | self.question = question 30 | 31 | def log(self, query, response): 32 | self.chat_log.append({"query": query, "response": response}) 33 | 34 | def get_truth_paths(self): 35 | entities = self.data["q_entity"] 36 | answer_entities = self.data["a_entity"] 37 | 38 | # Select paths 39 | paths = [] 40 | for h in entities: 41 | if h not in self.graph: 42 | continue 43 | for t in answer_entities: 44 | if t not in self.graph: 45 | continue 46 | try: 47 | for p in nx.all_shortest_paths(self.graph, h, t): 48 | paths.append(p) 49 | except: 50 | pass 51 | # Add relation to paths 52 | result_paths = [] 53 | for p in paths: 54 | tmp = [] 55 | for i in range(len(p) - 1): 56 | u = p[i] 57 | v = p[i + 1] 58 | tmp.append((u, self.graph[u][v]["relation"], v)) 59 | result_paths.append(tmp) 60 | return result_paths 61 | 62 | # def get_neg_paths(self, truth_paths): 63 | # ''' 64 | # Get negative paths 65 | # ''' 66 | # entities = self.data['q_entity'] 67 | # answer_entities = self.data['a_entity'] 68 | # neg_paths = [] 69 | # for truth_path in truth_paths: 70 | # l = len(truth_path) 71 | # for _ in range(self.args.n_neg): 72 | 73 | # for _ in range(self.args.n_neg): 74 | # # Get random entity pair 75 | # h = random.choice(list(self.graph.nodes)) 76 | # t = random.choice(list(self.graph.nodes)) 77 | # if h == t: 78 | # continue 79 | # try: 80 | # for p in nx.all_shortest_paths(self.graph, h, t): 81 | # neg_paths.append(p) 82 | # except: 83 | # pass 84 | # # Add relation to paths 85 | # result_paths = [] 86 | # for p in neg_paths: 87 | # tmp = [] 88 | # for i in range(len(p)-1): 89 | # u = p[i] 90 | # v = p[i+1] 91 | # tmp.append((u, self.graph[u][v]['relation'], v)) 92 | # result_paths.append(tmp) 93 | # return result_paths 94 | 95 | # def get_evaluate_data(self): 96 | # truth_paths = self.get_truth_paths() 97 | # if len(truth_paths) > self.args.n_pos: 98 | # truth_paths = random.sample(truth_paths, self.args.n_pos) 99 | # neg_paths = self.get_neg_paths(truth_paths) 100 | 101 | def predict_path(self, path): 102 | """ 103 | Check path validity and return the answer 104 | 105 | Args: 106 | path : _description_ 107 | """ 108 | results = {} 109 | for p in path: 110 | path_string = path_to_string(p) 111 | prompt = PROMPT.format(path=path_string, question=self.question) 112 | response = query(prompt) 113 | if "YES" in response: 114 | results[path_string] = 1 115 | else: 116 | results[path_string] = 0 117 | return results 118 | -------------------------------------------------------------------------------- /llms/base_hf_causal_model.py: -------------------------------------------------------------------------------- 1 | from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM 2 | import torch 3 | from .base_language_model import BaseLanguageModel 4 | import os 5 | import dotenv 6 | import distutils.util 7 | 8 | dotenv.load_dotenv() 9 | 10 | HF_TOKEN = os.getenv("HF_TOKEN") 11 | 12 | 13 | class HfCausalModel(BaseLanguageModel): 14 | DTYPE = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} 15 | 16 | @staticmethod 17 | def add_args(parser): 18 | parser.add_argument( 19 | "--model_path", type=str, help="HUGGING FACE MODEL or model path" 20 | ) 21 | parser.add_argument( 22 | "--max_new_tokens", type=int, help="max length", default=512 23 | ) 24 | parser.add_argument( 25 | "--top_k", default=None, type=int, help="Generate params: top-k sampling" 26 | ) 27 | parser.add_argument( 28 | "--top_p", default=1.0, type=float, help="Generate params: nucleus sampling" 29 | ) 30 | parser.add_argument( 31 | "--num_return_sequences", 32 | default=1, 33 | type=int, 34 | help="The number of returned sequences", 35 | ) 36 | parser.add_argument( 37 | "--batch_size", default=1, type=int, help="batch size for pipeline" 38 | ) 39 | parser.add_argument( 40 | "--temperature", 41 | type=float, 42 | default=0.7, 43 | help="Generate params: temperature", 44 | ) 45 | parser.add_argument( 46 | "--flash_atten_2", 47 | default=False, 48 | type=lambda x: bool(distutils.util.strtobool(x)), 49 | help="enable flash attention 2", 50 | ) 51 | 52 | parser.add_argument("--dtype", choices=["fp32", "fp16", "bf16"], default="fp16") 53 | parser.add_argument("--quant", choices=["none", "4bit", "8bit"], default="none") 54 | 55 | def __init__(self, args): 56 | self.args = args 57 | 58 | def prepare_for_inference(self): 59 | self.tokenizer = AutoTokenizer.from_pretrained( 60 | self.args.model_path, token=HF_TOKEN, trust_remote_code=True, use_fast=False 61 | ) 62 | model = AutoModelForCausalLM.from_pretrained( 63 | self.args.model_path, 64 | device_map="auto", 65 | token=HF_TOKEN, 66 | torch_dtype=self.DTYPE.get(self.args.dtype, None), 67 | load_in_8bit=self.args.quant == "8bit", 68 | load_in_4bit=self.args.quant == "4bit", 69 | trust_remote_code=True, 70 | use_flash_attention_2=self.args.flash_atten_2, 71 | ) 72 | self.generator = pipeline( 73 | "text-generation", model=model, tokenizer=self.tokenizer 74 | ) 75 | # self.generator.tokenizer.pad_token_id = self.generator.model.config.eos_token_id 76 | 77 | @torch.inference_mode() 78 | def generate_sentence(self, llm_input): 79 | if self.args.top_k: 80 | outputs = self.generator( 81 | llm_input, 82 | return_full_text=False, 83 | max_new_tokens=self.args.max_new_tokens, 84 | do_sample=True, 85 | temperature=self.args.temperature, 86 | top_k=self.args.top_k, 87 | top_p=self.args.top_p, 88 | num_return_sequences=self.args.num_return_sequences, 89 | pad_token_id=self.tokenizer.eos_token_id, 90 | batch_size=self.args.batch_size, 91 | ) 92 | else: 93 | outputs = self.generator( 94 | llm_input, 95 | return_full_text=False, 96 | max_new_tokens=self.args.max_new_tokens, 97 | num_return_sequences=self.args.num_return_sequences, 98 | pad_token_id=self.tokenizer.eos_token_id, 99 | batch_size=self.args.batch_size, 100 | ) 101 | if self.args.num_return_sequences == 1: 102 | return outputs[0]["generated_text"] # type: ignore 103 | return [out["generated_text"] for out in outputs] 104 | -------------------------------------------------------------------------------- /extract_subgraph/deal_cvt.py: -------------------------------------------------------------------------------- 1 | # https://github.com/microsoft/FastRDFStore/blob/master/FastRDFStore/FastRDFStore.cs 2 | # code reference 3 | from struct import * 4 | 5 | 6 | class BinaryStream: 7 | def __init__(self, base_stream): 8 | self.base_stream = base_stream 9 | 10 | def readByte(self): 11 | return self.base_stream.read(1) 12 | 13 | def readBytes(self, length): 14 | return self.base_stream.read(length) 15 | 16 | def readChar(self): 17 | return self.unpack("b") 18 | 19 | def readUChar(self): 20 | return self.unpack("B") 21 | 22 | def readBool(self): 23 | return self.unpack("?") 24 | 25 | def readInt16(self): 26 | return self.unpack("h", 2) 27 | 28 | def readUInt16(self): 29 | return self.unpack("H", 2) 30 | 31 | def readInt32(self): 32 | return self.unpack("i", 4) 33 | 34 | def readUInt32(self): 35 | return self.unpack("I", 4) 36 | 37 | def readInt64(self): 38 | return self.unpack("q", 8) 39 | 40 | def readUInt64(self): 41 | return self.unpack("Q", 8) 42 | 43 | def readFloat(self): 44 | return self.unpack("f", 4) 45 | 46 | def readDouble(self): 47 | return self.unpack("d", 8) 48 | 49 | def readString(self): 50 | # length = self.readUInt16() 51 | length = self.unpack(" 3: 105 | return subject[0:4] 106 | elif len(subject) > 2: 107 | return subject[0:3] 108 | else: 109 | return subject[0:2] 110 | else: 111 | if len(subject) > 1: 112 | return subject[0:2] 113 | return subject[0:1] 114 | 115 | 116 | def is_cvt(subject, cvt_nodes): 117 | tp_key = get_key(subject) 118 | if tp_key in cvt_nodes: 119 | if subject in cvt_nodes[tp_key]: 120 | return cvt_nodes[tp_key][subject] 121 | return False 122 | 123 | 124 | def load_cvt(): 125 | filename = "LLMReasoningCert/data/data/cvtnodes.bin" 126 | f = open(filename, "rb") 127 | reader = BinaryStream(f) 128 | dictionariesCount = reader.readInt32() 129 | # print(dictionariesCount) 130 | to_return = {} 131 | for i in range(0, dictionariesCount): 132 | key = bytes.decode(reader.readString()) 133 | # covert byte to string 134 | count = reader.readInt32() 135 | # print(key, count) 136 | dict_tp = {} 137 | for j in range(0, count): 138 | mid = bytes.decode(reader.readString()) 139 | isCVT = reader.readBool() 140 | dict_tp[mid] = isCVT 141 | to_return[key] = dict_tp 142 | return to_return 143 | 144 | 145 | if __name__ == "__main__": 146 | to_return = load_cvt() 147 | tp_dict = to_return["m.01"] 148 | num = 0 149 | for item in tp_dict: 150 | print(item, tp_dict[item], type(tp_dict[item])) 151 | num += 1 152 | if num == 10: 153 | exit(-1) 154 | -------------------------------------------------------------------------------- /preprocess_data/subgraph_discriminative_cert.py: -------------------------------------------------------------------------------- 1 | from generative_cert.utils.utils import build_graph, get_edge_data 2 | import random 3 | import pickle 4 | import pandas as pd 5 | import os 6 | import json 7 | from tqdm import tqdm 8 | 9 | tqdm.pandas() 10 | 11 | 12 | def create_misguide_path(G, start_node, end_nodes, num_hop): 13 | def create_random_path(G, start_node, num_hop): 14 | random_path = [] 15 | current_node = start_node 16 | for _ in range(num_hop): 17 | neighbors = list(G.neighbors(current_node)) 18 | if neighbors: 19 | next_node = random.choice(neighbors) 20 | rel = get_edge_data(G, current_node, next_node) 21 | random_path.append([current_node, rel, next_node]) 22 | current_node = next_node 23 | else: 24 | break 25 | return random_path 26 | 27 | misguide_path = None 28 | for _ in range(20): 29 | if (not misguide_path) or (misguide_path[-1][-1] in end_nodes): 30 | misguide_path = create_random_path(G, start_node, num_hop) 31 | if (misguide_path[-1][-1] in end_nodes) or len(misguide_path) < 2: 32 | raise ValueError("cannot find a misguide path") 33 | return misguide_path 34 | 35 | 36 | def misguide_path_from_ref_paths(G, ground_truth_paths): 37 | # get question entity and answer entities from ground_truth_paths 38 | ques_ent, ans_ents = None, [] 39 | if isinstance(ground_truth_paths[0][0], str): 40 | ques_ent = ground_truth_paths[0][0] 41 | ans_ents = [ground_truth_paths[-1][-1]] 42 | else: 43 | ques_ent = ground_truth_paths[0][0][0] 44 | ans_ents = [path[-1][-1] for path in ground_truth_paths] 45 | 46 | # create misguide path 47 | num_hop = random.choice([2, 3]) 48 | misguide_path = create_misguide_path(G, ques_ent, ans_ents, num_hop) 49 | for _ in range(50): 50 | if [ 51 | p 52 | for p in ground_truth_paths 53 | if "".join(["".join(t) for t in p]) 54 | in "".join(["".join(t) for t in misguide_path]) 55 | ]: 56 | misguide_path = create_misguide_path(G, ques_ent, ans_ents, num_hop) 57 | else: 58 | break 59 | if [ 60 | p 61 | for p in ground_truth_paths 62 | if "".join(["".join(t) for t in p]) 63 | in "".join(["".join(t) for t in misguide_path]) 64 | ]: 65 | raise ValueError("misguide_path contains ground_truth_paths") 66 | return misguide_path 67 | 68 | 69 | def create_misguide_path_for_dat(dataset, out_path): 70 | # load triplets 71 | triplets_path = ( 72 | f"LLMReasoningCert/data/db_extract/{dataset}/only_test_set/origin/triplets.pkl" 73 | ) 74 | with open(triplets_path, "rb") as f: 75 | triplets = pickle.load(f) 76 | # create graph for whole triplets 77 | G = build_graph(triplets) 78 | 79 | # load data from dataset 80 | in_path = "LLMReasoningCert/data/{}/gpt-3.5-turbo/test/splitted_ground_truth_paths.json".format( 81 | dataset 82 | ) 83 | with open(in_path, "r") as fin: 84 | results = json.load(fin) 85 | # ignore_1hop: 86 | results = results["min_2hop"] + results["min_multihop"] 87 | dat = pd.DataFrame(results)[["id", "question", "ground_truth_paths"]] 88 | 89 | # get misguide paths 90 | dat["misguide_path"] = dat["ground_truth_paths"].progress_apply( 91 | lambda p: misguide_path_from_ref_paths(G, p) 92 | ) 93 | 94 | # save 95 | dat.to_json(out_path, orient="records") 96 | print( 97 | dat.apply( 98 | lambda r: r["misguide_path"][-1][-1] 99 | in [p[-1][-1] for p in r["ground_truth_paths"]], 100 | axis=1, 101 | ).sum() 102 | / len(dat) 103 | ) 104 | 105 | 106 | if __name__ == "__main__": 107 | dataset = "grail_qa" # cwq 108 | out_path = f"LLMReasoningCert/data/{dataset}/misguide_path.jsonl" 109 | create_misguide_path_for_dat(dataset, out_path) 110 | 111 | df = pd.read_json(out_path) 112 | func = lambda ground_truth_paths, misguide_path: len( 113 | [ 114 | p 115 | for p in ground_truth_paths 116 | if "".join(["".join(t) for t in p]) 117 | in "".join(["".join(t) for t in misguide_path]) 118 | ] 119 | ) 120 | print( 121 | df.apply( 122 | lambda r: func(r["ground_truth_paths"], r["misguide_path"]), axis=1 123 | ).sum() 124 | ) 125 | -------------------------------------------------------------------------------- /llms/chatgpt.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import openai 4 | from openai import OpenAI 5 | from .base_language_model import BaseLanguageModel 6 | import dotenv 7 | import tiktoken 8 | 9 | dotenv.load_dotenv() 10 | openai.api_key = os.getenv("OPENAI_API_KEY") 11 | os.environ["TIKTOKEN_CACHE_DIR"] = "./tmp" 12 | 13 | OPENAI_MODEL = ["gpt-4", "gpt-3.5-turbo"] 14 | 15 | 16 | def get_token_limit(model="gpt-4"): 17 | """Returns the token limitation of provided model""" 18 | if model in ["gpt-4", "gpt-4-0613"]: 19 | num_tokens_limit = 8192 20 | elif model in ["gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613"]: 21 | num_tokens_limit = 16384 22 | elif model in [ 23 | "gpt-3.5-turbo", 24 | "gpt-3.5-turbo-0613", 25 | "text-davinci-003", 26 | "text-davinci-002", 27 | ]: 28 | num_tokens_limit = 4096 29 | else: 30 | raise NotImplementedError( 31 | f"""get_token_limit() is not implemented for model {model}.""" 32 | ) 33 | return num_tokens_limit 34 | 35 | 36 | PROMPT = """{instruction} 37 | 38 | {input}""" 39 | 40 | 41 | class ChatGPT(BaseLanguageModel): 42 | @staticmethod 43 | def add_args(parser): 44 | parser.add_argument("--retry", type=int, help="retry time", default=5) 45 | parser.add_argument( 46 | "--top_p", default=1.0, type=float, help="Generate params: nucleus sampling" 47 | ) 48 | parser.add_argument( 49 | "--num_return_sequences", 50 | default=1, 51 | type=int, 52 | help="The number of returned sequences", 53 | ) 54 | parser.add_argument( 55 | "--temperature", 56 | type=float, 57 | default=0.7, 58 | help="Generate params: temperature", 59 | ) 60 | parser.add_argument("--model_path", type=str, default="None") 61 | parser.add_argument("--quant", choices=["none", "4bit", "8bit"], default="none") 62 | 63 | def __init__(self, args): 64 | super().__init__(args) 65 | self.retry = args.retry 66 | self.model_name = args.model_name 67 | self.maximun_token = get_token_limit(self.model_name) 68 | self.redundant_tokens = 150 69 | self.args = args 70 | 71 | def tokenize(self, text): 72 | """Returns the number of tokens used by a list of messages.""" 73 | try: 74 | encoding = tiktoken.encoding_for_model(self.model_name) 75 | num_tokens = len(encoding.encode(text)) 76 | except KeyError: 77 | raise KeyError(f"Warning: model {self.model_name} not found.") 78 | return num_tokens + self.redundant_tokens 79 | 80 | def prepare_for_inference(self, model_kwargs={}): 81 | client = OpenAI( 82 | api_key=os.environ[ 83 | "OPENAI_API_KEY" 84 | ], # this is also the default, it can be omitted 85 | ) 86 | self.client = client 87 | 88 | def prepare_model_prompt(self, query): 89 | """ 90 | Add model-specific prompt to the input 91 | """ 92 | return query 93 | 94 | def generate_sentence(self, llm_input): 95 | query = [{"role": "user", "content": llm_input}] 96 | cur_retry = 0 97 | num_retry = self.retry 98 | # Chekc if the input is too long 99 | input_length = self.tokenize(llm_input) 100 | if input_length > self.maximun_token: 101 | print( 102 | f"Input length {input_length} is too long. The maximum token is {self.maximun_token}.\n Right tuncate the input to {self.maximun_token} tokens." 103 | ) 104 | llm_input = llm_input[: self.maximun_token] 105 | while cur_retry <= num_retry: 106 | try: 107 | response = self.client.chat.completions.create( 108 | model=self.model_name, 109 | messages=query, 110 | timeout=60, 111 | temperature=self.args.temperature, 112 | top_p=self.args.top_p, 113 | n=self.args.num_return_sequences, 114 | ) 115 | result = ( 116 | [res.message.content.strip() for res in response.choices] 117 | if len(response.choices) > 1 118 | else response.choices[0].message.content.strip() 119 | ) # type: ignore 120 | return result 121 | except Exception as e: 122 | print("Message: ", llm_input) 123 | print("Number of token: ", self.tokenize(llm_input)) 124 | print(e) 125 | time.sleep(30) 126 | cur_retry += 1 127 | continue 128 | return None 129 | -------------------------------------------------------------------------------- /preprocess_data/ground_truth_paths.py: -------------------------------------------------------------------------------- 1 | import generative_cert.utils.utils as utils 2 | import os 3 | import argparse 4 | from datasets import load_dataset 5 | import json 6 | from tqdm import tqdm 7 | import multiprocessing as mp 8 | from extract_subgraph.graph_loader import GraphProcess 9 | from functools import partial 10 | import random 11 | 12 | 13 | def get_output_file(path, force=False): 14 | if not os.path.exists(path) or force: 15 | fout = open(path, "w") 16 | return fout, [] 17 | else: 18 | with open(path, "r") as f: 19 | processed_results = [] 20 | for line in f: 21 | results = json.loads(line) 22 | processed_results.append(results["id"]) 23 | fout = open(path, "a") 24 | return fout, processed_results 25 | 26 | 27 | def process_data(data, args, processed_list): 28 | question = data["question"] 29 | id = data["id"] 30 | if id in processed_list: 31 | return None 32 | got = GraphProcess(data, args) 33 | ground_truth_paths = got.get_truth_paths() 34 | if len(ground_truth_paths) == 0: 35 | return None 36 | if len(ground_truth_paths) > args.n_pos: 37 | ground_truth_paths = random.sample(ground_truth_paths, args.n_pos) 38 | 39 | result = { 40 | "id": id, 41 | "question": question, 42 | "ground_truth_paths": ground_truth_paths, 43 | # 'min_num_hop': min(map(len, ground_truth_paths)) 44 | } 45 | return result 46 | 47 | 48 | def main(args): 49 | input_file = os.path.join(args.data_path, args.d) 50 | output_dir = os.path.join(args.predict_path, args.d, args.save_name, args.split) 51 | print("Save results to: ", output_dir) 52 | 53 | # Load dataset 54 | dataset = load_dataset( 55 | "json", data_files=f"{os.path.join(input_file,args.split)}.jsonl" 56 | )["train"] 57 | 58 | # Predict 59 | if not os.path.exists(output_dir): 60 | os.makedirs(output_dir) 61 | 62 | # Save args file 63 | with open(os.path.join(output_dir, "args.txt"), "w") as f: 64 | json.dump(args.__dict__, f, indent=2) 65 | 66 | fout, processed_list = get_output_file( 67 | os.path.join(output_dir, "ground_truth_paths.jsonl"), force=args.force 68 | ) 69 | count = 0 70 | if args.n == 1: 71 | for data in tqdm(dataset): 72 | res = process_data(data, args, processed_list) 73 | if res is None: 74 | count += 1 75 | continue 76 | fout.write(json.dumps(res) + "\n") 77 | fout.flush() 78 | else: 79 | with mp.Pool(args.n) as pool: 80 | for res in tqdm( 81 | pool.imap_unordered( 82 | partial(process_data, args=args, processed_list=processed_list), 83 | dataset, 84 | ), 85 | total=len(dataset), 86 | ): 87 | if res is None: 88 | continue 89 | fout.write(json.dumps(res) + "\n") 90 | fout.flush() 91 | fout.close() 92 | print(f"Dont found {count}/{len(dataset)} groundtruth reasoning") 93 | 94 | 95 | if __name__ == "__main__": 96 | argparser = argparse.ArgumentParser() 97 | argparser.add_argument( 98 | "--data_path", 99 | type=str, 100 | default="LLMReasoningCert/data", 101 | ) 102 | argparser.add_argument( 103 | "--predict_path", 104 | type=str, 105 | default="LLMReasoningCert/data", 106 | ) 107 | argparser.add_argument("--save_name", "-p", type=str, help="save name for results") 108 | argparser.add_argument("--seed", type=int, default=42, help="random seed") 109 | argparser.add_argument( 110 | "--model_name", 111 | "-m", 112 | type=str, 113 | help="save_name", 114 | choices=["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4"], 115 | default="gpt-3.5-turbo", 116 | ) 117 | argparser.add_argument( 118 | "--force", "-f", action="store_true", help="force to overwrite the results" 119 | ) 120 | argparser.add_argument("-n", default=1, type=int, help="number of processes") 121 | argparser.add_argument("--n_pos", default=5, type=int, help="number of postive") 122 | argparser.add_argument( 123 | "-neg", default=1, type=int, help="number of negative samples" 124 | ) 125 | argparser.add_argument("--debug", action="store_true") 126 | 127 | args = argparser.parse_args() 128 | 129 | if args.save_name is None: 130 | args.save_name = args.model_name 131 | 132 | utils.set_seed(args.seed) 133 | 134 | for d in ["cwq", "grail_qa"]: 135 | for split in ["test"]: 136 | print("Handling data {} and split {}: ".format(d, split)) 137 | args.d = d 138 | args.split = split 139 | main(args) 140 | -------------------------------------------------------------------------------- /finegrained_gen_cert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import pandas as pd 5 | import seaborn as sn 6 | from sklearn.metrics import confusion_matrix 7 | from datasets import load_dataset 8 | from minineedle import needle 9 | 10 | 11 | def prec_eval(row, prob_thres=0.7, ent_thres=0.5): 12 | # init precision value 13 | max_match = 0.0 14 | min_edit = 1.0 15 | p_incorrect_reasoning = 0 16 | 17 | # get triples 18 | ref_path = row["ground_truth_paths"] 19 | extracted_triplets = row["10_extracted_triplets"] 20 | pred_path = [ 21 | [ 22 | v["best_triple"]["triplet"] 23 | if ( 24 | v["best_triple"]["prob"] >= prob_thres 25 | and v["best_triple"]["head_score"] >= ent_thres 26 | and v["best_triple"]["tail_score"] >= ent_thres 27 | ) 28 | else ["h", "r", "t"] 29 | for v in res 30 | ] 31 | for res in extracted_triplets 32 | ] 33 | # sort and linearize to make sure the order in path 34 | ref_path = [[" ".join(sorted(v)) for v in res] for res in ref_path] 35 | pred_path = [[" ".join(sorted(v)) for v in res] for res in pred_path] 36 | 37 | for pred in pred_path: # check both self-consistency & a seq 38 | for ref in ref_path: # check each reference path 39 | alignment = needle.NeedlemanWunsch(pred, ref) 40 | alignment.align() 41 | score = alignment._score 42 | match = int( 43 | round(alignment._identity / 100 * len(alignment._alseq1), 0) 44 | ) / len(ref) # maybe exist redundant 45 | edit = 1 - alignment._identity / 100 46 | if match > max_match: 47 | max_match = match 48 | min_edit = edit 49 | elif match == max_match: 50 | if edit < min_edit: 51 | min_edit = edit 52 | 53 | p_incorrect_reasoning = 0 if match == 1 else 1 54 | p_incorrect_answer = row["short_eval"]["p_incorrect_answer"] 55 | id = row["id"] 56 | return { 57 | "id": id, 58 | "p_incorrect_answer": p_incorrect_answer, 59 | "p_incorrect_reasoning": p_incorrect_reasoning, 60 | "match_rate": match, 61 | "edit_rate": min_edit, 62 | } 63 | 64 | 65 | def calc_and_save_result_metrics(out_dir, dat): 66 | if not os.path.exists(out_dir): 67 | os.makedirs(out_dir, exist_ok=True) 68 | dat.to_json(os.path.join(out_dir, "data.jsonl")) 69 | 70 | ans_acc = ( 71 | round(len(dat.filter(lambda x: x["p_incorrect_answer"] == 0)) / len(dat), 4) 72 | * 100 73 | ) 74 | reasoning_acc = ( 75 | round(len(dat.filter(lambda x: x["p_incorrect_reasoning"] == 0)) / len(dat), 4) 76 | * 100 77 | ) 78 | # confusion matrix 79 | classes = [1, 0] 80 | dat = dat.to_pandas() 81 | cfm = confusion_matrix(dat["p_incorrect_answer"], dat["p_incorrect_reasoning"]) 82 | tn, fp, fn, tp = cfm.ravel() 83 | ## dic 84 | cfm_dic = { 85 | "IA_UR": str(tp), # incorrect answer + uncertified reasoning 86 | "CA_UR": str(fp), # correct answer + uncertified reasoning 87 | "IA_CR": str(fn), # incorrect answer + certified reasoning 88 | "CA_CR": str(tn), # correct answer + certified reasoning 89 | } 90 | ## image 91 | cfm_df = pd.DataFrame( 92 | cfm, 93 | index=["correct answer", "incorrect answer"], 94 | columns=["certified reasoning", "uncertified reasoning"], 95 | ) 96 | sn.heatmap(cfm_df, annot=True, fmt="d").figure.savefig( 97 | os.path.join(out_dir, "confusion_matrix.png") 98 | ) 99 | 100 | # write 101 | with open(os.path.join(out_dir, "metric.txt"), "w") as fout: 102 | fout.write(f"answer_accuracy {ans_acc}%" + "\n") 103 | fout.write(f"reasoning_accuracy {reasoning_acc}%" + "\n") 104 | fout.write(json.dumps(cfm_dic)) 105 | 106 | 107 | def main(args, prob_thres, ent_thres): 108 | # load data 109 | path = args.dat_path 110 | dat = load_dataset("json", data_files=path)["train"] 111 | # calculate precision 112 | dat = dat.map( 113 | lambda x: prec_eval(x, prob_thres, ent_thres), 114 | batched=False, 115 | num_proc=16, 116 | remove_columns=dat.column_names, 117 | ) 118 | # save 119 | out_dir = os.path.join(os.path.dirname(path), "groundtruth") 120 | calc_and_save_result_metrics(out_dir, dat) 121 | 122 | 123 | if __name__ == "__main__": 124 | argparser = argparse.ArgumentParser() 125 | argparser.add_argument( 126 | "--dat_path", 127 | type=str, 128 | required=True, 129 | help="LLMReasoningCert/data/cwq/gpt-3.5-turbo/gpt-3.5-turbo-fewshot-cot-only-cwq-temp-0.7-p-0.9-consistency-1/llm_prompt_response.jsonl", 130 | ) 131 | args = argparser.parse_args() 132 | 133 | prob_thres = 0.7 134 | ent_thres = 0.5 135 | main(args, prob_thres, ent_thres) 136 | -------------------------------------------------------------------------------- /virtuoso_db/virtuoso.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # This script provides a convenient wrapper for the Virtuoso SPARQL server. 4 | # Adapted from Sempre (https://github.com/percyliang/sempre) 5 | 6 | import os 7 | import sys 8 | import subprocess 9 | import argparse 10 | 11 | virtuosoPath = "LLMReasoningCert/LLMReasonCert/virtuoso_db/virtuoso-opensource" 12 | if not os.path.exists(virtuosoPath): 13 | print(f"{virtuosoPath} does not exist") 14 | sys.exit(1) 15 | 16 | 17 | # Virtuoso has two services: the server (isql) and SPARQL endpoint 18 | def isqlPort(port): 19 | return 10000 + port 20 | 21 | 22 | def httpPort(port): 23 | return port 24 | 25 | 26 | def run(command): 27 | print(f"RUNNING: {command}") 28 | res = subprocess.run(command, shell=True, stdout=subprocess.PIPE) 29 | return res.stdout 30 | 31 | 32 | def start(dbPath, port): 33 | if not os.path.exists(dbPath): 34 | os.mkdir(dbPath) 35 | 36 | # Recommended: 70% of RAM, each buffer is 8K 37 | # Use a fraction of the free RAM. The result may vary across runs. 38 | # memFree = parseInt(`cat /proc/meminfo | grep MemFree | awk '{print $2}'`) # KB 39 | # Use a fraction of the total RAM. The result is the same across runs. 40 | memFree = int(run("cat /proc/meminfo | grep MemTotal | awk '{print $2}'")) # KB 41 | numberOfBuffers = memFree * 0.15 / 8 42 | maxDirtyBuffers = numberOfBuffers / 2 43 | print( 44 | f"{memFree} KB free, using {numberOfBuffers} buffers, {maxDirtyBuffers} dirty buffers" 45 | ) 46 | 47 | # Configuration options: 48 | # http://docs.openlinksw.com/virtuoso/dbadm.html 49 | # http://virtuoso.openlinksw.com/dataspace/doc/dav/wiki/Main/VirtConfigScale 50 | config = ( 51 | f"[Database]\n" 52 | f"DatabaseFile = {dbPath}/virtuoso.db\n" 53 | f"ErrorLogFile = {dbPath}/virtuoso.log\n" 54 | f"LockFile = {dbPath}/virtuoso.lck\n" 55 | f"TransactionFile = {dbPath}/virtuoso.trx\n" 56 | f"xa_persistent_file = {dbPath}/virtuoso.pxa\n" 57 | f"ErrorLogLevel = 7\n" 58 | f"FileExtend = 200\n" 59 | f"MaxCheckpointRemap = 2000\n" 60 | f"Striping = 0\n" 61 | f"TempStorage = TempDatabase\n" 62 | f"\n" 63 | f"[TempDatabase]\n" 64 | f"DatabaseFile = {dbPath}/virtuoso-temp.db\n" 65 | f"TransactionFile = {dbPath}/virtuoso-temp.trx\n" 66 | f"MaxCheckpointRemap = 2000\n" 67 | f"Striping = 0\n" 68 | f"\n" 69 | f"[Parameters]\n" 70 | f"ServerPort = {isqlPort(port)}\n" 71 | f"LiteMode = 0\n" 72 | f"DisableUnixSocket = 1\n" 73 | f"DisableTcpSocket = 0\n" 74 | f"ServerThreads = 100 ; increased from 20\n" 75 | f"CheckpointInterval = 60\n" 76 | f"O_DIRECT = 1 ; increased from 0\n" 77 | f"CaseMode = 2\n" 78 | f"MaxStaticCursorRows = 100000\n" 79 | f"CheckpointAuditTrail = 0\n" 80 | f"AllowOSCalls = 0\n" 81 | f"SchedulerInterval = 10\n" 82 | f"DirsAllowed = .\n" 83 | f"ThreadCleanupInterval = 0\n" 84 | f"ThreadThreshold = 10\n" 85 | f"ResourcesCleanupInterval = 0\n" 86 | f"FreeTextBatchSize = 100000\n" 87 | # f"SingleCPU = 0\n" 88 | f"PrefixResultNames = 0\n" 89 | f"RdfFreeTextRulesSize = 100\n" 90 | f"IndexTreeMaps = 256\n" 91 | f"MaxMemPoolSize = 200000000\n" 92 | f"PrefixResultNames = 0\n" 93 | f"MacSpotlight = 0\n" 94 | f"IndexTreeMaps = 64\n" 95 | f"NumberOfBuffers = {numberOfBuffers}\n" 96 | f"MaxDirtyBuffers = {maxDirtyBuffers}\n" 97 | f"\n" 98 | f"[SPARQL]\n" 99 | f"ResultSetMaxRows = 50000\n" 100 | f"MaxQueryCostEstimationTime = 600 ; in seconds (increased)\n" 101 | f"MaxQueryExecutionTime = 180; in seconds (increased)\n" 102 | f"\n" 103 | f"[HTTPServer]\n" 104 | f"ServerPort = {httpPort(port)}\n" 105 | f"Charset = UTF-8\n" 106 | f"ServerThreads = 15 ; increased from unknown\n" 107 | ) 108 | 109 | configPath = f"{dbPath}/virtuoso.ini" 110 | print(config) 111 | print() 112 | print(configPath) 113 | print(f"==== Starting Virtuoso server for {dbPath} on port {port}...") 114 | with open(configPath, "w") as f: 115 | f.write(config) 116 | run(f"{virtuosoPath}/bin/virtuoso-t +configfile {configPath} +wait") 117 | 118 | 119 | def stop(port): 120 | run(f"echo 'shutdown;' | {virtuosoPath}/bin/isql localhost:{isqlPort(port)}") 121 | 122 | 123 | def status(port): 124 | run(f"echo 'status();' | {virtuosoPath}/bin/isql localhost:{isqlPort(port)}") 125 | 126 | 127 | ############################################################ 128 | # Main 129 | 130 | if __name__ == "__main__": 131 | parser = argparse.ArgumentParser(description="manage Virtuoso services") 132 | parser.add_argument("action", type=str, help="start or stop") 133 | parser.add_argument("port", type=int, help="port for the SPARQL HTTP endpoint") 134 | parser.add_argument("-d", "--db-path", type=str, help="path to the db directory") 135 | 136 | args = parser.parse_args() 137 | if args.action == "start": 138 | if not args.db_path: 139 | print("please specify path to the db directory with -d") 140 | sys.exit() 141 | 142 | if not os.path.isdir(args.db_path): 143 | print("the path specified does not exist") 144 | sys.exit() 145 | 146 | start(args.db_path, args.port) 147 | elif args.action == "stop": 148 | stop(args.port) 149 | else: 150 | print(f"invalid action: ${args.action}") 151 | sys.exit() 152 | -------------------------------------------------------------------------------- /evaluate_results.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import string 6 | 7 | 8 | def normalize(s: str) -> str: 9 | """Lower text and remove punctuation, articles and extra whitespace.""" 10 | s = s.lower() 11 | exclude = set(string.punctuation) 12 | s = "".join(char for char in s if char not in exclude) 13 | s = re.sub(r"\b(a|an|the)\b", " ", s) 14 | s = re.sub(r"^\b(or|and)\b", "", s) 15 | s = re.sub(r"\"", "", s) 16 | # remove token: 17 | s = re.sub(r"\b()\b", " ", s) 18 | s = " ".join(s.split()) 19 | return s 20 | 21 | 22 | def match(s1: str, s2: str) -> bool: 23 | s1 = normalize(s1) 24 | s2 = normalize(s2) 25 | return s2 in s1 26 | 27 | 28 | def eval_acc(prediction, answer): 29 | matched = 0.0 30 | for a in answer: 31 | if match(prediction, a): 32 | matched += 1 33 | return matched / len(answer) 34 | 35 | 36 | def eval_hit(prediction, answer): 37 | if not prediction and not answer: 38 | return 1 39 | for a in answer: 40 | if match(" ".join(prediction), a): 41 | return 1 42 | for p in prediction: 43 | if match(" ".join(answer), p): 44 | return 1 45 | return 0 46 | 47 | 48 | def eval_f1(prediction, answer): 49 | if not prediction and not answer: 50 | return 1, 1, 1 51 | if len(prediction) == 0 or len(answer) == 0: 52 | return 0, 0, 0 53 | matched = 0 54 | prediction_str = " ".join(prediction) 55 | for a in answer: 56 | if match(prediction_str, a): 57 | matched += 1 58 | precision = matched / len(prediction) 59 | recall = matched / len(answer) 60 | if precision + recall == 0: 61 | return 0, precision, recall 62 | else: 63 | return 2 * precision * recall / (precision + recall), precision, recall 64 | 65 | 66 | def eval_result(predict_path, result_path, dataset, split): 67 | """ 68 | Eval final answer. 69 | """ 70 | predict_file = os.path.join(predict_path, "llm_prompt_response.jsonl") 71 | out_file = os.path.join( 72 | result_path, f"{dataset}_{split}_evaluate_llm_prompting.jsonl" 73 | ) 74 | # Load results 75 | f1_list = [] 76 | precission_list = [] 77 | recall_list = [] 78 | hit_list = [] 79 | with open(predict_file) as fin, open(out_file, "w") as fout: 80 | # skip first line which is args 81 | first_line = True 82 | for line in fin: 83 | if first_line: 84 | first_line = False 85 | continue 86 | data = json.loads(line) 87 | if dataset == "grail_qa": 88 | groundtruth_ans = data["answer"]["entity_name"] 89 | if groundtruth_ans == [""]: 90 | groundtruth_ans = data["answer"]["answer_argument"] 91 | else: 92 | try: 93 | groundtruth_ans = list( 94 | {path[-1][-1] for path in data["ground_truth_paths"] if path} 95 | ) 96 | except: 97 | print(data["ground_truth_paths"]) 98 | raise ValueError() 99 | 100 | # note: havent yet handled \", or, and 101 | prediction_ans = re.findall( 102 | r"\nSo the answer is \((.*?)\)", data["reasoning_ans"] 103 | ) 104 | prediction_ans = prediction_ans[0].split(", ") if prediction_ans else [] 105 | # prediction_ans = [data['reasoning_ans']]#.split("\n")[-1].split(', ') 106 | f1, precision, recall = eval_f1(prediction_ans, groundtruth_ans) 107 | f1_list.append(f1) 108 | precission_list.append(precision) 109 | recall_list.append(recall) 110 | # prediction_str = ' '.join(prediction_ans) 111 | hit = eval_hit(prediction_ans, groundtruth_ans) 112 | hit_list.append(hit) 113 | fout.write( 114 | json.dumps( 115 | { 116 | "id": data["id"], 117 | "question": data["question"], 118 | "prediction": prediction_ans, 119 | "ground_truth": groundtruth_ans, 120 | "hit": hit, 121 | "f1": f1, 122 | "precision": precision, 123 | "recall": recall, 124 | } 125 | ) 126 | + "\n" 127 | ) 128 | result_str = { 129 | "Hit": str(sum(hit_list) * 100 / len(hit_list)), 130 | " F1": str(sum(f1_list) * 100 / len(f1_list)), 131 | " Precision": str(sum(precission_list) * 100 / len(precission_list)), 132 | " Recall": str(sum(recall_list) * 100 / len(recall_list)), 133 | } 134 | print(result_str) 135 | fout.write(json.dumps(result_str) + "\n") 136 | 137 | 138 | if __name__ == "__main__": 139 | argparser = argparse.ArgumentParser() 140 | argparser.add_argument( 141 | "-d_in", 142 | type=str, 143 | default="LLMReasoningCert/data/", 144 | ) 145 | argparser.add_argument( 146 | "-d_out", 147 | type=str, 148 | default="LLMReasoningCert/LLMReasonCert/experiment_results/evaluate_llm_prompt", 149 | ) 150 | argparser.add_argument("--dataset", type=str, default="cwq") # grail_qa, cwq 151 | argparser.add_argument("--split", type=str, default="test") # validation, test, dev 152 | argparser.add_argument("--model_name", type=str, default="gpt-3.5-turbo") 153 | args = argparser.parse_args() 154 | 155 | args.d_in = os.path.join(args.d_in, args.dataset, args.model_name, args.split) 156 | eval_result(args.d_in, args.d_out, args.dataset, args.split) 157 | -------------------------------------------------------------------------------- /finegrained_analysis.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | 4 | GPT_FOLDERS = [ 5 | "gpt-3.5-turbo-fewshot-cot-hint-{}-temp-0.7-p-0.9-consistency-1-is_sc_1", 6 | "gpt-3.5-turbo-fewshot-cot-only-{}-temp-0.7-p-0.9-consistency-1-is_sc_1", 7 | "gpt-3.5-turbo-fewshot-cot-only-{}-temp-0.7-p-0.9-consistency-20-is_sc_4", 8 | ] 9 | OTHER_MODELS_FOLDERS = [ 10 | "cot-hint-temp-0.7-p-0.9-is_sc_1", 11 | "cot-temp-0.7-p-0.9-is_sc_1", 12 | "cot-temp-0.7-p-0.9-is_sc_4", 13 | ] 14 | MODEL_LIST = [ 15 | "gpt-3.5-turbo", 16 | "Llama-2-70b-chat-hf", 17 | "Mistral-7B-Instruct-v0.1", 18 | "Qwen-7B-Chat", 19 | "Qwen-14B-Chat", 20 | "vicuna-33b-v1.3", 21 | ] 22 | DATASET_LIST = ["cwq", "grail_qa"] 23 | 24 | 25 | def calc_avg_match_edit_rate(match_path, step_path, len_dataset): 26 | with open(step_path) as f: 27 | step_dat = [json.loads(l) for l in f] 28 | step_dat = [len(min(dic["10_extracted_triplets"], key=len)) for dic in step_dat] 29 | with open(match_path) as f: 30 | match_dat = [json.loads(l) for l in f] 31 | match_dat = pd.DataFrame(match_dat) 32 | len_dataset = max(len(match_dat), len_dataset) 33 | return { 34 | # 'avg_match_rate': round(match_dat['match_rate'].sum()/len_dataset,4)*100, 35 | # 'avg_edit_rate': round((match_dat['edit_rate'].sum()+len_dataset-len(match_dat))/len_dataset,4)*100, 36 | "avg_match_rate": round(match_dat["match_rate"].mean(), 4) * 100, 37 | "avg_edit_rate": round(match_dat["edit_rate"].mean(), 4) * 100, 38 | "avg_num_steps": round(sum(step_dat) / len(step_dat), 4), 39 | } 40 | 41 | 42 | def check_match_edit_rate(): 43 | path_match_template = ( 44 | "LLMReasoningCert/LLMReasonCert/results/{}/{}/{}/groundtruth/data.jsonl" 45 | ) 46 | path_step_template = "LLMReasoningCert/LLMReasonCert/results/{}/{}/{}/full.jsonl" 47 | res_dic = {} 48 | for dataset in DATASET_LIST: 49 | with open( 50 | f"LLMReasoningCert/data/{dataset}/gpt-3.5-turbo/test/splitted_ground_truth_paths.json" 51 | ) as f: 52 | dat = json.load(f) 53 | dat = dat["min_2hop"] + dat["min_multihop"] 54 | len_dataset = len(dat) 55 | print(len_dataset) 56 | # find average of steps in groundtruthpath 57 | dat = [s["ground_truth_paths"] for s in dat] 58 | num_steps = [len(min(s, key=len)) for s in dat] 59 | print(dataset, sum(num_steps) / len(num_steps)) 60 | dic = {} 61 | for model_name in MODEL_LIST: 62 | model_dic = {} 63 | if "gpt" in model_name.lower(): 64 | for folder in GPT_FOLDERS: 65 | match_path = path_match_template.format( 66 | dataset, model_name, folder.format(dataset) 67 | ) 68 | step_path = path_step_template.format( 69 | dataset, model_name, folder.format(dataset) 70 | ) 71 | 72 | model_dic[folder] = calc_avg_match_edit_rate( 73 | match_path, step_path, len_dataset 74 | ) 75 | else: 76 | for folder in OTHER_MODELS_FOLDERS: 77 | match_path = path_match_template.format(dataset, model_name, folder) 78 | step_path = path_step_template.format(dataset, model_name, folder) 79 | 80 | model_dic[folder] = calc_avg_match_edit_rate( 81 | match_path, step_path, len_dataset 82 | ) 83 | dic[model_name] = model_dic 84 | 85 | res_dic[dataset] = dic 86 | return res_dic 87 | 88 | 89 | def calculate_answer_error_reasoning(path): 90 | with open(path) as f: 91 | dat = [json.loads(l) for l in f] 92 | dat = [dic["short_eval"] for dic in dat] 93 | reasoning_error_dat = [dic for dic in dat if dic["p_incorrect_reasoning"] == 1] 94 | fact_error_len = len( 95 | [dic for dic in reasoning_error_dat if dic["p_factual_error"] == 1] 96 | ) 97 | coherence_error_len = len( 98 | [dic for dic in reasoning_error_dat if dic["p_coherent_error"] == 1] 99 | ) 100 | answer_error_len = len( 101 | [dic for dic in reasoning_error_dat if dic["p_reasoning_ans_error"] == 1] 102 | ) 103 | reasoning_error_len = len(reasoning_error_dat) 104 | return { 105 | "fact_error": round(fact_error_len / reasoning_error_len, 4) * 100, 106 | "coherent_error": round(coherence_error_len / reasoning_error_len, 4) * 100, 107 | "reasoning_answer_error": round(answer_error_len / reasoning_error_len, 4) 108 | * 100, 109 | } 110 | 111 | 112 | def answer_error_reasoning(): 113 | path_step_template = ( 114 | "LLMReasoningCert/LLMReasonCert/tmp/revision/results/{}/{}/{}/full.jsonl" 115 | ) 116 | res_dic = {} 117 | for dataset in DATASET_LIST: 118 | dic = {} 119 | for model_name in MODEL_LIST: 120 | model_dic = {} 121 | if "gpt" in model_name.lower(): 122 | for folder in GPT_FOLDERS: 123 | step_path = path_step_template.format( 124 | dataset, model_name, folder.format(dataset) 125 | ) 126 | model_dic[folder] = calculate_answer_error_reasoning(step_path) 127 | else: 128 | for folder in OTHER_MODELS_FOLDERS: 129 | step_path = path_step_template.format(dataset, model_name, folder) 130 | model_dic[folder] = calculate_answer_error_reasoning(step_path) 131 | dic[model_name] = model_dic 132 | res_dic[dataset] = dic 133 | return res_dic 134 | 135 | 136 | if __name__ == "__main__": 137 | # check match, edit rate 138 | match_edit_dic = check_match_edit_rate() 139 | with open( 140 | "LLMReasoningCert/LLMReasonCert/tmp/revision/results/match_edit_rate.json", 141 | "w", 142 | ) as fout: 143 | json.dump(match_edit_dic, fout, indent=4) 144 | 145 | # check error 146 | reasoning_error_dic = answer_error_reasoning() 147 | with open( 148 | "LLMReasoningCert/LLMReasonCert/tmp/revision/results/type_error.json", 149 | "w", 150 | ) as fout: 151 | json.dump(reasoning_error_dic, fout, indent=4) 152 | -------------------------------------------------------------------------------- /generative_cert/utils/utils.py: -------------------------------------------------------------------------------- 1 | # import openai 2 | # from dotenv import load_dotenv 3 | import time 4 | import os 5 | import networkx as nx 6 | 7 | # import tiktoken 8 | import random 9 | from collections import deque 10 | import numpy as np 11 | import re 12 | import itertools 13 | import statistics as stat 14 | 15 | # load_dotenv() 16 | # openai.api_key = os.getenv("OPENAI_API_KEY") 17 | # openai.organization = os.getenv("OPENAI_ORG") 18 | # os.environ['TIKTOKEN_CACHE_DIR'] = './tmp' 19 | 20 | 21 | def set_seed(seed): 22 | random.seed(seed) 23 | np.random.seed(seed) 24 | 25 | 26 | def build_graph(graph: list) -> nx.Graph: 27 | G = nx.Graph() 28 | for triplet in graph: 29 | h, r, t = triplet 30 | G.add_edge(h, t, relation=r.strip()) 31 | return G 32 | 33 | 34 | def get_edge_data(G: nx.Graph, start_node, end_node): 35 | relation = G.get_edge_data(start_node, end_node)["relation"] 36 | return relation 37 | 38 | 39 | # function for obtaining answer using self-consistency or not 40 | def ans_by_sc(ans, is_sc): 41 | """ 42 | ans: str or list of str 43 | is_sc: int. Self-consistency if is_sc>1, otherwise NO 44 | """ 45 | if not isinstance(is_sc, int): 46 | raise ValueError("is_sc only can be INT") 47 | if isinstance(ans, str): 48 | if is_sc > 1: 49 | raise ValueError("Only 1 answer!") 50 | return [ans] 51 | if is_sc > 1: 52 | ans = ans[:is_sc] 53 | ans = filter_majority_vote_ans(ans) 54 | return ans 55 | else: 56 | return ans[:1] 57 | 58 | 59 | def preprocess_ans(str_ans): 60 | str_ans = str_ans.strip() 61 | str_ans = str_ans.replace("\n\n", "\n") 62 | str_ans = str_ans.replace("Step ", "").replace("step ", "") 63 | str_ans = str_ans[str_ans.find("1. ") :] # remove Q: A: 64 | # str_ans = str_ans[:str_ans.find('Q: ')] # remove Q: A: 65 | return str_ans.strip() 66 | 67 | 68 | def is_correct_llm_ans_format(ans): 69 | # if not all(element in ans for element in ["So the answer is (", "1. "]): 70 | # return False 71 | check_single = lambda ans: ( 72 | re.search("^1. ", ans) 73 | and re.findall("(t|T)he answer( to the question)? is\:? \(?(.*?)\)?.?$", ans) 74 | ) 75 | if isinstance(ans, str): 76 | return check_single(ans) 77 | return any(check_single(a) for a in ans) 78 | 79 | 80 | def is_lack_of_knowledge(ans): 81 | keywords = [ 82 | "do not have knowledge", 83 | "not have knowledge", 84 | "more information", 85 | "need more", 86 | "impossible", 87 | "not possible", 88 | "unknown", 89 | "no answer", 90 | "unable", 91 | "cannot", 92 | "sorry", 93 | "unclear", 94 | # 'depend on', 95 | # 'need to' 96 | ] 97 | check_single = lambda a: any(kw in a.lower() for kw in keywords) 98 | if isinstance(ans, str): 99 | return check_single(ans) 100 | return all(check_single(a) for a in ans) 101 | 102 | 103 | def get_final_answer(str_ans): 104 | if isinstance(str_ans, str): 105 | prediction_ans = re.findall( 106 | r"(t|T)he answer( to the question)? is\:? \(?(.*?)\)?.?$", str_ans 107 | ) 108 | if not prediction_ans: 109 | return [] 110 | else: 111 | pred_ans = prediction_ans[0][2].split(", ") 112 | return pred_ans 113 | else: 114 | raise ValueError("Get final answer: need to be a string!") 115 | 116 | 117 | def filter_majority_vote_ans(ans): 118 | """ 119 | Only select answer containing majority voted answer. 120 | """ 121 | if isinstance(ans, str): 122 | raise ValueError("Answer need to be a list of strings!") 123 | final_ans = [get_final_answer(a) for a in ans] # get final answer 124 | list_ans = list(set(itertools.chain(*final_ans))) 125 | if not list_ans: 126 | return ans # incorrect instruction 127 | major_ans = stat.mode(list_ans) 128 | filtered_ans = [a for a, fa in zip(ans, final_ans) if major_ans in fa] 129 | return filtered_ans 130 | 131 | 132 | def drop_duplicated_triplets(tri_list): 133 | return list(k for k, _ in itertools.groupby(sorted(tri_list)) if len(k) == 3) 134 | 135 | 136 | def is_unknown_ent(ent): 137 | if re.search("^[mg]\.", ent): 138 | return True 139 | return False 140 | 141 | 142 | def find_triplets_contain_unknown_ent(triplets, ent, position): 143 | # find list of triplets which unknown ent is at the position 144 | l = list(filter(lambda t: t[position] == ent, triplets)) 145 | return l 146 | 147 | 148 | def get_unknown_ent_cates(path): 149 | """Creating dict of unknow entities:0-list & 2-list...""" 150 | unknown_ent_cates = {} 151 | for tri in path: 152 | head, rel, tail = tri 153 | if is_unknown_ent(head): 154 | if head not in unknown_ent_cates: 155 | unknown_ent_cates[head] = {0: [tri]} 156 | else: 157 | unknown_ent_cates[head][0] = ( 158 | unknown_ent_cates[head][0] + [tri] 159 | if unknown_ent_cates[head].get(0, None) 160 | else [tri] 161 | ) 162 | elif is_unknown_ent(tail): 163 | if tail not in unknown_ent_cates: 164 | unknown_ent_cates[tail] = {2: [tri]} 165 | else: 166 | unknown_ent_cates[tail][2] = ( 167 | unknown_ent_cates[tail][2] + [tri] 168 | if unknown_ent_cates[tail].get(2, None) 169 | else [tri] 170 | ) 171 | return unknown_ent_cates 172 | 173 | 174 | def processed_groundtruth_path(path): 175 | """skip unknown entities.""" 176 | # find dic of unknown entities 177 | unknown_ent_cates = get_unknown_ent_cates(path) 178 | if not unknown_ent_cates: 179 | return path 180 | new_triplets = [] 181 | # merge 182 | for tri in path: 183 | head, rel, tail = tri 184 | if is_unknown_ent(head): 185 | l = unknown_ent_cates[head].get(2, []) 186 | new_list = [[t[0], "/".join([t[1], rel]), tail] for t in l if tail != t[0]] 187 | new_triplets += new_list 188 | elif is_unknown_ent(tail): 189 | l = unknown_ent_cates[tail].get(0, []) 190 | new_list = [[head, "/".join([rel, t[1]]), t[2]] for t in l if head != t[2]] 191 | new_triplets += new_list 192 | else: 193 | new_triplets += [tri] 194 | # drop duplicates 195 | new_triplets = drop_duplicated_triplets(new_triplets) 196 | return new_triplets 197 | -------------------------------------------------------------------------------- /preprocess_data/grailqa_graph.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import json 4 | import warnings 5 | from tqdm import tqdm 6 | from datasets import load_dataset 7 | from preprocess_data.sparql import SparQL 8 | from generative_cert.utils.utils import is_unknown_ent 9 | 10 | warnings.filterwarnings("ignore") 11 | num_proc = 16 12 | 13 | sparql = SparQL("http://localhost:3001/sparql") 14 | 15 | 16 | def get_ent_in_graphquery(dic): 17 | ids = dic["nodes"]["id"] 18 | ids = [e for e in ids if (is_unknown_ent(e) or "^^http:" in e)] 19 | ids = [e.split("^^")[0] for e in ids] 20 | res = [sparql.SQL_entity2name(e) for e in ids if e != "0"] 21 | if any([is_unknown_ent(e) for e in res]): 22 | raise ValueError() 23 | return res 24 | 25 | 26 | def remove_for_searching_subgraph(sparql_txt): 27 | sparql_txt = sparql_txt.replace("\n", "").replace("}", " }") 28 | sparql_txt = re.sub(r" ?VALUES \?\w\d { ((:[mg]\.\w+)|\".+\>) +}", "", sparql_txt) 29 | return sparql_txt 30 | 31 | 32 | def subgraph_comparative_ques(sparql_txt): 33 | # replace comparative entities with not in KG 34 | comparative_ent = re.search(r"\?x\d [<>]=? \"", sparql_txt).group()[ 35 | :3 36 | ] # return ?x\d 37 | sparql_txt1 = sparql_txt.replace( 38 | "SELECT (?x0 AS ?value) WHERE { SELECT DISTINCT ?x0", 39 | "SELECT (?x0 AS ?value) WHERE { SELECT DISTINCT ?x0".replace( 40 | "?x0", comparative_ent 41 | ), 42 | ) 43 | added_q_entities = sparql.query(sparql_txt1, variable="value")[1] 44 | # find subgraph 45 | sparql_txt = remove_for_searching_subgraph(sparql_txt) 46 | raw_subgraph, subgraph = sparql.query_reasoning_path(sparql_txt) 47 | return raw_subgraph, subgraph, added_q_entities, [] 48 | 49 | 50 | def subgraph_count_ques(sparql_txt): 51 | # search ans ents due to the final answer is COUNT=number 52 | added_a_entities = sparql.query(sparql_txt.replace("COUNT(?x0)", "?x0"))[1] 53 | # find subgraph 54 | sparql_txt = remove_for_searching_subgraph(sparql_txt) 55 | raw_subgraph, subgraph = sparql.query_reasoning_path(sparql_txt) 56 | return raw_subgraph, subgraph, [], added_a_entities 57 | 58 | 59 | def subgraph_superlative_ques(sparql_txt): 60 | sparql_txt = remove_for_searching_subgraph(sparql_txt) 61 | sparql_txt = re.sub(r"WHERE { \?y\d .+} ?\?x", "WHERE { ?x", sparql_txt) 62 | sparql_txt = re.sub(r" FILTER ( \?y.+ \?y\d )", "", sparql_txt) 63 | raw_subgraph, subgraph = sparql.query_reasoning_path(sparql_txt) 64 | return raw_subgraph, subgraph, [], [] 65 | 66 | 67 | def subgraph_none_ques(sparql_txt): 68 | sparql_txt = remove_for_searching_subgraph(sparql_txt) 69 | raw_subgraph, subgraph = sparql.query_reasoning_path(sparql_txt) 70 | return raw_subgraph, subgraph, [], [] 71 | 72 | 73 | def find_subgraph(row): 74 | sparql_txt = row["sparql_query"] 75 | # find subgraph 76 | try: 77 | if row["function"] == "none": 78 | raw_subgraph, subgraph, added_q_entities, added_a_entities = ( 79 | subgraph_none_ques(sparql_txt) 80 | ) 81 | elif row["function"] == "count": 82 | raw_subgraph, subgraph, added_q_entities, added_a_entities = ( 83 | subgraph_count_ques(sparql_txt) 84 | ) 85 | elif row["function"] in ["argmax", "argmin"]: 86 | raw_subgraph, subgraph, added_q_entities, added_a_entities = ( 87 | subgraph_superlative_ques(sparql_txt) 88 | ) 89 | else: # row['function'] in [>=, <=, >, <] 90 | raw_subgraph, subgraph, added_q_entities, added_a_entities = ( 91 | subgraph_comparative_ques(sparql_txt) 92 | ) 93 | except Exception as e: 94 | # print(e) 95 | with open("temp.txt", "a") as f: 96 | for k, v in row.items(): 97 | f.write(f"{k}:{v}" + "\n") 98 | f.write("\n") 99 | # return None, None 100 | raise ValueError() 101 | 102 | return raw_subgraph, subgraph, added_q_entities, added_a_entities 103 | 104 | 105 | def get_info(row, fout, processed_ids): 106 | id = row["qid"] 107 | if id in processed_ids: 108 | return {"id": id, "status": "processed"} 109 | ques = row["question"] 110 | ans_ent = ( 111 | row["answer"]["entity_name"] 112 | if row["answer"]["entity_name"] != [""] 113 | else row["answer"]["answer_argument"] 114 | ) # list of answer entities 115 | topic_ent = get_ent_in_graphquery(row["graph_query"]) # list of topic entities 116 | if not topic_ent: 117 | return {"id": id, "status": "none"} 118 | else: 119 | raw_subgraph, subgraph, added_q_entities, added_a_entities = find_subgraph(row) 120 | if not subgraph: 121 | return {"id": id, "status": "none"} 122 | topic_ent += added_q_entities 123 | ans_ent += added_a_entities 124 | res = { 125 | "id": id, 126 | "question": ques, 127 | "q_entity": topic_ent, 128 | "a_entity": ans_ent, 129 | "answer": ans_ent, 130 | "graph": subgraph, 131 | "raw_graph": raw_subgraph, 132 | "function": row["function"], 133 | } 134 | fout.write(json.dumps(res) + "\n") 135 | return {"id": id, "status": "sucessed"} 136 | 137 | 138 | def get_output_file(path, force=False): 139 | if not os.path.exists(path) or force: 140 | fout = open(path, "w") 141 | return fout, {} 142 | else: 143 | with open(path, "r") as f: 144 | processed_ids = [] 145 | for line in f: 146 | results = json.loads(line) 147 | processed_ids.append(results["id"]) 148 | fout = open(path, "a") 149 | # print('processed_ids', processed_ids) 150 | return fout, processed_ids 151 | 152 | 153 | def get_multihop(dataset="grail_qa", split="validation"): 154 | dat = load_dataset(dataset, split=split) 155 | # ignore_1hop: 156 | dat = dat.filter(lambda r: r["num_edge"] > 1) 157 | return dat 158 | 159 | 160 | def get_data(dataset="grail_qa", split="validation"): 161 | if split == "validation": 162 | convert_split = "test" 163 | out_path = f"LLMReasoningCert/data/{dataset}/{convert_split}.jsonl" 164 | print(out_path) 165 | fout, processed_ids = get_output_file(out_path) 166 | dat = get_multihop(dataset, split) 167 | dat = dat.map( 168 | lambda row: get_info(row, fout, processed_ids), 169 | num_proc=num_proc, 170 | remove_columns=dat.column_names, 171 | ) 172 | fout.close() 173 | print(dat.to_pandas().status.value_counts()) 174 | print(f"GrailQA: Done") 175 | 176 | 177 | if __name__ == "__main__": 178 | get_data() 179 | -------------------------------------------------------------------------------- /extract_subgraph/preprocess_data.py: -------------------------------------------------------------------------------- 1 | # base on this link: https://github.com/RichardHGL/WSDM2021_NSM/ 2 | import json 3 | import sys 4 | import os 5 | import time 6 | import pandas as pd 7 | from datasets import Dataset, load_dataset 8 | from LLMReasonCert.extract_subgraph.deal_cvt import load_cvt, is_cvt 9 | 10 | 11 | def get_domain_from_dataset(dat_path): 12 | with open(dat_path) as fin: 13 | data = json.load(fin) 14 | data = data["Questions"] 15 | filter_domain = set() 16 | for ques in data: 17 | parses = ques["Parses"] 18 | for parse in parses: 19 | rel_str = parse["InferentialChain"].split("..") 20 | filter_domain.update(rel_str) 21 | return list(filter_domain) 22 | 23 | 24 | def manual_filter_rel(filter_domain): 25 | filter_set = set(filter_domain) 26 | input = "LLMReasoningCert/data/data/fb_en.txt" 27 | output = "LLMReasoningCert/data/data/manual_fb_filter.txt" 28 | f_in = open(input) 29 | f_out = open(output, "w") 30 | num_line = 0 31 | num_reserve = 0 32 | for line in f_in: 33 | splitline = line.strip().split("\t") 34 | num_line += 1 35 | if len(splitline) < 3: 36 | continue 37 | rel = splitline[1] 38 | flag = False 39 | for domain in filter_set: 40 | if domain in rel: 41 | flag = True 42 | break 43 | if flag: 44 | continue 45 | f_out.write(line) 46 | num_reserve += 1 47 | if num_line % 1000000 == 0: 48 | print("Checked: ", num_line, "lines, Reserver:", num_reserve) 49 | f_in.close() 50 | f_out.close() 51 | print("Total: ", num_line, "lines, Reserver:", num_reserve) 52 | 53 | 54 | def is_ent(tp_str): 55 | if len(tp_str) < 3: 56 | return False 57 | if tp_str.startswith("m.") or tp_str.startswith("g."): 58 | return True 59 | return False 60 | 61 | 62 | def fetch_triple_1hop(kb_file, seed_set, cvt_nodes, cvt_hop=True): 63 | cvt_set = set() 64 | num_tot = 0 65 | num_res = 0 66 | subgraph = [] 67 | f = open(kb_file) 68 | for line in f: 69 | spline = line.strip().split("\t") 70 | num_tot += 1 71 | if spline[0] in seed_set: 72 | # Only subject is enough. 73 | if cvt_hop and spline[2] not in seed_set and is_cvt(spline[2], cvt_nodes): 74 | cvt_set.add(spline[2]) 75 | subgraph.append(spline) 76 | num_res += 1 77 | if num_tot % 1000000 == 0: 78 | print("seed-hop", num_tot, num_res) 79 | 80 | num_tot = 0 81 | num_res = 0 82 | if cvt_hop: 83 | cvt_set = cvt_set - seed_set 84 | with open(kb_file) as f: 85 | for line in f: 86 | num_tot += 1 87 | spline = line.strip().split("\t") 88 | if spline[0] in cvt_set: 89 | subgraph.append(spline) 90 | num_res += 1 91 | if num_tot % 1000000 == 0: 92 | print("seed-hop", num_tot, num_res) 93 | return subgraph 94 | 95 | 96 | def filter_ent_from_triple(subgraph): 97 | ent_set = set() 98 | for line in subgraph: 99 | if is_ent(line[0]): 100 | ent_set.add(line[0]) 101 | if is_ent(line[2]): 102 | ent_set.add(line[2]) 103 | return ent_set 104 | 105 | 106 | def get_n_hop_supgraph(seed_set, n_hop, cvt_nodes): 107 | st = time.time() 108 | kb_file = "LLMReasoningCert/data/data/manual_fb_filter.txt" 109 | 110 | for i_hop in range(1, n_hop + 1): 111 | subgraph = fetch_triple_1hop( 112 | kb_file=kb_file, seed_set=seed_set, cvt_nodes=cvt_nodes, cvt_hop=True 113 | ) 114 | print(f"\tHop {i_hop}", time.time() - st) 115 | st = time.time() 116 | 117 | if i_hop < n_hop: 118 | hop1_ent = filter_ent_from_triple(subgraph=subgraph) 119 | print("\tFetch ent from Hop 1", time.time() - st) 120 | st = time.time() 121 | seed_set = hop1_ent 122 | 123 | print("\tDone.", time.time() - st) 124 | return subgraph 125 | 126 | 127 | def freebase_supgraph(data_folder, n_hop): 128 | def add_subgraph_to_parse(parse, cvt_nodes, n_hop): 129 | for parse_i in range(len(parse)): 130 | seed_set = set( 131 | [parse[parse_i]["TopicEntityMid"]] 132 | + [ans["AnswersMid"] for ans in parse[parse_i]["Answers"]] 133 | ) 134 | parse[parse_i]["subgraph"] = get_n_hop_supgraph(seed_set, n_hop, cvt_nodes) 135 | return parse 136 | 137 | data_file = [ 138 | "FreebaseQA-dev.json", 139 | "FreebaseQA-eval.json", 140 | "ComplexWebQuestions_train.json", 141 | ] 142 | input_dir = "raw" 143 | 144 | st = time.time() 145 | cvt_nodes = load_cvt() 146 | print("Loaded CVT", time.time() - st) 147 | 148 | for file in data_file: 149 | # print(f'Handled {count}/{len_data} in {file}') 150 | print(f"Handling {file}") 151 | output_file = os.path.join(data_folder, file) 152 | input_file = os.path.join(data_folder, input_dir, file) 153 | with open(input_file) as f_in: 154 | data = json.load(f_in) 155 | data = Dataset.from_pandas(pd.DataFrame(data=data["Questions"])) 156 | data = data.map( 157 | input_columns=["Parses"], 158 | remove_columns=["Parses"], 159 | function=lambda x: { 160 | "Parses": add_subgraph_to_parse(x, cvt_nodes, n_hop) 161 | }, 162 | num_proc=8, 163 | ) 164 | data.to_json(output_file) 165 | 166 | 167 | def grailqa_supgraph(data_dir, n_hop): 168 | def add_subgraph_to_parse(row, cvt_nodes, n_hop): 169 | seed_set = set( 170 | [row["graph_query"]["id"][-1]] + row["answer"]["answer_argument"] 171 | ) 172 | return get_n_hop_supgraph(seed_set, n_hop, cvt_nodes) 173 | 174 | st = time.time() 175 | cvt_nodes = load_cvt() 176 | print("Loaded CVT", time.time() - st) 177 | 178 | data_dic = load_dataset("grail_qa") 179 | for split in data_dic: 180 | output_file = os.path.join(data_dir, f"{split}.json") 181 | data = data_dic[split] 182 | data = data.map( 183 | lambda x: {"subgraph": add_subgraph_to_parse(x, cvt_nodes, n_hop)}, 184 | num_proc=8, 185 | ) 186 | data.to_json(output_file) 187 | 188 | 189 | if __name__ == "__main__": 190 | # step0: filter freebase: not triplet 191 | # dat_path = 'LLMReasoningCert/data/FreebaseQA/raw/' 192 | # filter_domain = get_domain_from_dataset(dat_path) 193 | # print(filter_domain[:2]) 194 | # manual_filter_rel([]) 195 | 196 | # step1: get subgraph 197 | # freebase 198 | # freebase_supgraph(data_folder='LLMReasoningCert/data/FreebaseQA', dat_name='freebase', n_hop=2) 199 | grailqa_supgraph( 200 | data_folder="LLMReasoningCert/envs/huggingface", 201 | dat_name="grail_qa", 202 | n_hop=4, 203 | ) 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [ACL'24] Direct Evaluation of CoT in Multi-hop Reasoning with Knowledge Graphs 2 | --- 3 | Official Implementation of ["Direct Evaluation of Chain-of-Thought in Multi-hop Reasoning with Knowledge Graphs"](https://aclanthology.org/2024.findings-acl.168). 4 | 5 | Has been accepted at [ACL2024](https://aclanthology.org/2024.findings-acl.168) Findings. 6 | 7 | 8 | 9 | Aiming evaluate not only final answers but also intermediate steps in the CoT reasoning capabilities of LLMs in multi-hop question answering, the paper proposed 2 evaluation modules: 10 | 1. **Discriminative**: assess LLMs' knowledge of reasoning 11 | 2. **Generative**: assess the accuracy of the generated CoT by utilizing knowledge graphs (KGs). 12 | 13 | In addition, we do ablation studies to evaluate the fine-grain CoT generation to calculate edit-distance & reasoning errors. 14 | 15 | ## Requirements 16 | ```sh 17 | conda create --name llm-reasoning-cert python=3.8 18 | conda activate llm-reasoning-cert 19 | ``` 20 | ```sh 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ## Datasets 25 | The paper uses 2 datasets: [CWQ](https://allenai.org/data/complexwebquestions) and [GrailQA](https://huggingface.co/datasets/grail_qa) as initiate datasets for experiments. 26 | 27 | Then, extract subgraph and ground-truth reasoning path based on SPARQL. 28 | 29 | Final datasets used for the paper are uploaded into HuggingFace: (Note: update later) 30 | 1. [CWQ-Subgraph-Eval]() 31 | 2. [GrailQA-Subgraph-Eval]() 32 | 33 | ### Preprocess for each dataset: 34 | Aim: create subgraphs for querying ground-truth reasoning path & creating VectorDB 35 | #### Create subgraphs 36 | Code at `./preprocess_data` 37 | 1. Create **subgraph** from the **raw-subgraph** via the detail implementation in [preprocess's readme](./preprocess_data/readme.md) 38 | 3. Get **groundtruth reasoning path** via the **subgraph**, `answer entities` and `topic entities` 39 | ```bash 40 | python ./preprocess_data/ground_truth_paths.py 41 | ``` 42 | 4. Rearrange questions according to the number of edge of **groundtruth reasoning path** 43 | ```bash 44 | python ./preprocess_data/splitted_ground_truth_paths.py 45 | ``` 46 | We only use questions >=2 hops in the corresponding reasoning path. 47 | 48 | #### Create VectorDB 49 | `FAISS` & `sentence-transformers/all-mpnet-base-v2` are used to create VectorDB before retrieving 50 | ```bash 51 | DATASET='cwq' # 'grail_qa 52 | sbatch scripts/gen-cert/extract_triplet.sh $DATASET 53 | ``` 54 | you can setup addition arguments: 55 | - embed_model_name. Default is `sentence-transformers/all-mpnet-base-v2` 56 | - top_k. Default is `10` 57 | - device. Default is `cpu` 58 | 59 | !Note: remember re-setup them in `./generative-cert.py#L228` 60 | 61 | 62 | ### Data for Discriminative 63 | Download data at [here](https://drive.google.com/file/d/1jhZ0qwg2pfuxSiBqtzoGQNjRkoHfdC99/view?usp=sharing) 64 | 65 | #### Generate negative reasoning paths 66 | - Negative generation model: 67 | 1. `replace`: replace the entities in reasoning paths. 68 | 2. `reorder`: reorder the reasoning paths. 69 | 3. `mislead`: generate the reasoning paths leading to incorrect answers. 70 | - Code: 71 | 72 | ```python 73 | # 1. Generate supgraph for misguide paths 74 | python preprocess_data/subgraph_discriminative_cert.py 75 | # 2. Generate negative paths: 76 | ## - CWQ dataset 77 | python gen_negative.py --data_path data/cwq_test_res.csv --kg_path data/cwq_test.jsonl_cwq_test.jsonl --mode {'mislead', 'reorder', 'replace'} 78 | ## - GrailQA dataset 79 | python gen_negative.py --data_path data/multi_hop_grailqa.csv --kg_path data/grail_w_kg.jsonl --mode {'mislead', 'reorder', 'replace'} 80 | ``` 81 | 82 | 83 | ## Framework 84 | Set your OpenAI api key & Huggingface key (if needed) in `.env` (check file `.env.example` as the example). 85 | 86 | ### Discriminative Mode 87 | - Evaluation for ground-truth reasoning paths 88 | ```bash 89 | sh scripts/disc-cert/submit_discriminative_cert.sh 90 | ``` 91 | - Evaluation for generated negative reasoning paths 92 | ```bash 93 | sh scripts/disc-cert/submit_discriminative_cert_neg.sh 94 | ``` 95 | - Get results 96 | ```python 97 | python scripts/disc-cert/summary_results.py 98 | ``` 99 | 100 | ### Generative Mode 101 | #### Stage1: LLM prompting for structured answer 102 | 1. ChatGPT 103 | ```bash 104 | sh scripts/gen-cert/llm_prompting.sh 105 | ``` 106 | 2. HF models: Llama2 7B/13B/70B chat-hf, Mistral-7B-Instruct-v0.1, Qwen-14B-Chat, Vicuna-33b-v1.3 107 | ```bash 108 | sh generative_cert/scripts/fitcluster/script.sh 109 | ``` 110 | 111 | #### Stage 2 & 3: Retrieval & Evaluation 112 | 1. Main result 113 | ```bash 114 | sh scripts/gen-cert/job_eval_llm.sh 115 | ``` 116 | 2. The fine-grained generative evaluation: edit-distance score 117 | ```bash 118 | sh scripts/gen-cert/job_eval_llm_finegrained.sh 119 | python finegrained_analysis.py 120 | ``` 121 | 3. Run the analysis for reasoning errors 122 | ```bash 123 | python finegrained_analysis.py 124 | ``` 125 | 126 | ## Results 127 | 128 | 129 | 130 | --- 131 | ## Citation 132 | If you find this paper or the repo useful for your work, please consider citing the paper 133 | ``` 134 | @misc{nguyen2024direct, 135 | title={Direct Evaluation of Chain-of-Thought in Multi-hop Reasoning with Knowledge Graphs}, 136 | author={Minh-Vuong Nguyen and Linhao Luo and Fatemeh Shiri and Dinh Phung and Yuan-Fang Li and Thuy-Trang Vu and Gholamreza Haffari}, 137 | year={2024}, 138 | eprint={2402.11199}, 139 | archivePrefix={arXiv}, 140 | primaryClass={cs.CL} 141 | } 142 | ``` 143 | or 144 | ``` 145 | @inproceedings{nguyen-etal-2024-direct, 146 | title = "Direct Evaluation of Chain-of-Thought in Multi-hop Reasoning with Knowledge Graphs", 147 | author = "Nguyen, Minh-Vuong and 148 | Luo, Linhao and 149 | Shiri, Fatemeh and 150 | Phung, Dinh and 151 | Li, Yuan-Fang and 152 | Vu, Thuy-Trang and 153 | Haffari, Gholamreza", 154 | editor = "Ku, Lun-Wei and 155 | Martins, Andre and 156 | Srikumar, Vivek", 157 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2024", 158 | month = aug, 159 | year = "2024", 160 | address = "Bangkok, Thailand", 161 | publisher = "Association for Computational Linguistics", 162 | url = "https://aclanthology.org/2024.findings-acl.168/", 163 | doi = "10.18653/v1/2024.findings-acl.168", 164 | pages = "2862--2883", 165 | abstract = "Large language models (LLMs) have demonstrated strong reasoning abilities when prompted to generate chain-of-thought (CoT) explanations alongside answers. However, previous research on evaluating LLMs has solely focused on answer accuracy, neglecting the correctness of the generated CoT. In this paper, we delve deeper into the CoT reasoning capabilities of LLMs in multi-hop question answering by utilizing knowledge graphs (KGs). We propose a novel discriminative and generative CoT evaluation paradigm to assess LLMs' knowledge of reasoning and the accuracy of the generated CoT. Through experiments conducted on 5 different families of LLMs across 2 multi-hop question-answering datasets, we find that LLMs possess sufficient knowledge to perform reasoning. However, there exists a significant disparity between answer accuracy and faithfulness of the CoT generated by LLMs, indicating that they often arrive at correct answers through incorrect reasoning." 166 | } 167 | ``` 168 | -------------------------------------------------------------------------------- /generative_cert/prompts.py: -------------------------------------------------------------------------------- 1 | FEWSHOT_COT_ONLY = """1. 2 | 2. 3 | ... 4 | So the answer is (). 5 | Make sure that the answer uses the above format and answers the question step by step. 6 | 7 | Q: when Lou Seal is the mascot for the team that last won the World Series? 8 | A: Let's work this out in a step by step way to be sure we have the right answer. 9 | 1. Lou Seal is the mascot for the San Francisco Giants. 10 | 2. The San Francisco Giants are associated with the sports championship event, the 2014 World Series. 11 | So the answer is (2014 World Series). 12 | 13 | Q: What nation has an army or more than 713480 people and borders the country of Bolivia? 14 | A: Let's work this out in a step by step way to be sure we have the right answer. 15 | 1. Bolivia is a landlocked country located in South America. 16 | 2. Bolivia shares its borders with several countries, including Argentina, Brazil, Chile, Paraguay, and Peru. 17 | So the answer is (Brazil). 18 | 19 | Q: What movie was displayed at the 2012 Refugee Film Festival and had Angelia Jolie directing it? 20 | A: Let's work this out in a step by step way to be sure we have the right answer. 21 | 1. Angelia Jolie whose first major film as a director which named “In the Land of Blood and Honey”. 22 | 2. “In the Land of Blood and Honey” was shown at the 2012 Refugee Film Festival. 23 | So the answer is (In the Land of Blood and Honey). 24 | 25 | Q: How many Mary Mary sisters? 26 | A: Let's work this out in a step by step way to be sure we have the right answer. 27 | 1. Mary Mary is a group which has a member named Tina Campbell 28 | 2. Mary Mary is a group which has a member named Erica Campbell 29 | So the answer is (Erica Campbell, Tina Campbell). 30 | 31 | Q: Which languages are used in the location that the breed Egyptian Mau started in? 32 | A: Let's work this out in a step by step way to be sure we have the right answer. 33 | 1. The Egyptian Mau is a breed of domestic cat that is believed to have originated in Egypt. 34 | 2. In Egypt, the primary language spoken is Arabic, besides Domari or Nobiin. 35 | So the answer is (Arabic, Domari, Nobiin). 36 | 37 | Q: {question} 38 | A: Let's work this out in a step by step way to be sure we have the right answer.""" 39 | 40 | 41 | FEWSHOT_COT_HINT = """Relation path is a sequence relation that describes each step of the reasoning process. You first give a relation path as a HINT, then reason the answer step-by-step based on it. 42 | HINT: 43 | 1. 44 | 2. 45 | ... 46 | So the answer is (). 47 | Make sure that the answer uses the above format and answers the question step by step. 48 | 49 | Q: when Lou Seal is the mascot for the team that last won the World Series? 50 | A: Let's work this out in a step by step way to be sure we have the right answer. 51 | HINT: sports.sports_team.team_mascot -> sports.sports_team.championships 52 | 1. Lou Seal is the mascot for the San Francisco Giants. 53 | 2. The San Francisco Giants are associated with the sports championship event, the 2014 World Series. 54 | So the answer is (2014 World Series). 55 | 56 | Q: What nation has an army or more than 713480 people and borders the country of Bolivia? 57 | A: Let's work this out in a step by step way to be sure we have the right answer. 58 | HINT: geography.river.basin_countries -> location.location.partially_contains 59 | 1. Bolivia is a landlocked country located in South America. 60 | 2. Bolivia shares its borders with several countries, including Argentina, Brazil, Chile, Paraguay, and Peru. 61 | So the answer is (Brazil). 62 | 63 | Q: What movie was displayed at the 2012 Refugee Film Festival and had Angelia Jolie directing it? 64 | A: Let's work this out in a step by step way to be sure we have the right answer. 65 | HINT: film.director.film -> film.film_regional_release_date.film_regional_debut_venue 66 | 1. Angelia Jolie whose first major film as a director which named “In the Land of Blood and Honey”. 67 | 2. “In the Land of Blood and Honey” was shown at the 2012 Refugee Film Festival. 68 | So the answer is (In the Land of Blood and Honey). 69 | 70 | Q: How many Mary Mary sisters? 71 | A: Let's work this out in a step by step way to be sure we have the right answer. 72 | HINT: music.group_membership.member -> music.group_membership.member 73 | 1. Mary Mary is a group which has a member named Tina Campbell 74 | 2. Mary Mary is a group which has a member named Erica Campbell 75 | So the answer is (Erica Campbell, Tina Campbell). 76 | 77 | Q: Which languages are used in the location that the breed Egyptian Mau started in? 78 | A: Let's work this out in a step by step way to be sure we have the right answer. 79 | HINT: biology.breed_origin.breeds_originating_here -> location.country.languages_spoken 80 | 1. The Egyptian Mau is a breed of domestic cat that is believed to have originated in Egypt. 81 | 2. In Egypt, the primary language spoken is Arabic, besides Domari or Nobiin. 82 | So the answer is (Arabic, Domari, Nobiin). 83 | 84 | Q: {question} 85 | A: Let's work this out in a step by step way to be sure we have the right answer.""" 86 | 87 | 88 | FEWSHOT_COT_HINT_GROUND = """Relation path is a sequence relation that describes each step of the reasoning process. I will give you a relation path as a hint. Please reason the answer step-by-step based on it. 89 | 1. 90 | 2. 91 | ... 92 | So the answer is (). 93 | Make sure that the answer uses the above format and answers the question step by step. 94 | 95 | Q: when Lou Seal is the mascot for the team that last won the World Series? 96 | A: Let's work this out in a step by step way to be sure we have the right answer. 97 | HINT: sports.sports_team.team_mascot -> sports.sports_team.championships 98 | 1. Lou Seal is the mascot for the San Francisco Giants. 99 | 2. The San Francisco Giants are associated with the sports championship event, the 2014 World Series. 100 | So the answer is (2014 World Series). 101 | 102 | Q: What nation has an army or more than 713480 people and borders the country of Bolivia? 103 | A: Let's work this out in a step by step way to be sure we have the right answer. 104 | HINT: geography.river.basin_countries -> location.location.partially_contains 105 | 1. Bolivia is a landlocked country located in South America. 106 | 2. Bolivia shares its borders with several countries, including Argentina, Brazil, Chile, Paraguay, and Peru. 107 | So the answer is (Brazil). 108 | 109 | Q: What movie was displayed at the 2012 Refugee Film Festival and had Angelia Jolie directing it? 110 | A: Let's work this out in a step by step way to be sure we have the right answer. 111 | HINT: film.director.film -> film.film_regional_release_date.film_regional_debut_venue 112 | 1. Angelia Jolie whose first major film as a director which named “In the Land of Blood and Honey”. 113 | 2. “In the Land of Blood and Honey” was shown at the 2012 Refugee Film Festival. 114 | So the answer is (In the Land of Blood and Honey). 115 | 116 | Q: How many Mary Mary sisters? 117 | A: Let's work this out in a step by step way to be sure we have the right answer. 118 | HINT: music.group_membership.member -> music.group_membership.member 119 | 1. Mary Mary is a group which has a member named Tina Campbell 120 | 2. Mary Mary is a group which has a member named Erica Campbell 121 | So the answer is (Erica Campbell, Tina Campbell). 122 | 123 | Q: Which languages are used in the location that the breed Egyptian Mau started in? 124 | A: Let's work this out in a step by step way to be sure we have the right answer. 125 | HINT: biology.breed_origin.breeds_originating_here -> location.country.languages_spoken 126 | 1. The Egyptian Mau is a breed of domestic cat that is believed to have originated in Egypt. 127 | 2. In Egypt, the primary language spoken is Arabic, besides Domari or Nobiin. 128 | So the answer is (Arabic, Domari, Nobiin). 129 | 130 | Q: {question} 131 | A: Let's work this out in a step by step way to be sure we have the right answer. 132 | HINT: {hint}.""" 133 | -------------------------------------------------------------------------------- /discriminative_cert/utils/utils.py: -------------------------------------------------------------------------------- 1 | import openai 2 | from dotenv import load_dotenv 3 | import time 4 | import os 5 | import networkx as nx 6 | import tiktoken 7 | import random 8 | from collections import deque 9 | import numpy as np 10 | 11 | load_dotenv() 12 | openai.api_key = os.getenv("OPENAI_API_KEY") 13 | openai.organization = os.getenv("OPENAI_ORG") 14 | os.environ["TIKTOKEN_CACHE_DIR"] = "./tmp" 15 | 16 | 17 | def set_seed(seed): 18 | random.seed(seed) 19 | np.random.seed(seed) 20 | 21 | 22 | def build_graph(graph: list) -> nx.Graph: 23 | G = nx.Graph() 24 | for triplet in graph: 25 | h, r, t = triplet 26 | G.add_edge(h, t, relation=r.strip()) 27 | return G 28 | 29 | 30 | # Define a function to perform a breadth-first search 31 | def bfs_with_rule(graph, start_node, target_rule): 32 | result_paths = [] 33 | queue = deque( 34 | [(start_node, [])] 35 | ) # Use queues to store nodes to be explored and corresponding paths 36 | while queue: 37 | current_node, current_path = queue.popleft() 38 | 39 | # If the current path matches the rules, add it to the results list 40 | if len(current_path) == len(target_rule): 41 | result_paths.append(current_path) 42 | 43 | # If the current path length is less than the rule length, continue exploring 44 | if len(current_path) < len(target_rule): 45 | if current_node not in graph: 46 | continue 47 | for neighbor in graph.neighbors(current_node): 48 | # Pruning: If the current edge type does not match the corresponding position in the rule, the path will not be explored further. 49 | rel = graph[current_node][neighbor]["relation"] 50 | if rel != target_rule[len(current_path)] or len(current_path) > len( 51 | target_rule 52 | ): 53 | continue 54 | queue.append((neighbor, current_path + [(current_node, rel, neighbor)])) 55 | 56 | return result_paths 57 | 58 | 59 | def random_walk_edge_sampling(graph, start_node, path_length, num_paths): 60 | paths = [] 61 | 62 | for _ in range(num_paths): 63 | path = [start_node] 64 | current_node = start_node 65 | edges = [] # Used to record the edges of each path 66 | 67 | for _ in range(path_length - 1): 68 | neighbors = list(graph.neighbors(current_node)) 69 | 70 | if len(neighbors) == 0: 71 | break 72 | 73 | next_node = random.choice(neighbors) 74 | edges.append(graph[current_node][next_node]["relation"]) 75 | path.append(next_node) 76 | current_node = next_node 77 | 78 | paths.append(edges) 79 | 80 | return paths 81 | 82 | 83 | def list_to_string(l: list) -> str: 84 | prompt = '"{}"' 85 | return ", ".join([prompt.format(i) for i in l]) 86 | 87 | 88 | def rule_to_string(rule: list) -> str: 89 | return " -> ".join(rule) 90 | 91 | 92 | def rules_to_string(rules: list) -> str: 93 | prompt = [] 94 | for r in rules: 95 | rule_to_string(r) 96 | return "\n".join(prompt) 97 | 98 | 99 | def path_to_string(path: list) -> str: 100 | result = "" 101 | for i, p in enumerate(path): 102 | h, r, t = p 103 | result += f"Step {i+1}: {h} -> {r} -> {t}\n" 104 | 105 | return result.strip() 106 | 107 | 108 | def reoder_path_to_string(path: list) -> str: 109 | result = "" 110 | for i, p in enumerate(path): 111 | h, r, t = p 112 | result += f"Step {i+1}: {h} -> {r} -> {t}\n" 113 | 114 | return result.strip() 115 | 116 | 117 | def num_tokens_from_message(path_string, model): 118 | """Returns the number of tokens used by a list of messages.""" 119 | messages = [{"role": "user", "content": path_string}] 120 | try: 121 | encoding = tiktoken.encoding_for_model(model) 122 | except KeyError: 123 | print("Warning: model not found. Using cl100k_base encoding.") 124 | encoding = tiktoken.get_encoding("cl100k_base") 125 | if model in ["gpt-3.5-turbo", "gpt-3.5-turbo-16k"]: 126 | tokens_per_message = ( 127 | 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n 128 | ) 129 | elif model == "gpt-4": 130 | tokens_per_message = 3 131 | else: 132 | raise NotImplementedError( 133 | f"num_tokens_from_messages() is not implemented for model {model}." 134 | ) 135 | num_tokens = 0 136 | for message in messages: 137 | num_tokens += tokens_per_message 138 | for key, value in message.items(): 139 | num_tokens += len(encoding.encode(value)) 140 | num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> 141 | return num_tokens 142 | 143 | 144 | def get_token_limit(model="gpt-4"): 145 | """Returns the token limitation of provided model""" 146 | if model in ["gpt-4", "gpt-4-0613"]: 147 | num_tokens_limit = 8192 148 | elif model in ["gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613"]: 149 | num_tokens_limit = 16384 150 | elif model in [ 151 | "gpt-3.5-turbo", 152 | "gpt-3.5-turbo-0613", 153 | "text-davinci-003", 154 | "text-davinci-002", 155 | ]: 156 | num_tokens_limit = 4096 157 | else: 158 | raise NotImplementedError( 159 | f"""get_token_limit() is not implemented for model {model}.""" 160 | ) 161 | return num_tokens_limit 162 | 163 | 164 | def split_path_list(path_list, token_limit, model): 165 | """ 166 | Split the path list into several lists, each list can be fed into the model. 167 | """ 168 | output_list = [] 169 | current_list = [] 170 | current_token_count = 4 171 | 172 | for path in path_list: 173 | path += "\n" 174 | path_token_count = num_tokens_from_message(path, model) - 4 175 | if ( 176 | current_token_count + path_token_count > token_limit 177 | ): # If the path makes the current list exceed the token limit 178 | output_list.append(current_list) 179 | current_list = [path] # Start a new list. 180 | current_token_count = path_token_count + 4 181 | else: # The new path fits into the current list without exceeding the limit 182 | current_list.append(path) # Just add it there. 183 | current_token_count += path_token_count 184 | # Add the last list of tokens, if it's non-empty. 185 | if current_list: # The last list not exceed the limit but no more paths 186 | output_list.append(current_list) 187 | return output_list 188 | 189 | 190 | def query(message, model="gpt-3.5-turbo"): 191 | """ 192 | Query ChatGPT API 193 | :param message: 194 | :return: 195 | """ 196 | prompt_len = num_tokens_from_message(message, model) 197 | token_limitation = get_token_limit(model) 198 | if prompt_len > token_limitation: 199 | raise ValueError( 200 | f"Message length {prompt_len} exceeds token limitation {token_limitation}. Message: {message}" 201 | ) 202 | while True: 203 | try: 204 | response = openai.ChatCompletion.create( 205 | model=model, 206 | messages=[{"role": "user", "content": message}], 207 | request_timeout=60, 208 | temperature=0.0, 209 | stream=False, 210 | ) 211 | return response["choices"][0]["message"]["content"].strip() 212 | # collected_chunks = [] 213 | # collected_messages = [] 214 | # # iterate through the stream of events 215 | # for chunk in response: 216 | # collected_chunks.append(chunk) # save the event response 217 | # chunk_message = chunk['choices'][0]['delta'] # extract the message 218 | # collected_messages.append(chunk_message) # save the message 219 | # result = ''.join([m.get('content', '') for m in collected_messages]).strip() 220 | # return result 221 | except Exception as e: 222 | print(e) 223 | time.sleep(60) 224 | continue 225 | -------------------------------------------------------------------------------- /scripts/disc-cert/submit_discriminative_cert.sh: -------------------------------------------------------------------------------- 1 | # DATA_PATH="data/cwq_test_res.csv" 2 | DATA_PATH="data/multi_hop_grailqa.csv" 3 | MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 4 | 5 | 6 | MODEL_NAME="gpt-3.5-turbo" 7 | MODEL_PATH="None" 8 | N_PROCESS=5 9 | QUANT=none 10 | 11 | # MODEL_NAME="llama2-7B-chat-hf" 12 | # MODEL_PATH="meta-llama/Llama-2-7b-chat-hf" 13 | # N_PROCESS=1 14 | # QUANT=none 15 | 16 | # MODEL_NAME="llama2-13B-chat-hf" 17 | # MODEL_PATH="meta-llama/Llama-2-13b-chat-hf" 18 | # N_PROCESS=1 19 | # QUANT=none 20 | 21 | # MODEL_NAME="Mistral-7B-Instruct-v0.1" 22 | # MODEL_PATH="mistralai/Mistral-7B-Instruct-v0.1" 23 | # N_PROCESS=1 24 | # QUANT=none 25 | 26 | # MODEL_NAME="Qwen-7B-Chat" 27 | # MODEL_PATH="Qwen/Qwen-7B-Chat" 28 | # N_PROCESS=1 29 | # QUANT=none 30 | 31 | # MODEL_NAME="Qwen-14B-Chat" 32 | # MODEL_PATH="Qwen/Qwen-14B-Chat" 33 | # N_PROCESS=1 34 | # QUANT=none 35 | 36 | 37 | # MODEL_NAME="vicuna-33b-v1.3" 38 | # MODEL_PATH="lmsys/vicuna-33b-v1.3" 39 | # N_PROCESS=1 40 | # QUANT=none 41 | 42 | # MODEL_NAME="llama2-70B-chat-hf" 43 | # MODEL_PATH="meta-llama/Llama-2-70b-chat-hf" 44 | # N_PROCESS=1 45 | # QUANT=4bit 46 | 47 | 48 | for DATA in $DATA_PATH; do 49 | for MODE in $MODE_LIST; do 50 | echo "Running $DATA $MODEL_NAME $MODE" 51 | python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path ${MODEL_PATH} --n $N_PROCESS --data_path $DATA --quant $QUANT 52 | done 53 | done 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | # DATA_PATH=data/multi_hop_grailqa.csv 74 | # DATA_PATH=data/cwq_test_res.csv 75 | 76 | # MODEL_LIST="gpt-3.5-turbo" 77 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 78 | # N_PROCESS=5 79 | # for MODEL_NAME in $MODEL_LIST; do 80 | # FIRST=True 81 | # for MODE in $MODE_LIST; do 82 | # echo "Submitting $MODEL $MODE" 83 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 submit_CPU_job.sh \ 84 | # "python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --n $N_PROCESS --data_path ${DATA_PATH} --wandb" 85 | # if [ "$FIRST" = True ]; then 86 | # sleep 10 87 | # FIRST=False 88 | # fi 89 | # done 90 | # done 91 | 92 | 93 | # MODEL_LIST="llama2-7B-chat-hf" 94 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 95 | # MODEL_PATH="meta-llama/Llama-2-7b-chat-hf" 96 | # for MODEL_NAME in $MODEL_LIST; do 97 | # FIRST=True 98 | # for MODE in $MODE_LIST; do 99 | # echo "Submitting $MODEL_NAME $MODE" 100 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 submit_1-GPU_job.sh \ 101 | # "python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH --data_path ${DATA_PATH} --wandb" 102 | # if [ "$FIRST" = True ]; then 103 | # sleep 10 104 | # FIRST=False 105 | # fi 106 | # done 107 | # done 108 | 109 | # MODEL_LIST="llama2-13B-chat-hf" 110 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 111 | # MODEL_PATH="meta-llama/Llama-2-13b-chat-hf" 112 | # for MODEL_NAME in $MODEL_LIST; do 113 | # FIRST=True 114 | # for MODE in $MODE_LIST; do 115 | # echo "Submitting $MODEL_NAME $MODE" 116 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 --exclude="node[01-05]" submit_1-GPU_job.sh \ 117 | # "TMPDIR=/home/lluo/projects/LLMReasonCert/tmp TEMP=/home/lluo/projects/LLMReasonCert/tmp TMP=/home/lluo/projects/LLMReasonCert/tmp python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH --data_path ${DATA_PATH} --wandb" 118 | # if [ "$FIRST" = True ]; then 119 | # sleep 10 120 | # FIRST=False 121 | # fi 122 | # done 123 | # done 124 | 125 | # MODEL_LIST="llama2-70B-chat-hf" 126 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 127 | # MODEL_PATH="meta-llama/Llama-2-70b-chat-hf" 128 | # for MODEL_NAME in $MODEL_LIST; do 129 | # FIRST=True 130 | # for MODE in $MODE_LIST; do 131 | # echo "Submitting $MODEL_NAME $MODE" 132 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 submit_1-A100_job.sh \ 133 | # "python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH --quant 8bit --data_path ${DATA_PATH} --wandb" 134 | # if [ "$FIRST" = True ]; then 135 | # sleep 10 136 | # FIRST=False 137 | # fi 138 | # done 139 | # done 140 | 141 | 142 | # MODEL_LIST="llama2-70B-chat-hf_fp16" 143 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 144 | # MODEL_PATH="meta-llama/Llama-2-70b-chat-hf" 145 | # for MODEL_NAME in $MODEL_LIST; do 146 | # for MODE in $MODE_LIST; do 147 | # echo "Submitting $MODEL_NAME $MODE" 148 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 submit_2-A100_job.sh \ 149 | # "python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH" 150 | # done 151 | # done 152 | 153 | # MODEL_LIST="Mistral-7B-Instruct-v0.1" 154 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 155 | # MODEL_PATH="mistralai/Mistral-7B-Instruct-v0.1" 156 | # for MODEL_NAME in $MODEL_LIST; do 157 | # FIRST=True 158 | # for MODE in $MODE_LIST; do 159 | # echo "Submitting $MODEL_NAME $MODE" 160 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 submit_1-GPU_job.sh \ 161 | # "conda activate py310_hf_350" \ 162 | # "python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH --data_path ${DATA_PATH} --wandb" 163 | # if [ "$FIRST" = True ]; then 164 | # sleep 10 165 | # FIRST=False 166 | # fi 167 | # done 168 | # done 169 | 170 | # MODEL_LIST="Qwen-7B-Chat" 171 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 172 | # MODEL_PATH="Qwen/Qwen-7B-Chat" 173 | # for MODEL_NAME in $MODEL_LIST; do 174 | # FIRST=True 175 | # for MODE in $MODE_LIST; do 176 | # echo "Submitting $MODEL_NAME $MODE" 177 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 --exclude="node[01-06]" submit_1-GPU_job.sh \ 178 | # "conda activate py310_hf_350" \ 179 | # "TMPDIR=/home/lluo/projects/LLMReasonCert/tmp python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH --data_path ${DATA_PATH} --wandb" 180 | # if [ "$FIRST" = True ]; then 181 | # sleep 10 182 | # FIRST=False 183 | # fi 184 | # done 185 | # done 186 | 187 | # MODEL_LIST="Qwen-14B-Chat" 188 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 189 | # MODEL_PATH="Qwen/Qwen-14B-Chat" 190 | # for MODEL_NAME in $MODEL_LIST; do 191 | # FIRST=True 192 | # for MODE in $MODE_LIST; do 193 | # echo "Submitting $MODEL_NAME $MODE" 194 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 --exclude="node[01-06]" submit_1-GPU_job.sh \ 195 | # "conda activate py310_hf_350" \ 196 | # "TMPDIR=/home/lluo/projects/LLMReasonCert/tmp python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH --data_path ${DATA_PATH} --wandb" 197 | # if [ "$FIRST" = True ]; then 198 | # sleep 10 199 | # FIRST=False 200 | # fi 201 | # done 202 | # done 203 | 204 | # MODEL_LIST="vicuna-33b-v1.3" 205 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 206 | # MODEL_PATH="lmsys/vicuna-33b-v1.3" 207 | # for MODEL_NAME in $MODEL_LIST; do 208 | # FIRST=True 209 | # for MODE in $MODE_LIST; do 210 | # echo "Submitting $MODEL_NAME $MODE" 211 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 submit_1-A100_job.sh \ 212 | # "conda activate py310_hf_350" \ 213 | # "TMPDIR=/home/lluo/projects/LLMReasonCert/tmp python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH --data_path ${DATA_PATH} --wandb" 214 | # if [ "$FIRST" = True ]; then 215 | # sleep 10 216 | # FIRST=False 217 | # fi 218 | # done 219 | # done -------------------------------------------------------------------------------- /discriminative_cert/scripts/submit_discriminative_cert.sh: -------------------------------------------------------------------------------- 1 | # DATA_PATH="data/cwq_test_res.csv" 2 | DATA_PATH="data/multi_hop_grailqa.csv" 3 | MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 4 | 5 | 6 | MODEL_NAME="gpt-3.5-turbo" 7 | MODEL_PATH="None" 8 | N_PROCESS=5 9 | QUANT=none 10 | 11 | # MODEL_NAME="llama2-7B-chat-hf" 12 | # MODEL_PATH="meta-llama/Llama-2-7b-chat-hf" 13 | # N_PROCESS=1 14 | # QUANT=none 15 | 16 | # MODEL_NAME="llama2-13B-chat-hf" 17 | # MODEL_PATH="meta-llama/Llama-2-13b-chat-hf" 18 | # N_PROCESS=1 19 | # QUANT=none 20 | 21 | # MODEL_NAME="Mistral-7B-Instruct-v0.1" 22 | # MODEL_PATH="mistralai/Mistral-7B-Instruct-v0.1" 23 | # N_PROCESS=1 24 | # QUANT=none 25 | 26 | # MODEL_NAME="Qwen-7B-Chat" 27 | # MODEL_PATH="Qwen/Qwen-7B-Chat" 28 | # N_PROCESS=1 29 | # QUANT=none 30 | 31 | # MODEL_NAME="Qwen-14B-Chat" 32 | # MODEL_PATH="Qwen/Qwen-14B-Chat" 33 | # N_PROCESS=1 34 | # QUANT=none 35 | 36 | 37 | # MODEL_NAME="vicuna-33b-v1.3" 38 | # MODEL_PATH="lmsys/vicuna-33b-v1.3" 39 | # N_PROCESS=1 40 | # QUANT=none 41 | 42 | # MODEL_NAME="llama2-70B-chat-hf" 43 | # MODEL_PATH="meta-llama/Llama-2-70b-chat-hf" 44 | # N_PROCESS=1 45 | # QUANT=4bit 46 | 47 | 48 | for DATA in $DATA_PATH; do 49 | for MODE in $MODE_LIST; do 50 | echo "Running $DATA $MODEL_NAME $MODE" 51 | python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path ${MODEL_PATH} --n $N_PROCESS --data_path $DATA --quant $QUANT 52 | done 53 | done 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | # DATA_PATH=data/multi_hop_grailqa.csv 74 | # DATA_PATH=data/cwq_test_res.csv 75 | 76 | # MODEL_LIST="gpt-3.5-turbo" 77 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 78 | # N_PROCESS=5 79 | # for MODEL_NAME in $MODEL_LIST; do 80 | # FIRST=True 81 | # for MODE in $MODE_LIST; do 82 | # echo "Submitting $MODEL $MODE" 83 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 submit_CPU_job.sh \ 84 | # "python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --n $N_PROCESS --data_path ${DATA_PATH} --wandb" 85 | # if [ "$FIRST" = True ]; then 86 | # sleep 10 87 | # FIRST=False 88 | # fi 89 | # done 90 | # done 91 | 92 | 93 | # MODEL_LIST="llama2-7B-chat-hf" 94 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 95 | # MODEL_PATH="meta-llama/Llama-2-7b-chat-hf" 96 | # for MODEL_NAME in $MODEL_LIST; do 97 | # FIRST=True 98 | # for MODE in $MODE_LIST; do 99 | # echo "Submitting $MODEL_NAME $MODE" 100 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 submit_1-GPU_job.sh \ 101 | # "python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH --data_path ${DATA_PATH} --wandb" 102 | # if [ "$FIRST" = True ]; then 103 | # sleep 10 104 | # FIRST=False 105 | # fi 106 | # done 107 | # done 108 | 109 | # MODEL_LIST="llama2-13B-chat-hf" 110 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 111 | # MODEL_PATH="meta-llama/Llama-2-13b-chat-hf" 112 | # for MODEL_NAME in $MODEL_LIST; do 113 | # FIRST=True 114 | # for MODE in $MODE_LIST; do 115 | # echo "Submitting $MODEL_NAME $MODE" 116 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 --exclude="node[01-05]" submit_1-GPU_job.sh \ 117 | # "TMPDIR=/home/lluo/projects/LLMReasonCert/tmp TEMP=/home/lluo/projects/LLMReasonCert/tmp TMP=/home/lluo/projects/LLMReasonCert/tmp python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH --data_path ${DATA_PATH} --wandb" 118 | # if [ "$FIRST" = True ]; then 119 | # sleep 10 120 | # FIRST=False 121 | # fi 122 | # done 123 | # done 124 | 125 | # MODEL_LIST="llama2-70B-chat-hf" 126 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 127 | # MODEL_PATH="meta-llama/Llama-2-70b-chat-hf" 128 | # for MODEL_NAME in $MODEL_LIST; do 129 | # FIRST=True 130 | # for MODE in $MODE_LIST; do 131 | # echo "Submitting $MODEL_NAME $MODE" 132 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 submit_1-A100_job.sh \ 133 | # "python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH --quant 8bit --data_path ${DATA_PATH} --wandb" 134 | # if [ "$FIRST" = True ]; then 135 | # sleep 10 136 | # FIRST=False 137 | # fi 138 | # done 139 | # done 140 | 141 | 142 | # MODEL_LIST="llama2-70B-chat-hf_fp16" 143 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 144 | # MODEL_PATH="meta-llama/Llama-2-70b-chat-hf" 145 | # for MODEL_NAME in $MODEL_LIST; do 146 | # for MODE in $MODE_LIST; do 147 | # echo "Submitting $MODEL_NAME $MODE" 148 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 submit_2-A100_job.sh \ 149 | # "python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH" 150 | # done 151 | # done 152 | 153 | # MODEL_LIST="Mistral-7B-Instruct-v0.1" 154 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 155 | # MODEL_PATH="mistralai/Mistral-7B-Instruct-v0.1" 156 | # for MODEL_NAME in $MODEL_LIST; do 157 | # FIRST=True 158 | # for MODE in $MODE_LIST; do 159 | # echo "Submitting $MODEL_NAME $MODE" 160 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 submit_1-GPU_job.sh \ 161 | # "conda activate py310_hf_350" \ 162 | # "python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH --data_path ${DATA_PATH} --wandb" 163 | # if [ "$FIRST" = True ]; then 164 | # sleep 10 165 | # FIRST=False 166 | # fi 167 | # done 168 | # done 169 | 170 | # MODEL_LIST="Qwen-7B-Chat" 171 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 172 | # MODEL_PATH="Qwen/Qwen-7B-Chat" 173 | # for MODEL_NAME in $MODEL_LIST; do 174 | # FIRST=True 175 | # for MODE in $MODE_LIST; do 176 | # echo "Submitting $MODEL_NAME $MODE" 177 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 --exclude="node[01-06]" submit_1-GPU_job.sh \ 178 | # "conda activate py310_hf_350" \ 179 | # "TMPDIR=/home/lluo/projects/LLMReasonCert/tmp python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH --data_path ${DATA_PATH} --wandb" 180 | # if [ "$FIRST" = True ]; then 181 | # sleep 10 182 | # FIRST=False 183 | # fi 184 | # done 185 | # done 186 | 187 | # MODEL_LIST="Qwen-14B-Chat" 188 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 189 | # MODEL_PATH="Qwen/Qwen-14B-Chat" 190 | # for MODEL_NAME in $MODEL_LIST; do 191 | # FIRST=True 192 | # for MODE in $MODE_LIST; do 193 | # echo "Submitting $MODEL_NAME $MODE" 194 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 --exclude="node[01-06]" submit_1-GPU_job.sh \ 195 | # "conda activate py310_hf_350" \ 196 | # "TMPDIR=/home/lluo/projects/LLMReasonCert/tmp python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH --data_path ${DATA_PATH} --wandb" 197 | # if [ "$FIRST" = True ]; then 198 | # sleep 10 199 | # FIRST=False 200 | # fi 201 | # done 202 | # done 203 | 204 | # MODEL_LIST="vicuna-33b-v1.3" 205 | # MODE_LIST="zero-shot zero-shot-cot few-shot few-shot-cot" 206 | # MODEL_PATH="lmsys/vicuna-33b-v1.3" 207 | # for MODEL_NAME in $MODEL_LIST; do 208 | # FIRST=True 209 | # for MODE in $MODE_LIST; do 210 | # echo "Submitting $MODEL_NAME $MODE" 211 | # sbatch --job-name ${MODEL_NAME}-${MODE} --cpus-per-task 1 --ntasks 1 submit_1-A100_job.sh \ 212 | # "conda activate py310_hf_350" \ 213 | # "TMPDIR=/home/lluo/projects/LLMReasonCert/tmp python discriminative-cert.py --model_name $MODEL_NAME --mode $MODE --model_path $MODEL_PATH --data_path ${DATA_PATH} --wandb" 214 | # if [ "$FIRST" = True ]; then 215 | # sleep 10 216 | # FIRST=False 217 | # fi 218 | # done 219 | # done 220 | -------------------------------------------------------------------------------- /scripts/gen-cert/job_eval_llm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | DATASET_LIST="cwq grail_qa" # 3 | PROB_THRES=0.6 4 | ENT_THRES=0.5 5 | 6 | # gpt-3.5-turbo 7 | for DATASET in $DATASET_LIST; do 8 | JOB_NAME=gpt-3.5-turbo-fewshot-cot-hint-$DATASET-temp-0.7-p-0.9-consistency-1 9 | MODE=fewshot-cot-hint 10 | IS_SC=1 11 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/gpt-3.5-turbo/$JOB_NAME/llm_prompt_response.jsonl 12 | sbatch --job-name $JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 13 | 14 | JOB_NAME=gpt-3.5-turbo-fewshot-cot-hint-ground-$DATASET-temp-0.7-p-0.9-consistency-1 15 | MODE=fewshot-cot-hint-ground 16 | IS_SC=1 17 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/gpt-3.5-turbo/$JOB_NAME/llm_prompt_response.jsonl 18 | sbatch --job-name $JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 19 | 20 | JOB_NAME=gpt-3.5-turbo-fewshot-cot-only-$DATASET-temp-0.0-p-1.0-consistency-1 21 | MODE=fewshot-cot-only 22 | IS_SC=1 23 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/gpt-3.5-turbo/$JOB_NAME/llm_prompt_response.jsonl 24 | sbatch --job-name $JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 25 | 26 | JOB_NAME=gpt-3.5-turbo-fewshot-cot-only-$DATASET-temp-0.3-p-1.0-consistency-1 27 | MODE=fewshot-cot-only 28 | IS_SC=1 29 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/gpt-3.5-turbo/$JOB_NAME/llm_prompt_response.jsonl 30 | sbatch --job-name $JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 31 | 32 | JOB_NAME=gpt-3.5-turbo-fewshot-cot-only-$DATASET-temp-0.5-p-1.0-consistency-1 33 | MODE=fewshot-cot-only 34 | IS_SC=1 35 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/gpt-3.5-turbo/$JOB_NAME/llm_prompt_response.jsonl 36 | sbatch --job-name $JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 37 | 38 | JOB_NAME=gpt-3.5-turbo-fewshot-cot-only-$DATASET-temp-0.7-p-0.9-consistency-1 39 | MODE=fewshot-cot-only 40 | IS_SC=1 41 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/gpt-3.5-turbo/$JOB_NAME/llm_prompt_response.jsonl 42 | sbatch --job-name $JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 43 | 44 | JOB_NAME=gpt-3.5-turbo-fewshot-cot-only-$DATASET-temp-0.7-p-0.9-consistency-20 45 | MODE=fewshot-cot-only 46 | IS_SC=4 47 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/gpt-3.5-turbo/$JOB_NAME/llm_prompt_response.jsonl 48 | sbatch --job-name $JOB_NAME-$IS_SC generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 49 | 50 | JOB_NAME=gpt-3.5-turbo-fewshot-cot-only-$DATASET-temp-0.7-p-1.0-consistency-1 51 | MODE=fewshot-cot-only 52 | IS_SC=1 53 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/gpt-3.5-turbo/$JOB_NAME/llm_prompt_response.jsonl 54 | sbatch --job-name $JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 55 | 56 | JOB_NAME=gpt-3.5-turbo-fewshot-cot-only-$DATASET-temp-1.0-p-0.95-consistency-1 57 | MODE=fewshot-cot-only 58 | IS_SC=1 59 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/gpt-3.5-turbo/$JOB_NAME/llm_prompt_response.jsonl 60 | sbatch --job-name $JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 61 | done 62 | 63 | 64 | # llama 65 | for DATASET in $DATASET_LIST; do 66 | for MODEL_SIZE in "70b"; do 67 | MODEL_NAME=Llama-2-$MODEL_SIZE-chat-hf 68 | 69 | JOB_NAME=cot-hint-temp-0.7-p-0.9 70 | MODE=fewshot-cot-hint 71 | IS_SC=1 72 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/$MODEL_NAME/$JOB_NAME/llm_prompt_response.jsonl 73 | sbatch --job-name $MODEL_NAME-$JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 74 | 75 | JOB_NAME=cot-temp-0.7-p-0.9 76 | MODE=fewshot-cot-only 77 | IS_SC=1 78 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/$MODEL_NAME/$JOB_NAME/llm_prompt_response.jsonl 79 | sbatch --job-name $MODEL_NAME-$JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 80 | 81 | JOB_NAME=cot-temp-0.7-p-0.9 82 | MODE=fewshot-cot-only 83 | IS_SC=4 84 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/$MODEL_NAME/$JOB_NAME/llm_prompt_response.jsonl 85 | sbatch --job-name $MODEL_NAME-$JOB_NAME-$IS_SC generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 86 | done 87 | done 88 | 89 | 90 | # qwen 91 | for DATASET in $DATASET_LIST; do 92 | for MODEL_SIZE in "7B" "14B"; do 93 | MODEL_NAME=Qwen-$MODEL_SIZE-Chat 94 | 95 | JOB_NAME=cot-hint-temp-0.7-p-0.9 96 | MODE=fewshot-cot-hint 97 | IS_SC=1 98 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/$MODEL_NAME/$JOB_NAME/llm_prompt_response.jsonl 99 | sbatch --job-name $MODEL_NAME-$JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 100 | 101 | JOB_NAME=cot-temp-0.7-p-0.9 102 | MODE=fewshot-cot-only 103 | IS_SC=1 104 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/$MODEL_NAME/$JOB_NAME/llm_prompt_response.jsonl 105 | sbatch --job-name $MODEL_NAME-$JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 106 | 107 | JOB_NAME=cot-temp-0.7-p-0.9 108 | MODE=fewshot-cot-only 109 | IS_SC=4 110 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/$MODEL_NAME/$JOB_NAME/llm_prompt_response.jsonl 111 | sbatch --job-name $MODEL_NAME-$JOB_NAME-$IS_SC generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 112 | done 113 | done 114 | 115 | 116 | # mistral 117 | for DATASET in $DATASET_LIST; do 118 | MODEL_NAME=Mistral-7B-Instruct-v0.1 119 | 120 | JOB_NAME=cot-hint-temp-0.7-p-0.9 121 | MODE=fewshot-cot-hint 122 | IS_SC=1 123 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/$MODEL_NAME/$JOB_NAME/llm_prompt_response.jsonl 124 | sbatch --job-name $MODEL_NAME-$JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 125 | 126 | JOB_NAME=cot-temp-0.7-p-0.9 127 | MODE=fewshot-cot-only 128 | IS_SC=1 129 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/$MODEL_NAME/$JOB_NAME/llm_prompt_response.jsonl 130 | sbatch --job-name $MODEL_NAME-$JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 131 | 132 | JOB_NAME=cot-temp-0.7-p-0.9 133 | MODE=fewshot-cot-only 134 | IS_SC=4 135 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/$MODEL_NAME/$JOB_NAME/llm_prompt_response.jsonl 136 | sbatch --job-name $MODEL_NAME-$JOB_NAME-$IS_SC generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 137 | done 138 | 139 | 140 | # vicuna 141 | for DATASET in $DATASET_LIST; do 142 | MODEL_NAME=vicuna-33b-v1.3 143 | 144 | JOB_NAME=cot-hint-temp-0.7-p-0.9 145 | MODE=fewshot-cot-hint 146 | IS_SC=1 147 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/$MODEL_NAME/$JOB_NAME/llm_prompt_response.jsonl 148 | sbatch --job-name $MODEL_NAME-$JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 149 | 150 | JOB_NAME=cot-temp-0.7-p-0.9 151 | MODE=fewshot-cot-only 152 | IS_SC=1 153 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/$MODEL_NAME/$JOB_NAME/llm_prompt_response.jsonl 154 | sbatch --job-name $MODEL_NAME-$JOB_NAME generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 155 | 156 | JOB_NAME=cot-temp-0.7-p-0.9 157 | MODE=fewshot-cot-only 158 | IS_SC=4 159 | DATA_PATH=/home/xvuthith/da33_scratch/lluo/LLMReasoningCert/data/$DATASET/$MODEL_NAME/$JOB_NAME/llm_prompt_response.jsonl 160 | sbatch --job-name $MODEL_NAME-$JOB_NAME-$IS_SC generative_cert/scripts/eval_llm.sh $DATASET $MODE $DATA_PATH $PROB_THRES $ENT_THRES $IS_SC 161 | done 162 | -------------------------------------------------------------------------------- /generative-cert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import pandas as pd 5 | from extract_triplet.extract_triplet import ( 6 | load_vdb, 7 | search_triples_prompt, 8 | load_embed_model, 9 | ) 10 | from certify_fact import certify_prompt_triples 11 | from generative_cert.utils.utils import ( 12 | is_correct_llm_ans_format, 13 | is_lack_of_knowledge, 14 | ans_by_sc, 15 | get_final_answer, 16 | preprocess_ans, 17 | ) 18 | 19 | 20 | def retrieve_triplets( 21 | embed_model_name, device, top_k, db_path, reasoning_ans_dat, only_mean 22 | ): # , prob_thres=prob_thres): 23 | print("Load embedding model") 24 | embed_model = load_embed_model(embed_model_name, device) 25 | print("Starting loading VectorDB") 26 | db_index, triplets = load_vdb(db_path) 27 | print("Corpus contains {} triplets.".format(len(triplets))) 28 | 29 | reasoning_ans_dat[f"{top_k}_extracted_triplets"] = reasoning_ans_dat[ 30 | "reasoning_ans" 31 | ].progress_apply( 32 | lambda reasoning_ans: [ 33 | search_triples_prompt( 34 | embed_model, db_index, triplets, ans, top_k, only_mean=only_mean 35 | ) 36 | for ans in reasoning_ans 37 | ] 38 | ) 39 | return reasoning_ans_dat 40 | 41 | 42 | def eval(row, prob_thres, ent_thres): 43 | pred_ans = [get_final_answer(ans) for ans in row["reasoning_ans"]] 44 | ground_ans = row["groundtruth_answer"] 45 | extracted_triplets = row[f"{top_k}_extracted_triplets"] 46 | try: 47 | res = [] 48 | short_res = {} 49 | for pred, best_triples in zip(pred_ans, extracted_triplets): 50 | best_tri = [dic["best_triple"].values() for dic in best_triples] 51 | single_res = certify_prompt_triples( 52 | pred, ground_ans, best_tri, prob_thres, ent_thres 53 | ) 54 | res.append(single_res) 55 | if not short_res: 56 | short_res = single_res 57 | else: 58 | if ( 59 | single_res["p_incorrect_answer"] == 0 60 | and single_res["p_incorrect_reasoning"] == 0 61 | ): 62 | short_res = single_res 63 | else: 64 | if ( 65 | single_res["p_incorrect_answer"] == 0 66 | and short_res["p_incorrect_answer"] == 1 67 | ): 68 | short_res = single_res 69 | return res, short_res 70 | except: 71 | print(row) 72 | raise ValueError() 73 | 74 | 75 | def save_result_df(dat, raw_dat_path, dataset, is_sc): 76 | version = raw_dat_path.split("/")[-2] + f"-is_sc_{is_sc}" 77 | model_name = raw_dat_path.split("/")[-3] 78 | out_dir = "LLMReasoningCert/LLMReasonCert/tmp/revision/results/" 79 | out_dir = os.path.join(out_dir, dataset, model_name, version) 80 | if not os.path.exists(out_dir): 81 | os.makedirs(out_dir, exist_ok=True) 82 | out_path = os.path.join(out_dir, "full.jsonl") 83 | dat.to_json(out_path, orient="records", lines=True) 84 | return out_dir 85 | 86 | 87 | def calc_and_save_result_metrics( 88 | out_dir, dat, raw_dat_len, lack_knowledge_rate, incorrect_instruction_rate 89 | ): 90 | ans_acc = round(1 - sum(dat["p_incorrect_answer"]) / len(dat), 4) * 100 91 | reasoning_cert_ratio = ( 92 | round(1 - sum(dat["p_incorrect_reasoning"]) / len(dat), 4) * 100 93 | ) 94 | ## fine-grain uncert reasoning causes: fact or relation 95 | reasoning_uncert_by_fact_ratio = ( 96 | round(sum(dat["p_factual_error"]) / sum(dat["p_incorrect_reasoning"]), 4) * 100 97 | ) 98 | reasoning_uncert_by_relation_ratio = 100 - reasoning_uncert_by_fact_ratio 99 | 100 | # write 101 | with open(os.path.join(out_dir, "metric.txt"), "w") as fout: 102 | fout.write(f"lack_knowledge_rate {lack_knowledge_rate}%" + "\n") 103 | fout.write(f"incorrect_instruction_rate {incorrect_instruction_rate}%" + "\n") 104 | fout.write(f"answer_accuracy {ans_acc}%" + "\n") 105 | fout.write(f"reasoning_certification_ratio {reasoning_cert_ratio}%" + "\n") 106 | fout.write( 107 | f"reasoning_uncert_by_fact_ratio {reasoning_uncert_by_fact_ratio}%" + "\n" 108 | ) 109 | fout.write( 110 | f"reasoning_uncert_by_relation_ratio {reasoning_uncert_by_relation_ratio}%" 111 | + "\n" 112 | ) 113 | 114 | 115 | def load_data(raw_dat_path, dataset, is_sc): 116 | # load llm generation data 117 | with open(raw_dat_path) as f: 118 | dat = [json.loads(l) for l in f] 119 | dat = pd.DataFrame(dat) 120 | dat["raw_reasoning_ans"] = dat["reasoning_ans"] 121 | dat["reasoning_ans"] = dat["reasoning_ans"].map(lambda x: ans_by_sc(x, is_sc)) 122 | 123 | # if dataset is grail_qa, get ground_truth_paths 124 | if dataset == "grail_qa": 125 | splitted_ground_truth_paths_p = "LLMReasoningCert/data/grail_qa/gpt-3.5-turbo/test/splitted_ground_truth_paths.json" 126 | with open(splitted_ground_truth_paths_p) as f: 127 | origin_dat = json.load(f) 128 | origin_dat = origin_dat["min_2hop"] + origin_dat["min_multihop"] 129 | origin_dat = {v["id"]: v["ground_truth_paths"] for v in origin_dat} 130 | dat["ground_truth_paths"] = dat["id"].map(lambda id: origin_dat.get(id, None)) 131 | dat = dat.dropna().reset_index() 132 | 133 | print(dat.head(), "\n", len(dat)) 134 | return dat 135 | 136 | 137 | def cert_llm_reasoning( 138 | raw_dat_path, 139 | embed_model_name, 140 | device, 141 | top_k, 142 | dataset, 143 | mode, 144 | is_sc, 145 | prob_thres, 146 | ent_thres, 147 | ): 148 | # 1. load llm generation data 149 | print("Loading data") 150 | raw_dat = load_data(raw_dat_path, dataset, is_sc) 151 | raw_dat_len = len(raw_dat) 152 | 153 | # eval1. check incorrect_instruction & lack of knowledge 154 | ## lack of knowledge 155 | print("Checking abstention") 156 | dat1 = raw_dat[ 157 | raw_dat["reasoning_ans"].map(is_lack_of_knowledge) == False 158 | ].reset_index(drop=True) 159 | lack_knowledge_rate = round(1 - len(dat1) / raw_dat_len, 4) * 100 160 | print(dat1.head(), "\nLen:", len(dat1)) 161 | ## incorrect_instruction 162 | print("Checking incorrect_instruction") 163 | len_dat = len(dat1) 164 | dat1["reasoning_ans"] = dat1["reasoning_ans"].map( 165 | lambda ans: [preprocess_ans(x) for x in ans] 166 | if isinstance(ans, list) 167 | else preprocess_ans(ans) 168 | ) 169 | dat = dat1[dat1["reasoning_ans"].map(is_correct_llm_ans_format) == 1].reset_index( 170 | drop=True 171 | ) 172 | incorrect_instruction_rate = round(1 - len(dat) / len_dat, 4) * 100 173 | print(dat.head(), "\nLen:", len(dat)) 174 | 175 | # 2. retrieve triplets 176 | print("Retrieving triplets") 177 | db_path = f"LLMReasoningCert/data/db_extract/{dataset}/only_test_set" 178 | dat = retrieve_triplets( 179 | embed_model_name, device, top_k, db_path, dat, only_mean=True 180 | ) 181 | 182 | # 3. evaluate reasoning 183 | # a. find GroundTruth Answer 184 | dat["groundtruth_answer"] = dat["ground_truth_paths"].map( 185 | lambda v: list({i[-1][-1] for i in v if i}) 186 | ) 187 | # b. eval2. eval fact, reasoning and add them to origin data 188 | print("Evaluating fact & reasoning") 189 | dat[["raw_eval", "short_eval"]] = dat.progress_apply( 190 | lambda r: eval(r, prob_thres, ent_thres), axis=1, result_type="expand" 191 | ) 192 | dat = pd.concat([dat, pd.DataFrame(dat["short_eval"].tolist())], axis=1) 193 | 194 | # 4. save 195 | # a. save data to double check and example for writing paper 196 | print("Saving full data") 197 | out_dir = save_result_df(dat, raw_dat_path, dataset, is_sc) 198 | # b. calculate and save metrics 199 | print("Saving metrics at", out_dir) 200 | calc_and_save_result_metrics( 201 | out_dir, dat, raw_dat_len, lack_knowledge_rate, incorrect_instruction_rate 202 | ) 203 | print("Done!") 204 | 205 | 206 | if __name__ == "__main__": 207 | argparser = argparse.ArgumentParser() 208 | argparser.add_argument( 209 | "--dataset", type=str, default="cwq", choices=["cwq", "grail_qa"] 210 | ) 211 | argparser.add_argument( 212 | "--mode", 213 | type=str, 214 | choices=["fewshot-cot-only", "fewshot-cot-hint", "fewshot-cot-hint-ground"], 215 | ) 216 | argparser.add_argument("--prob_thres", type=float, default=0.7) 217 | argparser.add_argument("--ent_thres", type=float, default=0.5) 218 | argparser.add_argument("--is_sc", type=int, default=1) 219 | argparser.add_argument( 220 | "--raw_dat_path", 221 | type=str, 222 | required=True, 223 | help="LLMReasoningCert/data/cwq/gpt-3.5-turbo/gpt-3.5-turbo-fewshot-cot-only-cwq-temp-0.7-p-0.9-consistency-1/llm_prompt_response.jsonl", 224 | ) 225 | args = argparser.parse_args() 226 | 227 | top_k = 10 228 | embed_model_name = "sentence-transformers/all-mpnet-base-v2" 229 | device = "cpu" 230 | prob_thres, ent_thres = args.prob_thres, args.ent_thres 231 | 232 | cert_llm_reasoning( 233 | args.raw_dat_path, 234 | embed_model_name, 235 | device, 236 | top_k, 237 | args.dataset, 238 | args.mode, 239 | args.is_sc, 240 | prob_thres, 241 | ent_thres, 242 | ) 243 | -------------------------------------------------------------------------------- /discriminative-cert.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pandas as pd 3 | import os 4 | import argparse 5 | from tqdm import tqdm 6 | from multiprocessing.dummy import Pool as ThreadPool 7 | import json 8 | import discriminative_cert.utils as utils 9 | from functools import partial 10 | from discriminative_cert.discriminative_prompts import * 11 | from llms import get_registed_model 12 | import wandb 13 | import distutils.util 14 | 15 | 16 | def prompt_builder(question, answer, path, mode): 17 | if mode == "zero-shot": 18 | path_string = utils.path_to_string(path) 19 | query = ZERO_PROMPT.format(question=question, answer=answer, path=path_string) 20 | elif mode == "zero-shot-cot": 21 | path_string = utils.path_to_string(path) 22 | query = ZERO_COT_PROMPT.format( 23 | question=question, answer=answer, path=path_string 24 | ) 25 | elif mode == "few-shot": 26 | path_string = utils.path_to_string(path) 27 | query = FEWSHOT_PROMPT.format( 28 | question=question, answer=answer, path=path_string 29 | ) 30 | elif mode == "few-shot-cot": 31 | path_string = utils.path_to_string(path) 32 | query = FEWSHOT_COT_PROMPT.format( 33 | question=question, answer=answer, path=path_string 34 | ) 35 | # elif mode == "neg-few-shot": 36 | # path_string = utils.path_to_string(path) 37 | # query = NEG_FEWSHOT_PROMPT.format(question=question, path=path_string) 38 | # elif mode == "neg-few-shot-cot": 39 | # path_string = utils.path_to_string(path) 40 | # query = NEG_FEWSHOT_COT_PROMPT.format(question=question, path=path_string) 41 | # elif mode == "neg-reorder-zero-shot": 42 | # path_string = utils.reoder_path_to_string(path) 43 | # query = ZERO_PROMPT.format(question=question, path=path_string) 44 | # elif mode == "neg-reorder-few-shot": 45 | # path_string = utils.reoder_path_to_string(path) 46 | # query = NEG_REORDER_FEWSHOT_PROMPT.format(question=question, path=path_string) 47 | else: 48 | raise NotImplementedError(f"Mode {mode} is not implemented") 49 | return query 50 | 51 | 52 | def get_output_file(path, force=False): 53 | if not os.path.exists(path) or force: 54 | fout = open(path, "w") 55 | return fout, {} 56 | else: 57 | fout = open(path, "r+") 58 | processed_results = {} 59 | for line in fout: 60 | try: 61 | results = json.loads(line) 62 | processed_results[results["id"]] = results["acc"] 63 | except: 64 | print("Error in parsing line: ", line) 65 | fout.seek(-len(line), 1) 66 | break 67 | return fout, processed_results 68 | 69 | 70 | def predict(data, args, processed_list, model): 71 | data_id, row = data 72 | if data_id in processed_list: 73 | return None 74 | question = row["question"] 75 | answer = row["ground_answer"] 76 | answer_string = " ".join(eval(answer)) 77 | ground_truth_paths = row["ground_reasoning_path"] 78 | result_list = [] 79 | for p in ground_truth_paths: 80 | query = prompt_builder(question, answer_string, p, mode=args.mode) 81 | query = model.prepare_model_prompt(query) 82 | response = model.generate_sentence(query) 83 | if response is None: 84 | continue 85 | prediction = 0 86 | if args.eval_neg: 87 | if "NO" in response.upper() and "YES" not in response.upper(): 88 | prediction = 1 89 | else: 90 | if "YES" in response.upper() and "NO" not in response.upper(): 91 | prediction = 1 92 | result_list.append( 93 | { 94 | "path": p, 95 | "prediction": prediction, 96 | "raw_response": response, 97 | "raw_input": query, 98 | } 99 | ) 100 | avg_result = float(sum([r["prediction"] for r in result_list])) / len(result_list) 101 | result = { 102 | "id": data_id, 103 | "ground_answer": answer, 104 | "question": question, 105 | "acc": avg_result, 106 | "details": result_list, 107 | } 108 | return result 109 | 110 | 111 | def main(args, LLM): 112 | df = pd.read_csv(args.data_path) 113 | df.rename(columns={" ": "question"}, inplace=True) 114 | if args.eval_neg: 115 | df["ground_reasoning_path"] = df["negative_paths"] 116 | 117 | df["ground_reasoning_path"] = df["ground_reasoning_path"].apply(lambda x: eval(x)) 118 | # print(df.columns) 119 | 120 | input_file_name = os.path.basename(args.data_path) 121 | output_dir = os.path.join(args.output_path, input_file_name, args.model_name) 122 | 123 | while not os.path.exists(output_dir): 124 | try: 125 | os.makedirs(output_dir) 126 | except Exception as e: 127 | print(e) 128 | time.sleep(10) 129 | pass 130 | 131 | output_file_name = f"predictions_{args.mode}{args.postfix}.jsonl" 132 | output_name = os.path.join(output_dir, output_file_name) 133 | 134 | fout, processed_list = get_output_file(output_name, force=args.force) 135 | 136 | result_list = [value for value in processed_list.values()] 137 | 138 | model = LLM(args) 139 | print("Prepare pipline for inference...") 140 | model.prepare_for_inference() 141 | 142 | if args.wandb: 143 | wandb.init( 144 | config=args, 145 | project="discriminative-cert", 146 | name=f"{input_file_name}_{args.model_name}_{args.mode}{args.postfix}", 147 | ) 148 | 149 | if args.n > 1: 150 | with ThreadPool(args.n) as p: 151 | with tqdm( 152 | p.imap_unordered( 153 | partial( 154 | predict, args=args, processed_list=processed_list, model=model 155 | ), 156 | df.iterrows(), 157 | ), 158 | total=len(df), 159 | ) as phar: 160 | for r in phar: 161 | if r is None: 162 | continue 163 | fout.write(json.dumps(r) + "\n") 164 | result_list.append(r["acc"]) 165 | if args.debug: 166 | for r in r["details"]: 167 | print(f"Input: {r['raw_input']}") 168 | print(f"Response: {r['raw_response']}") 169 | print(f"Prediction: {r['prediction']}") 170 | phar.set_postfix( 171 | {"ACC": float(sum(result_list)) / len(result_list)} 172 | ) 173 | else: 174 | with tqdm(df.iterrows(), total=len(df)) as phar: 175 | for data in phar: 176 | r = predict(data, args, processed_list, model) 177 | if r is None: 178 | continue 179 | fout.write(json.dumps(r) + "\n") 180 | result_list.append(r["acc"]) 181 | if args.debug: 182 | for r in r["details"]: 183 | print(f"Input: {r['raw_input']}") 184 | print(f"Response: {r['raw_response']}") 185 | print(f"Prediction: {r['prediction']}") 186 | phar.set_postfix({"ACC": float(sum(result_list)) / len(result_list)}) 187 | fout.close() 188 | print("Accuracy: ", float(sum(result_list)) / len(result_list)) 189 | if args.wandb: 190 | wandb.log({"acc": float(sum(result_list)) / len(result_list)}) 191 | with open( 192 | os.path.join(output_dir, f"results_{args.mode}{args.postfix}.txt"), "w" 193 | ) as fout: 194 | fout.write(f"Accuracy: {float(sum(result_list))/len(result_list)}\n") 195 | 196 | 197 | if __name__ == "__main__": 198 | parser = argparse.ArgumentParser() 199 | parser.add_argument("--data_path", default="data/cwq_test_res.csv", type=str) 200 | parser.add_argument("--output_path", default="new_results_w_ans") 201 | parser.add_argument("--postfix", default="", type=str) 202 | parser.add_argument("--n", default=1, type=int, help="number of processes") 203 | parser.add_argument( 204 | "--model_name", "-m", type=str, help="model name", default="gpt-3.5-turbo" 205 | ) 206 | parser.add_argument( 207 | "--mode", 208 | default="zero-shot", 209 | choices=[ 210 | "zero-shot", 211 | "zero-shot-cot", 212 | "few-shot", 213 | "few-shot-cot", 214 | "neg-few-shot", 215 | "neg-few-shot-cot", 216 | "neg-reorder-zero-shot", 217 | "neg-reorder-few-shot", 218 | "neg-reorder-few-shot-cot", 219 | ], 220 | ) 221 | parser.add_argument("--debug", action="store_true") 222 | parser.add_argument("--force", action="store_true") 223 | parser.add_argument( 224 | "--wandb", 225 | default=False, 226 | type=lambda x: bool(distutils.util.strtobool(x)), 227 | help="enable wandb", 228 | ) 229 | parser.add_argument( 230 | "--eval_neg", 231 | default=False, 232 | type=lambda x: bool(distutils.util.strtobool(x)), 233 | help="enable wandb", 234 | ) 235 | args, _ = parser.parse_known_args() 236 | LLM = get_registed_model(args.model_name) 237 | LLM.add_args(parser) 238 | args = parser.parse_args() 239 | 240 | main(args, LLM) 241 | -------------------------------------------------------------------------------- /preprocess_data/sparql.py: -------------------------------------------------------------------------------- 1 | from SPARQLWrapper import SPARQLWrapper, JSON 2 | import re 3 | import json 4 | from tqdm import tqdm 5 | from generative_cert.utils.utils import processed_groundtruth_path 6 | 7 | 8 | # https://github.com/lanyunshi/Multi-hopComplexKBQA/blob/master/code/SPARQL_test.py 9 | # https://juejin.cn/post/7283690681175113740 10 | # python virtuoso.py start 3001 -d virtuoso_db 11 | # python virtuoso.py stop 3001 12 | class SparQL(object): 13 | def __init__(self, SPARQLPATH): 14 | self.SPARQLPATH = SPARQLPATH 15 | self.sparql = SPARQLWrapper( 16 | SPARQLPATH 17 | ) # PREFIX xsd: 18 | self.sparql.setReturnFormat(JSON) 19 | 20 | def query(self, sparql_txt, variable="value"): 21 | self.sparql.setQuery(sparql_txt) 22 | results = self.sparql.query().convert() 23 | id_vals = [ 24 | str(v[variable]["value"]).split("/ns/")[-1] 25 | for v in results["results"]["bindings"] 26 | ] 27 | name_vals = [self.SQL_entity2name(id) for id in id_vals] 28 | return id_vals, name_vals 29 | 30 | def query_object(self, subject, rel): 31 | sparql_txt = ( 32 | """PREFIX ns:\nSELECT ?t WHERE {FILTER (!isLiteral(?t) OR lang(?t) = '' OR langMatches(lang(?t), 'en'))\nns:%s ns:%s ?t}""" 33 | % (subject, rel) 34 | ) 35 | self.sparql.setQuery(sparql_txt) 36 | results = self.sparql.query().convert() 37 | try: 38 | obj_id = results["results"]["bindings"][0]["t"]["value"].split("/ns/")[-1] 39 | except: 40 | return None, None 41 | obj_name = self.SQL_entity2name(obj_id) 42 | sub_name = self.SQL_entity2name(subject) 43 | return [sub_name, rel, obj_name], obj_id 44 | 45 | def query_reasoning_path(self, sparql_txt): 46 | if "PREFIX rdf" in sparql_txt: # grail_qa 47 | # sparql_txt = sparql_txt.replace('PREFIX rdf: PREFIX rdfs: ','') 48 | sparql_txt = re.sub("SELECT .+ WHERE", "CONSTRUCT WHERE", sparql_txt) 49 | sparql_txt = sparql_txt.replace("} }", "}").replace("} }", "}") 50 | else: # cwq, webqsp 51 | sparql_txt = re.sub( 52 | "SELECT DISTINCT .+\\nWHERE", "CONSTRUCT WHERE", sparql_txt 53 | ) 54 | sparql_txt = sparql_txt.replace("#MANUAL SPARQL", "") 55 | # print('sparql_txt', sparql_txt) 56 | self.sparql.setQuery(sparql_txt) 57 | results = self.sparql.query().convert() 58 | # print(results) 59 | path = [] 60 | for e in results["results"]["bindings"]: 61 | try: 62 | l = [str(v["value"]).split("/ns/")[-1] for v in e.values()] 63 | if l[1] == "type.object.type": 64 | continue 65 | dic = [self.SQL_entity2name(id) for id in l] 66 | path.append(dic) 67 | except: 68 | print([v["value"] for v in e.values()]) 69 | raise ValueError() 70 | # print(dic) 71 | processed_path = processed_groundtruth_path(path) 72 | # for tri in processed_path: 73 | # print(tri) 74 | return path, processed_path 75 | 76 | def get_ents(self, sparql_txt): 77 | # retrieve all entities 78 | sparql_txt = sparql_txt.replace("DISTINCT ?x", "DISTINCT *") 79 | self.sparql.setQuery(sparql_txt) 80 | results = self.sparql.query().convert() 81 | # print(results) 82 | ent_l = [] 83 | for e in results["results"]["bindings"]: 84 | dic = {k: v["value"].split("/ns/")[-1] for k, v in e.items()} 85 | dic = {code: self.SQL_entity2name(id) for code, id in dic.items()} 86 | ent_l.append(dic) 87 | return ent_l 88 | 89 | def SQL_query2path(self, sparql_txt): 90 | code2id_ents = self.get_ents(sparql_txt) 91 | # print(code2id_ents) 92 | # print() 93 | query_triplets = self.extract_triples_from_query(sparql_txt) 94 | # print(query_triplets) 95 | sparql_paths = [] 96 | for dic in code2id_ents: 97 | path = [ 98 | [dic.get(t[0], t[0]), t[1], dic.get(t[2], t[2])] for t in query_triplets 99 | ] 100 | sparql_paths.append(path) 101 | return sparql_paths 102 | 103 | def triplet_reverse(self, tri): 104 | head, rel, tail = tri 105 | head_condition = re.search("^[mg]\.", head) or re.search("^\d+", head) 106 | tail_condition = re.search("^[mg]\.", tail) or re.search("^\d+", tail) 107 | if head_condition and tail_condition: 108 | return tri 109 | if not head_condition: 110 | condition = "?t ns:%s ns:%s" % (rel, tail) 111 | else: 112 | condition = "ns:%s ns:%s ?t" % (head, rel) 113 | # print(condition) 114 | self.sparql.setQuery( 115 | """PREFIX ns:\nSELECT ?t WHERE {FILTER (!isLiteral(?t) OR lang(?t) = '' OR langMatches(lang(?t), 'en'))\n%s}""" 116 | % condition 117 | ) 118 | results = self.sparql.query().convert() 119 | # print(results) 120 | name = results["results"]["bindings"][0]["t"]["value"].split("/ns/")[-1] 121 | if not head_condition: 122 | return name, rel, tail 123 | else: 124 | return head, rel, name 125 | 126 | def SQL_entity2name(self, e): 127 | if not re.search("^[mg]\.", e): 128 | return e 129 | self.sparql.setQuery( 130 | """PREFIX ns:\nSELECT ?t WHERE {FILTER (!isLiteral(?t) OR lang(?t) = '' OR langMatches(lang(?t), 'en'))\nns:%s ns:type.object.name ?t.}""" 131 | % (e) 132 | ) 133 | try: 134 | results = self.sparql.query().convert() 135 | # print(results) 136 | name = ( 137 | results["results"]["bindings"][0]["t"]["value"] 138 | if results["results"]["bindings"] 139 | else e 140 | ) 141 | except: 142 | name = e 143 | return name 144 | 145 | def extract_triples_from_query(self, query_str): 146 | pattern = r"(ns:|\?).+ ns:.+ (ns:|\?).+ " 147 | query_l = query_str.split("\n") 148 | triplets = [re.search(pattern, q) for q in query_l] 149 | triplets = [tri.group().strip() for tri in triplets if tri] 150 | triplets = [tri.replace("ns:", "").replace("?", "") for tri in triplets] 151 | triplets = [tri.split(" ") for tri in triplets] 152 | # find id2name in case: not code 153 | triplets = [ 154 | [self.SQL_entity2name(t[0]), t[1], self.SQL_entity2name(t[2])] 155 | for t in triplets 156 | ] 157 | return triplets 158 | 159 | 160 | if __name__ == "__main__": 161 | SPARQLPATH = "http://localhost:3001/sparql" 162 | sparql = SparQL(SPARQLPATH) 163 | 164 | sparql_txt = 'PREFIX ns: \nSELECT DISTINCT ?x\nWHERE {\nFILTER (?x != ns:m.05kkh)\nFILTER (!isLiteral(?x) OR lang(?x) = \'\' OR langMatches(lang(?x), \'en\'))\nns:m.05kkh ns:government.governmental_jurisdiction.governing_officials ?y .\n?y ns:government.government_position_held.office_holder ?x .\n?y ns:government.government_position_held.basic_title ns:m.0fkvn .\nFILTER(NOT EXISTS {?y ns:government.government_position_held.from ?sk0} || \nEXISTS {?y ns:government.government_position_held.from ?sk1 . \nFILTER(xsd:datetime(?sk1) <= "2011-12-31"^^xsd:dateTime) })\nFILTER(NOT EXISTS {?y ns:government.government_position_held.to ?sk2} || \nEXISTS {?y ns:government.government_position_held.to ?sk3 . \nFILTER(xsd:datetime(?sk3) >= "2011-01-01"^^xsd:dateTime) })\n?x ns:government.politician.government_positions_held ?c .\n?c ns:government.government_position_held.from ?num .\nFILTER (?num < "1983-01-03"^^xsd:dateTime) . \n}' 165 | print(sparql.SQL_query2path(sparql_txt)) 166 | print( 167 | sparql.triplet_reverse( 168 | ["m.0340r0", "government.politician.government_positions_held", "sk2"] 169 | ) 170 | ) #'m.0mth_2g','government.government_position_held.to' 171 | print(sparql.SQL_entity2name("m.0bfmhy4")) 172 | print(sparql.query_reasoning_path(sparql_txt)) 173 | # print() 174 | sparql_txt = "#MANUAL SPARQL\nPREFIX ns: \nSELECT DISTINCT ?x\nWHERE {\n\tFILTER (?x != ns:m.0d05w3)\n\t{ \n\t ns:m.0d05w3 ns:location.statistical_region.places_exported_to ?y .\n\t ?y ns:location.imports_and_exports.exported_to ?x . \n\t}\n\tUNION\n\t{\n\t ns:m.0d05w3 ns:location.statistical_region.places_imported_from ?y .\n\t ?y ns:location.imports_and_exports.imported_from ?x . \n\t}?x ns:location.location.time_zones ns:m.03bdv . \n}" 175 | paths = sparql.query_reasoning_path(sparql_txt)[1] 176 | print(paths, len(paths)) 177 | 178 | sparql_txt = "PREFIX rdf: PREFIX rdfs: PREFIX : SELECT (?x0 AS ?value) WHERE { SELECT DISTINCT ?x0 WHERE { ?x0 :type.object.type :language.language_writing_type . ?x1 :type.object.type :language.language_writing_system . {SELECT (MAX(?y2) AS ?x2) WHERE { ?x0 :language.language_writing_type.writing_systems ?x1 . ?x1 :language.language_writing_system.used_from ?x2 . FILTER ( ?x0 != ?x1 && ?x0 != ?x2 && ?x1 != ?x2 ) } }" 179 | # sparql_txt = re.sub(r' FILTER \( \?y.+ \?y\d *\)','', sparql_txt) 180 | paths = sparql.query_reasoning_path(sparql_txt)[1] 181 | print(paths, len(paths)) 182 | 183 | print(sparql.query(sparql_txt, variable="value")) 184 | print( 185 | sparql.query_object( 186 | "m.02zb0l", "user.patrick.default_domain.warship_v1_1.commissioned" 187 | ) 188 | ) 189 | -------------------------------------------------------------------------------- /discriminative_cert/discriminative_prompts.py: -------------------------------------------------------------------------------- 1 | ZERO_PROMPT = """A reasoning path is a sequence of triples that can be used to derive the answer of given question. A valid reasoning path should follow these rules: 2 | 1. No factual errors: Each triple in the reasoning path should adhere to real-world factual knowledge. 3 | 2. Coherence: The tail entity of the previous triple should be the head entity of the next triple. 4 | 3. Correctness: The reasoning path should lead to the correct answer at the last tail entity. 5 | 6 | Given this reasoning path, do you think this is a valid path to derive the answer of given question? If yes please answer "YES", otherwise please answer "NO". 7 | 8 | Question: 9 | {question} 10 | 11 | Answer: 12 | {answer} 13 | 14 | Reasoning path: 15 | {path} 16 | """ 17 | 18 | ZERO_COT_PROMPT = """A reasoning path is a sequence of triples that can be used to derive the answer of given question. A valid reasoning path should follow these rules: 19 | 1. No factual errors: Each triple in the reasoning path should adhere to real-world factual knowledge. 20 | 2. Coherence: The tail entity of the previous triple should be the head entity of the next triple. 21 | 3. Correctness: The reasoning path should lead to the correct answer at the last tail entity. 22 | 23 | Given this reasoning path, do you think this is a valid path to derive the answer of given question? If yes please answer "YES", otherwise please answer "NO". Let's think it step by step. 24 | 25 | Question: 26 | {question} 27 | 28 | Answer: 29 | {answer} 30 | 31 | Reasoning path: 32 | {path} 33 | """ 34 | 35 | 36 | FEWSHOT_PROMPT = """A reasoning path is a sequence of triples that can be used to derive the answer of given question. A valid reasoning path should follow these rules: 37 | 1. No factual errors: Each triple in the reasoning path should adhere to real-world factual knowledge. 38 | 2. Coherence: The tail entity of the previous triple should be the head entity of the next triple. 39 | 3. Correctness: The reasoning path should lead to the correct answer at the last tail entity. 40 | 41 | Given this reasoning path, do you think this is a valid path to derive the answer of given question? If yes please answer "YES", otherwise please answer "NO". Here are some examples: 42 | 43 | ## Input: 44 | Question: 45 | What type of government is used in the country with Northern District? 46 | 47 | Answer: 48 | Parliamentary system 49 | 50 | Reasoning Paths: 51 | Step 1: Northern District -> location.administrative_division.first_level_division_of -> Israel 52 | Step 2: Israel -> government.form_of_government.countries -> Parliamentary system 53 | 54 | ## Output: 55 | YES 56 | 57 | ## Input: 58 | Question: 59 | Where is the home stadium of the team who won the 1946 World Series championship? 60 | 61 | Answer: 62 | Busch Stadium 63 | 64 | Reasoning Paths: 65 | Step 1: 1946 World Series -> sports.sports_team.championships -> St. Louis Cardinals 66 | Step 2: St. Louis Cardinals -> sports.sports_team.arena_stadium -> Roger Dean Stadium 67 | 68 | ## Output: 69 | NO 70 | 71 | ## Input: 72 | Question: 73 | In which American Southern City did the ""Downs"" composer die? 74 | 75 | Answer: 76 | New Orleans 77 | 78 | Reasoning Paths: 79 | Step 1: Alex Chilton -> people.deceased_person.place_of_death -> New Orleans 80 | Step 2: Downs -> music.composer.compositions -> Alex Chilton 81 | 82 | ## Output: 83 | NO 84 | 85 | ## Input: 86 | Question: 87 | Where was the main artist featured in the Rihanna: Live in Concert Tour raised? 88 | 89 | Answer: 90 | Saint Michael Parish 91 | 92 | Reasoning Paths: 93 | Step 1: Rihanna: Live in Concert Tour -> music.concert_tour.artist -> Rihanna 94 | Step 2: Rihanna -> music.track_contribution.role -> Vocals 95 | 96 | ## Output: 97 | NO 98 | 99 | ## Input: 100 | Question: 101 | {question} 102 | 103 | Answer: 104 | {answer} 105 | 106 | Reasoning path: 107 | {path} 108 | 109 | ## Output: 110 | """ 111 | 112 | FEWSHOT_COT_PROMPT = """A reasoning path is a sequence of triples that can be used to derive the answer of given question. A valid reasoning path should follow these rules: 113 | 1. No factual errors: Each triple in the reasoning path should adhere to real-world factual knowledge. 114 | 2. Coherence: The tail entity of the previous triple should be the head entity of the next triple. 115 | 3. Correctness: The reasoning path should lead to the correct answer at the last tail entity. 116 | 117 | Given this reasoning path, do you think this is a valid path to derive the answer of given question? If yes please answer "YES", otherwise please answer "NO". Here are some examples: 118 | 119 | ## Input: 120 | Question: 121 | What type of government is used in the country with Northern District? 122 | Answer: 123 | Parliamentary system 124 | 125 | Reasoning Paths: 126 | Step 1: Northern District -> location.administrative_division.first_level_division_of -> Israel 127 | Step 2: Israel -> government.form_of_government.countries -> Parliamentary system 128 | 129 | ## Output: 130 | This reasoning path indicates that: 131 | 1. "Northern District" is a location within some country. 132 | 2. The reasoning path mentions "Northern District -> location.administrative_division.first_level_division_of -> Israel," indicating that the Northern District is part of Israel. 133 | 3. It further states "Israel -> government.form_of_government.countries," suggesting that Israel's form of government is being discussed. 134 | 4. The last part of the reasoning path indicates that Israel has a "Parliamentary system." 135 | 136 | Based on the provided reasoning paths, it can be concluded that the type of government used in the country with the Northern District (Israel) is a Parliamentary system. Therefore, the path adhaer the definition of valid reasoning path. The answer is "YES" 137 | 138 | ## Input: 139 | Question: 140 | Where is the home stadium of the team who won the 1946 World Series championship? 141 | 142 | Answer: 143 | Busch Stadium 144 | 145 | Reasoning Paths: 146 | Step 1: 1946 World Series -> sports.sports_team.championships -> St. Louis Cardinals 147 | Step 2: St. Louis Cardinals -> sports.sports_team.arena_stadium -> Roger Dean Stadium 148 | 149 | ## Output: 150 | This reasoning path indicates that: 151 | 1. St. Louis Cardinals as the team that won the 1946 World Series 152 | 2. Roger Dean Stadium is the stadium associated with the St. Louis Cardinals. 153 | 154 | However, the home stadium of St. Louis Cardinals is Busch Stadium, not Roger Dean Stadium. Therefore, this path contains factual errors. The answer is "NO". 155 | 156 | ## Input: 157 | Question: 158 | In which American Southern City did the ""Downs"" composer die? 159 | 160 | Answer: 161 | New Orleans 162 | 163 | Reasoning Paths: 164 | Step 1: Alex Chilton -> people.deceased_person.place_of_death -> New Orleans 165 | Step 2: Downs -> music.composer.compositions -> Alex Chilton 166 | 167 | ## Output: 168 | This reasoning path indicates that: 169 | 1. Alex Chilton was dead in New Orleans 170 | 2. The composition of Downs is Alex Chilton 171 | 172 | Even through the first step of reasoning path leads to the correct answer, which is New Orleans, the reasoning path is not coherent. To answer the question, we need to first the composition of Downs which is Alex Chilton, then we find the death place of Alex Chilton which is New Orleans. Therefore, the answer is "NO". 173 | 174 | ## Input: 175 | Question: 176 | Where was the main artist featured in the Rihanna: Live in Concert Tour raised? 177 | 178 | Answer: 179 | Saint Michael Parish 180 | 181 | Reasoning Paths: 182 | Step 1: Rihanna: Live in Concert Tour -> music.concert_tour.artist -> Rihanna 183 | Step 2: Rihanna -> music.track_contribution.role -> Vocals 184 | 185 | ## Output: 186 | This reasoning path indicates that: 187 | 1. The artist of Rihanna: Live in Concert Tour is Rihanna 188 | 2. Rihanna is a vocal artist 189 | 190 | Even through there are no factual erros and the reasoning path is coherent, the reasoning path does not lead to the correct answer. The question asks for the birth place of the main artist, not the role of the artist. Therefore, the answer is "NO". 191 | 192 | ## Input: 193 | Question: 194 | {question} 195 | 196 | Answer: 197 | {answer} 198 | 199 | Reasoning path: 200 | {path} 201 | 202 | ## Output: 203 | """ 204 | 205 | NEG_FEWSHOT_COT_PROMPT = """ 206 | Given this reasoning path, do you think this is a valid path to answer the question? If yes please answer "YES", otherwise please answer "NO". Let's think it step by step. Here are some examples: 207 | 208 | ## Input: 209 | Question: 210 | What type of government is used in the country with Northern District? 211 | 212 | Reasoning Paths: 213 | Step 1: Northern District -> location.administrative_division.first_level_division_of -> Israel 214 | Step 2: Israel -> government.form_of_government.countries -> Parliamentary system 215 | 216 | ## Output: 217 | This reasoning path indicates that: 218 | 1. "Northern District" is a location within some country. 219 | 2. The reasoning path mentions "Northern District -> location.administrative_division.first_level_division_of -> Israel," indicating that the Northern District is part of Israel. 220 | 3. It further states "Israel -> government.form_of_government.countries," suggesting that Israel's form of government is being discussed. 221 | 4. The last part of the reasoning path indicates that Israel has a "Parliamentary system." 222 | 223 | Therefore, based on the provided reasoning paths, it can be concluded that the type of government used in the country with the Northern District (Israel) is a Parliamentary system. The answer is "YES" 224 | 225 | ## Input: 226 | Question: 227 | Where is the home stadium of the team who won the 1946 World Series championship? 228 | 229 | Reasoning Paths: 230 | 1946 World Series -> sports.sports_team.championships -> St. Louis Cardinals -> sports.sports_team.arena_stadium -> Roger Dean Stadium 231 | 232 | ## Output: 233 | This reasoning path indicates that: 234 | 1. St. Louis Cardinals as the team that won the 1946 World Series 235 | 2. Roger Dean Stadium is the stadium associated with the St. Louis Cardinals. 236 | 237 | However, the home stadium of St. Louis Cardinals is Busch Stadium, not Roger Dean Stadium. Therefore, the answer is "NO" 238 | 239 | ## Input: 240 | Question: 241 | In which American Southern City did the ""Downs"" composer die? 242 | 243 | Reasoning Paths: 244 | Downs -> music.composer.compositions -> Judæo-Persian dialects -> people.deceased_person.place_of_death -> New Orleans 245 | 246 | ## Output: 247 | This reasoning path indicates that: 248 | 1. The composer of "Downs" is Judæo-Persian dialects. 249 | 2. Judæo-Persian dialects was deaded in New Orleans 250 | 251 | However, the composer of Downs is Alex Chilton, not Judæo-Persian dialects. Therefore, the answer is "NO" 252 | 253 | ## Input: 254 | Question: 255 | {question} 256 | 257 | Reasoning path: 258 | {path} 259 | 260 | ## Output: 261 | """ 262 | 263 | NEG_FEWSHOT_PROMPT = """ 264 | Given this reasoning path, do you think this is a valid path to answer the question? If yes please answer "YES", otherwise please answer "NO". Let's think it step by step. Here are some examples: 265 | 266 | ## Input: 267 | Question: 268 | What type of government is used in the country with Northern District? 269 | 270 | Reasoning Paths: 271 | Northern District -> location.administrative_division.first_level_division_of -> Israel -> government.form_of_government.countries -> Parliamentary system 272 | 273 | ## Output: 274 | YES 275 | 276 | ## Input: 277 | Question: 278 | Where is the home stadium of the team who won the 1946 World Series championship? 279 | 280 | Reasoning Paths: 281 | 1946 World Series -> sports.sports_team.championships -> St. Louis Cardinals -> sports.sports_team.arena_stadium -> Roger Dean Stadium 282 | 283 | ## Output: 284 | NO 285 | 286 | ## Input: 287 | Question: 288 | In which American Southern City did the ""Downs"" composer die? 289 | 290 | Reasoning Paths: 291 | Downs -> music.composer.compositions -> Judæo-Persian dialects -> people.deceased_person.place_of_death -> New Orleans 292 | 293 | ## Output: 294 | NO 295 | 296 | ## Input: 297 | Question: 298 | {question} 299 | 300 | Reasoning path: 301 | {path} 302 | 303 | ## Output: 304 | """ 305 | 306 | 307 | NEG_REORDER_FEWSHOT_PROMPT = """ 308 | Given this reasoning path, do you think this is a valid path to answer the question? If yes please answer "YES", otherwise please answer "NO". Let's think it step by step. Here are some examples: 309 | 310 | ## Input: 311 | Question: 312 | What type of government is used in the country with Northern District? 313 | 314 | Reasoning Paths: 315 | Step 1: Northern District -> location.administrative_division.first_level_division_of -> Israel 316 | Step 2: Israel -> government.form_of_government.countries -> Parliamentary system 317 | 318 | ## Output: 319 | YES 320 | 321 | ## Input: 322 | Question: 323 | Where is the home stadium of the team who won the 1946 World Series championship? 324 | 325 | Reasoning Paths: 326 | Step 1: St. Louis Cardinals -> sports.sports_team.arena_stadium -> Busch Stadium 327 | Step 2: 1946 World Series -> sports.sports_team.championships -> St. Louis Cardinals 328 | 329 | ## Output: 330 | NO 331 | 332 | ## Input: 333 | Question: 334 | In which American Southern City did the ""Downs"" composer die? 335 | 336 | Reasoning Paths: 337 | Step 1: Alex Chilton -> people.deceased_person.place_of_death -> New Orleans 338 | Step 2: Downs -> music.composer.compositions -> Alex Chilton 339 | 340 | ## Output: 341 | NO 342 | 343 | ## Input: 344 | Question: 345 | {question} 346 | 347 | Reasoning path: 348 | {path} 349 | 350 | ## Output: 351 | """ 352 | -------------------------------------------------------------------------------- /certify_fact.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | from tqdm import tqdm 4 | import argparse 5 | import os 6 | import re 7 | from sklearn.metrics import classification_report 8 | import collections 9 | import networkx as nx 10 | from generative_cert.utils.utils import build_graph 11 | from evaluate_results import eval_hit 12 | 13 | tqdm.pandas() 14 | 15 | 16 | class GraphProcess(object): 17 | def __init__(self, path) -> None: 18 | self.G = build_graph(path) 19 | 20 | def print(self): 21 | print(nx.draw_shell(self.G)) 22 | 23 | def is_connected_graph(self): 24 | return nx.is_connected(self.G) 25 | 26 | def find_leaf_nodes(self): 27 | return [n for n in self.G.nodes() if self.G.degree[n] == 1] 28 | 29 | def test(self): 30 | check_path = [ 31 | ["Milwaukee Brewers", "sports.sports_team.team_mascot", "Hank"], 32 | [ 33 | "Miller Park", 34 | "sports.sports_facility.home_venue_for", 35 | "Milwaukee Brewers", 36 | ], 37 | ] 38 | print(self.is_connected_graph()) 39 | print(self.find_leaf_nodes()) 40 | 41 | check_path = [ 42 | ["Lauren Conrad", "people.person.profession", "Fashion designer"], 43 | ["Lauren Conrad", "people.place_lived.location", "Los Angeles"], 44 | [ 45 | "Willie Brown", 46 | "government.politician.government_positions_held", 47 | "San Francisco", 48 | ], 49 | ] 50 | print(self.is_connected_graph()) 51 | print(self.find_leaf_nodes()) 52 | 53 | check_path = [ 54 | ["Kiribati", "location.country.capital", "South Tarawa"], 55 | ["Kiribati", "location.location.containedby", "Pacific Ocean"], 56 | ["Japan", "location.statistical_region.places_exported_to", "Japan"], 57 | ] 58 | print(self.is_connected_graph()) 59 | print(self.find_leaf_nodes()) 60 | 61 | 62 | def find_best_triplet(topk_extracted_triplets, prob_thres, ent_thres): 63 | """ 64 | check if topk extracted triplets is valid 65 | Valid condition: exist at least 1 extracted triplet 66 | whose prob in sententence & entity all are not less than threshold 67 | """ 68 | twocheck_max = 0 69 | twocheck_res = None 70 | sumcheck_max = 0 71 | sumcheck_res = None 72 | for triplet_prob in topk_extracted_triplets: 73 | _, prob, head_score, tail_score = triplet_prob.values() 74 | head_score = head_score / 100 if isinstance(head_score, int) else head_score 75 | tail_score = tail_score / 100 if isinstance(tail_score, int) else tail_score 76 | score = (prob + head_score + tail_score) / 3 77 | # skip underthreshold triplet 78 | if not ( 79 | prob >= prob_thres and head_score >= ent_thres and tail_score >= ent_thres 80 | ): 81 | if score > sumcheck_max: 82 | sumcheck_res = triplet_prob 83 | sumcheck_max = score 84 | else: 85 | if score > twocheck_max: 86 | twocheck_res = triplet_prob 87 | twocheck_max = score 88 | return (twocheck_res, twocheck_max) if twocheck_max else (sumcheck_res, 0) 89 | 90 | 91 | def eval_reasoning_path(path): 92 | if not path: 93 | return False 94 | G = GraphProcess(path) 95 | return True if G.is_connected_graph() and len(G.G.nodes()) > 2 else False 96 | 97 | 98 | def certify_prompt_triples(pred_ans, ground_ans, best_triples, prob_thres, ent_thres): 99 | is_correct_ans = eval_hit(pred_ans, ground_ans) 100 | res = { 101 | "p_incorrect_answer": int(not is_correct_ans), 102 | "p_incorrect_reasoning": 0, 103 | "p_factual_error": 0, 104 | "p_coherent_error": 0, 105 | "p_reasoning_ans_error": 0, 106 | "p_steps_error": None, 107 | } 108 | steps_error = [] 109 | only_tris = [] 110 | order_error = False 111 | last_tail = None 112 | # check factual_error 113 | for i, best_triple in enumerate(best_triples): 114 | tri, prob, head_score, tail_score = best_triple 115 | if prob < prob_thres: 116 | res["p_factual_error"] = 1 117 | steps_error.append(i + 1) # due to index start from 0 118 | else: 119 | if not last_tail: 120 | last_tail = [tri[0], tri[2]] 121 | else: 122 | if tri[0] in last_tail: 123 | last_tail = tri[2] 124 | elif tri[2] in last_tail: 125 | last_tail = tri[0] 126 | else: 127 | order_error = True 128 | only_tris.append(tri) 129 | if not only_tris: 130 | res["p_factual_error"] = 1 131 | only_tris = [["", "", ""]] 132 | 133 | res["p_steps_error"] = steps_error 134 | 135 | # check coherent error 136 | res["p_coherent_error"] = (not res["p_factual_error"]) and order_error 137 | 138 | # check answer error 139 | res["p_reasoning_ans_error"] = ( 140 | (not res["p_factual_error"]) 141 | and (not order_error) 142 | and (only_tris[-1][-1] not in ground_ans) 143 | ) 144 | 145 | # check reasoning error 146 | if ( 147 | res["p_factual_error"] 148 | or res["p_coherent_error"] 149 | or res["p_reasoning_ans_error"] 150 | ): # old version: res['p_incorrect_answer'] or res['p_factual_error'] 151 | res["p_incorrect_reasoning"] = 1 152 | else: 153 | res["p_incorrect_reasoning"] = 0 # int(not eval_reasoning_path(only_tris)) 154 | return res 155 | 156 | 157 | def main(args): 158 | eval_path = os.path.join( 159 | args.performance_dir, 160 | f"{args.dataset}_{args.split}_evaluate_llm_prompting.jsonl", 161 | ) 162 | retrieval_path = os.path.join( 163 | args.retrieval_dir, 164 | args.dataset, 165 | args.model_name, 166 | args.split, 167 | "only_test_extract_triplet_skip_unknown_ent.jsonl", 168 | ) 169 | topk = args.top_k 170 | prob_thres = args.prob_thres / 100 171 | ent_thres = args.ent_thres / 100 172 | check_valid_list = [] # contain dic: {'id','is_valid','unvalid_step'} 173 | 174 | # check valid in extracted triplet file 175 | first_line = True 176 | with open(retrieval_path) as fin: 177 | for line in tqdm(fin): 178 | if first_line: 179 | first_line = False 180 | continue 181 | data = json.loads(line) 182 | topk_extracted_triplets = data[f"{topk}_extracted_triplets"] 183 | info = { 184 | "id": data["id"], 185 | "question": None, 186 | "prediction": None, # final answer 187 | "ground_answer": None, # groundtruth answer 188 | "pred_num_steps": None, # the number of reasoning steps in the answer 189 | "sent_anwser": None, # list of sentences in the answer, excluding the final answer 190 | "prediction_reasoning_path": None, 191 | "ground_reasoning_path": None, 192 | "graph": None, 193 | "raw_graph": None, 194 | "hit": None, 195 | "is_valid": True, # default is True 196 | "unvalid_step": -1, # -1 if valid, otherwise is the index of unvalid step 197 | } 198 | if not topk_extracted_triplets: 199 | info = {} 200 | check_valid_list.append(info) 201 | continue 202 | # loop for each step in answer 203 | sent_anwser = [] 204 | prediction_reasoning_path = [] 205 | for step, sent_k_tri in enumerate(topk_extracted_triplets): 206 | sent = sent_k_tri["sentence"] 207 | # skip no meaning sentence 208 | if "we need" in sent: 209 | continue 210 | triplets = sent_k_tri["triplets"] 211 | triplet, score = find_best_triplet(triplets, prob_thres, ent_thres) 212 | # check if extracted triplet is certified 213 | # certified condition: score>=prob_thres 214 | if score < prob_thres and info["is_valid"]: 215 | info["is_valid"] = False 216 | # if uninvalid, point out invalid step 217 | info["unvalid_step"] = step 218 | sent_anwser.append(sent) 219 | prediction_reasoning_path.append(triplet["triplet"]) 220 | info["sent_anwser"] = sent_anwser 221 | info["prediction_reasoning_path"] = prediction_reasoning_path 222 | info["pred_num_steps"] = len(sent_anwser) 223 | check_valid_list.append(info) 224 | 225 | last_line = len(check_valid_list) 226 | with open(eval_path) as fin: 227 | for i, line in enumerate(fin): 228 | if i == last_line: 229 | break 230 | data = json.loads(line) 231 | if ( 232 | check_valid_list[i] 233 | and data.get("prediction", None) 234 | and check_valid_list[i]["id"] == data["id"] 235 | ): # skip wrong format 236 | check_valid_list[i]["question"] = data["question"] 237 | check_valid_list[i]["prediction"] = data["prediction"] 238 | check_valid_list[i]["ground_answer"] = data["ground_truth"] 239 | check_valid_list[i]["hit"] = data["hit"] 240 | if check_valid_list[i]["is_valid"]: # only check if certifying fact 241 | if not eval_reasoning_path( 242 | check_valid_list[i]["prediction_reasoning_path"] 243 | ): 244 | check_valid_list[i]["is_valid"] = False 245 | # if uninvalid, point out invalid step 246 | check_valid_list[i]["unvalid_step"] = "reasoning" 247 | # check_valid_list[i]['hit'] = 0 248 | 249 | # add ground_reasoning_path & answer sequence 250 | groundtruthpath_path = ( 251 | f"../data/{args.dataset}/gpt-3.5-turbo/{args.split}/llm_prompt_response.jsonl" 252 | ) 253 | with open(groundtruthpath_path) as fin: 254 | first_line = True 255 | for i, line in enumerate(fin): 256 | if first_line: 257 | first_line = False 258 | continue 259 | data = json.loads(line) 260 | i -= 1 261 | if check_valid_list[i] and check_valid_list[i]["id"] == data["id"]: 262 | check_valid_list[i]["ground_reasoning_path"] = data[ 263 | "ground_truth_paths" 264 | ] 265 | check_valid_list[i]["sent_anwser"] = [data["reasoning_ans"]] + [ 266 | check_valid_list[i]["sent_anwser"] 267 | ] 268 | 269 | # add graph 270 | test_path = f"../data/{args.dataset}/{args.split}.jsonl" 271 | graph_dic = {} 272 | with open(test_path) as fin: 273 | for line in fin: 274 | data = json.loads(line) 275 | graph_dic[data["id"]] = { 276 | "raw_graph": data["raw_graph"], 277 | "graph": data["graph"], 278 | } 279 | for i in range(len(check_valid_list)): 280 | if check_valid_list[i]: 281 | if not graph_dic.get(check_valid_list[i]["id"], None): 282 | check_valid_list[i] = {} # cannot find subgraph 283 | else: 284 | check_valid_list[i]["raw_graph"] = graph_dic[check_valid_list[i]["id"]][ 285 | "raw_graph" 286 | ] 287 | check_valid_list[i]["graph"] = graph_dic[check_valid_list[i]["id"]][ 288 | "graph" 289 | ] 290 | 291 | # check f1 292 | prediction = [v["is_valid"] for v in check_valid_list if v] 293 | groundtruth = [v["hit"] for v in check_valid_list if v] 294 | num_step = [v["pred_num_steps"] for v in check_valid_list if v] 295 | result = classification_report( 296 | groundtruth, prediction, target_names=["invalid", "valid"], output_dict=True 297 | ) 298 | result_str = classification_report( 299 | groundtruth, prediction, target_names=["invalid", "valid"], output_dict=False 300 | ) 301 | 302 | # print('Hit:', sum(groundtruth)/len(groundtruth)) 303 | # print('Mean of the number of reasoning steps:', sum(num_step)/len(num_step)) 304 | print("The number of reasoning steps:", dict(collections.Counter(num_step))) 305 | print("Result: \n", result_str) 306 | with open( 307 | os.path.join( 308 | args.out_dir, 309 | f"{args.dataset}_{args.split}_retrive_fact_certification.jsonl", 310 | ), 311 | "w", 312 | ) as fout: 313 | fout.write(json.dumps({"args": args.__dict__}) + "\n") 314 | fout.write(json.dumps(result) + "\n") 315 | for v in check_valid_list: 316 | fout.write(json.dumps(v) + "\n") 317 | 318 | 319 | def replace_sparql(id, old_value, df1): 320 | value = df1.loc[df1["id"] == id, "groundtruth_paths"] 321 | if len(value) > 0: 322 | return value.iloc[0] 323 | # print('NotFound') 324 | return old_value 325 | 326 | 327 | def clear_reasoning_path(paths): 328 | if not paths: 329 | return False 330 | for path in paths: 331 | for tri in path: 332 | if re.search("^[mg]\.", tri[0]) or re.search("^[mg]\.", tri[2]): 333 | return False 334 | return True 335 | 336 | 337 | if __name__ == "__main__": 338 | argparser = argparse.ArgumentParser() 339 | argparser.add_argument( 340 | "--performance_dir", type=str, default="experiment_results/evaluate_llm_prompt" 341 | ) 342 | argparser.add_argument("--retrieval_dir", type=str, default="../data") 343 | argparser.add_argument( 344 | "--out_dir", type=str, default="experiment_results/retrive_fact_cert" 345 | ) 346 | argparser.add_argument("--dataset", type=str, default="cwq") # cwq, grailqa 347 | argparser.add_argument("--split", type=str, default="test") 348 | argparser.add_argument("--model_name", type=str, default="gpt-3.5-turbo") 349 | argparser.add_argument("--top_k", type=int, default=10) 350 | argparser.add_argument("--prob_thres", type=int, default=70) 351 | argparser.add_argument("--ent_thres", type=int, default=50) 352 | args = argparser.parse_args() 353 | 354 | main(args) 355 | dic_list = [] 356 | # with open(f'LLMReasoningCert/data/{args.dataset}/gpt-3.5-turbo/{args.split}/sparql_ground_truth_paths.jsonl') as fin: 357 | # for i, line in enumerate(fin): 358 | # data = json.loads(line) 359 | # dic_list.append(data) 360 | # df1 = pd.DataFrame(dic_list) 361 | 362 | dic_list = [] 363 | with open( 364 | f"experiment_results/retrive_fact_cert/{args.dataset}_{args.split}_retrive_fact_certification.jsonl" 365 | ) as fin: 366 | for i, line in enumerate(fin): 367 | data = json.loads(line) 368 | dic_list.append(data) 369 | dic_list = dic_list[2:] 370 | df = pd.DataFrame(dic_list) 371 | # df['ground_reasoning_path'] = df.progress_apply(lambda r: replace_sparql(r['id'],r['ground_reasoning_path'], df1), axis=1) 372 | df.dropna(inplace=True) 373 | df["sent_anwser"] = df["sent_anwser"].map(lambda x: x[0]) 374 | df["eval"] = df.apply( 375 | lambda r: f""""hit": {r['hit']}, "is_valid": {r['is_valid']}, "unvalid_step": {r['unvalid_step']}""", 376 | axis=1, 377 | ) 378 | df = df[ 379 | [ 380 | "question", 381 | "ground_answer", 382 | "sent_anwser", 383 | "prediction_reasoning_path", 384 | "ground_reasoning_path", 385 | "graph", 386 | "raw_graph", 387 | "eval", 388 | ] 389 | ] 390 | 391 | # df = df[df.ground_reasoning_path.map(clear_reasoning_path)] 392 | df.to_csv(f"tmp/{args.dataset}_{args.split}_res.csv", index=False) 393 | --------------------------------------------------------------------------------