├── Freebase-Setup ├── fix_freebase_literal_format.py ├── numeric_properties.txt └── virtuoso.py ├── LLMs ├── LLaMA │ └── src │ │ ├── beam_output.py │ │ ├── beam_output_eva.py │ │ ├── cli_demo.py │ │ ├── llmtuner │ │ ├── __init__.py │ │ ├── api │ │ │ ├── __init__.py │ │ │ ├── app.py │ │ │ └── protocol.py │ │ ├── chat │ │ │ ├── __init__.py │ │ │ └── stream_chat.py │ │ ├── dsets │ │ │ ├── __init__.py │ │ │ ├── loader.py │ │ │ ├── preprocess.py │ │ │ └── utils.py │ │ ├── extras │ │ │ ├── __init__.py │ │ │ ├── callbacks.py │ │ │ ├── constants.py │ │ │ ├── logging.py │ │ │ ├── misc.py │ │ │ ├── patches │ │ │ │ ├── __init__.py │ │ │ │ └── llama_patch.py │ │ │ ├── ploting.py │ │ │ ├── save_and_load.py │ │ │ └── template.py │ │ ├── hparams │ │ │ ├── __init__.py │ │ │ ├── data_args.py │ │ │ ├── finetuning_args.py │ │ │ ├── general_args.py │ │ │ ├── generating_args.py │ │ │ └── model_args.py │ │ ├── tuner │ │ │ ├── __init__.py │ │ │ ├── core │ │ │ │ ├── __init__.py │ │ │ │ ├── adapter.py │ │ │ │ ├── loader.py │ │ │ │ ├── parser.py │ │ │ │ └── utils.py │ │ │ ├── dpo │ │ │ │ ├── __init__.py │ │ │ │ ├── collator.py │ │ │ │ ├── trainer.py │ │ │ │ └── workflow.py │ │ │ ├── ppo │ │ │ │ ├── __init__.py │ │ │ │ ├── trainer.py │ │ │ │ ├── utils.py │ │ │ │ └── workflow.py │ │ │ ├── pt │ │ │ │ ├── __init__.py │ │ │ │ └── workflow.py │ │ │ ├── rm │ │ │ │ ├── __init__.py │ │ │ │ ├── collator.py │ │ │ │ ├── metric.py │ │ │ │ ├── trainer.py │ │ │ │ └── workflow.py │ │ │ ├── sft │ │ │ │ ├── __init__.py │ │ │ │ ├── metric.py │ │ │ │ ├── trainer.py │ │ │ │ └── workflow.py │ │ │ └── tune.py │ │ └── webui │ │ │ ├── __init__.py │ │ │ ├── chat.py │ │ │ ├── common.py │ │ │ ├── components │ │ │ ├── __init__.py │ │ │ ├── chatbot.py │ │ │ ├── data.py │ │ │ ├── eval.py │ │ │ ├── export.py │ │ │ ├── infer.py │ │ │ ├── top.py │ │ │ └── train.py │ │ │ ├── css.py │ │ │ ├── interface.py │ │ │ ├── locales.py │ │ │ ├── manager.py │ │ │ ├── runner.py │ │ │ └── utils.py │ │ ├── pro_model │ │ ├── cross_token.py │ │ ├── gate.py │ │ ├── map_layer.py │ │ ├── pm.py │ │ └── totoken.py │ │ └── train_bash.py └── data_id │ ├── CWQ_Freebase_NQ_test │ └── examples.json │ ├── CWQ_Freebase_NQ_train │ └── examples.json │ ├── WebQSP_Freebase_NQ_test │ └── examples.json │ ├── WebQSP_Freebase_NQ_train │ └── examples.json │ └── dataset_info.json ├── README.md ├── components ├── dataset_utils.py ├── expr_parser.py └── utils.py ├── config.py ├── data_process.py ├── entity_retrieval ├── aqqu_entity_linker.py ├── aqqu_util.py └── surface_index_memory.py ├── eval_final.py ├── executor ├── logic_form_util.py ├── logic_form_util_cwq.py └── sparql_executor.py ├── generation ├── cwq_evaluate.py └── webqsp_evaluate_offcial.py ├── lib └── virtodbc.so ├── ontology ├── README.md ├── domain_dict ├── domain_info ├── fb_roles ├── fb_types ├── full_reverse_properties.json └── reverse_properties ├── parse_sparql_cwq.py ├── parse_sparql_webqsp.py ├── process_NQ.py ├── requirements.txt ├── run_all.sh ├── run_ft.sh └── run_generator_final.py /Freebase-Setup/fix_freebase_literal_format.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | 3 | in_file = "freebase-rdf-latest.gz" 4 | out_file = "freebase-rdf-latest-literal_fixed.gz" 5 | 6 | # datatype strings 7 | datatype_string = {} 8 | datatype_string["type.int"] = "" 9 | datatype_string["type.float"] = "" 10 | datatype_string["type.boolean"] = "" 11 | 12 | # get the properties with literal object value 13 | type_map = {} 14 | with open("numeric_properties.txt", "r") as f_in: 15 | for line in f_in: 16 | line = line.strip() 17 | pred, type = line.split("\t") 18 | type_map[pred] = datatype_string[type] 19 | 20 | # update literal type line by line 21 | f_in = gzip.open(in_file, "r") 22 | f_out = gzip.open(out_file, "w") 23 | line_num = 0 24 | for line in f_in: 25 | line_num += 1 26 | if not line: 27 | continue 28 | subj, pred, obj, rest = line.split("\t") 29 | pred_t = pred[pred.rfind("/")+1:len(pred)-1] 30 | try: 31 | datatype_string = type_map[pred_t] 32 | if "^^" in obj: 33 | pass 34 | else: 35 | if "\"" in obj: 36 | obj = obj + "^^" + datatype_string 37 | else: 38 | obj = "\"" + obj + "\"^^" + datatype_string 39 | line = "\t".join([subj, pred, obj, rest]) 40 | except: 41 | pass 42 | f_out.write(line) 43 | if line_num % 1000000 == 0: 44 | print(line_num) 45 | 46 | f_in.close() 47 | f_out.close() 48 | -------------------------------------------------------------------------------- /Freebase-Setup/virtuoso.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # This script provides a convenient wrapper for the Virtuoso SPARQL server. 4 | # Adapted from Sempre (https://github.com/percyliang/sempre) 5 | 6 | import os 7 | import sys 8 | import subprocess 9 | import argparse 10 | 11 | virtuosoPath = "../virtuoso-opensource" 12 | if not os.path.exists(virtuosoPath): 13 | print(f"{virtuosoPath} does not exist") 14 | sys.exit(1) 15 | 16 | # Virtuoso has two services: the server (isql) and SPARQL endpoint 17 | def isqlPort(port): return 10000 + port 18 | def httpPort(port): return port 19 | 20 | def run(command): 21 | print(f"RUNNING: {command}") 22 | res = subprocess.run(command, shell=True, stdout=subprocess.PIPE) 23 | return res.stdout 24 | 25 | def start(dbPath, port): 26 | 27 | if not os.path.exists(dbPath): 28 | os.mkdir(dbPath) 29 | 30 | # Recommended: 70% of RAM, each buffer is 8K 31 | # Use a fraction of the free RAM. The result may vary across runs. 32 | # memFree = parseInt(`cat /proc/meminfo | grep MemFree | awk '{print $2}'`) # KB 33 | # Use a fraction of the total RAM. The result is the same across runs. 34 | memFree = int(run("cat /proc/meminfo | grep MemTotal | awk '{print $2}'")) # KB 35 | numberOfBuffers = memFree * 0.15 / 8 36 | maxDirtyBuffers = numberOfBuffers / 2 37 | print(f"{memFree} KB free, using {numberOfBuffers} buffers, {maxDirtyBuffers} dirty buffers") 38 | 39 | # Configuration options: 40 | # http://docs.openlinksw.com/virtuoso/dbadm.html 41 | # http://virtuoso.openlinksw.com/dataspace/doc/dav/wiki/Main/VirtConfigScale 42 | config = ( 43 | f"[Database]\n" 44 | f"DatabaseFile = {dbPath}/virtuoso.db\n" 45 | f"ErrorLogFile = {dbPath}/virtuoso.log\n" 46 | f"LockFile = {dbPath}/virtuoso.lck\n" 47 | f"TransactionFile = {dbPath}/virtuoso.trx\n" 48 | f"xa_persistent_file = {dbPath}/virtuoso.pxa\n" 49 | f"ErrorLogLevel = 7\n" 50 | f"FileExtend = 200\n" 51 | f"MaxCheckpointRemap = 2000\n" 52 | f"Striping = 0\n" 53 | f"TempStorage = TempDatabase\n" 54 | f"\n" 55 | f"[TempDatabase]\n" 56 | f"DatabaseFile = {dbPath}/virtuoso-temp.db\n" 57 | f"TransactionFile = {dbPath}/virtuoso-temp.trx\n" 58 | f"MaxCheckpointRemap = 2000\n" 59 | f"Striping = 0\n" 60 | f"\n" 61 | f"[Parameters]\n" 62 | f"ServerPort = {isqlPort(port)}\n" 63 | f"LiteMode = 0\n" 64 | f"DisableUnixSocket = 1\n" 65 | f"DisableTcpSocket = 0\n" 66 | f"ServerThreads = 100 ; increased from 20\n" 67 | f"CheckpointInterval = 60\n" 68 | f"O_DIRECT = 1 ; increased from 0\n" 69 | f"CaseMode = 2\n" 70 | f"MaxStaticCursorRows = 100000\n" 71 | f"CheckpointAuditTrail = 0\n" 72 | f"AllowOSCalls = 0\n" 73 | f"SchedulerInterval = 10\n" 74 | f"DirsAllowed = .\n" 75 | f"ThreadCleanupInterval = 0\n" 76 | f"ThreadThreshold = 10\n" 77 | f"ResourcesCleanupInterval = 0\n" 78 | f"FreeTextBatchSize = 100000\n" 79 | # f"SingleCPU = 0\n" 80 | f"PrefixResultNames = 0\n" 81 | f"RdfFreeTextRulesSize = 100\n" 82 | f"IndexTreeMaps = 256\n" 83 | f"MaxMemPoolSize = 200000000\n" 84 | f"PrefixResultNames = 0\n" 85 | f"MacSpotlight = 0\n" 86 | f"IndexTreeMaps = 64\n" 87 | f"NumberOfBuffers = {numberOfBuffers}\n" 88 | f"MaxDirtyBuffers = {maxDirtyBuffers}\n" 89 | f"\n" 90 | f"[SPARQL]\n" 91 | f"ResultSetMaxRows = 50000\n" 92 | f"MaxQueryCostEstimationTime = 600 ; in seconds (increased)\n" 93 | f"MaxQueryExecutionTime = 180; in seconds (increased)\n" 94 | f"\n" 95 | f"[HTTPServer]\n" 96 | f"ServerPort = {httpPort(port)}\n" 97 | f"Charset = UTF-8\n" 98 | f"ServerThreads = 15 ; increased from unknown\n" 99 | ) 100 | 101 | configPath = f"{dbPath}/virtuoso.ini" 102 | print(config) 103 | print() 104 | print(configPath) 105 | print(f"==== Starting Virtuoso server for {dbPath} on port {port}...") 106 | with open(configPath, 'w') as f: 107 | f.write(config) 108 | run(f"{virtuosoPath}/bin/virtuoso-t +configfile {configPath} +wait") 109 | 110 | def stop(port): 111 | run(f"echo 'shutdown;' | {virtuosoPath}/bin/isql localhost:{isqlPort(port)}") 112 | 113 | def status(port): 114 | run(f"echo 'status();' | {virtuosoPath}/bin/isql localhost:{isqlPort(port)}") 115 | 116 | ############################################################ 117 | # Main 118 | 119 | if __name__ == "__main__": 120 | parser = argparse.ArgumentParser(description="manage Virtuoso services") 121 | parser.add_argument("action", type=str, help="start or stop") 122 | parser.add_argument("port", type=int, help="port for the SPARQL HTTP endpoint") 123 | parser.add_argument("-d", "--db-path", type=str, help="path to the db directory") 124 | 125 | args = parser.parse_args() 126 | if args.action == "start": 127 | if not args.db_path: 128 | print("please specify path to the db directory with -d") 129 | sys.exit() 130 | 131 | if not os.path.isdir(args.db_path): 132 | print("the path specified does not exist") 133 | sys.exit() 134 | 135 | start(args.db_path, args.port) 136 | elif args.action == "stop": 137 | stop(args.port) 138 | else: 139 | print(f"invalid action: ${args.action}") 140 | sys.exit() 141 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/beam_output.py: -------------------------------------------------------------------------------- 1 | from llmtuner import ChatModel 2 | 3 | 4 | def main(): 5 | chat_model = ChatModel() 6 | query = "Generate a Logical Form query that retrieves the information corresponding to the given question. \nQuestion: { what does jamaican people speak }" 7 | output = chat_model.chat_beam(query) 8 | print(output) 9 | 10 | 11 | if __name__ == "__main__": 12 | main() 13 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/beam_output_eva.py: -------------------------------------------------------------------------------- 1 | from llmtuner import ChatModel 2 | import json 3 | from tqdm import tqdm 4 | import random 5 | import re 6 | import os 7 | from llmtuner.tuner.core import get_infer_args 8 | from pro_model.totoken import data_load_retrieval, get_extra_input_ids 9 | 10 | def dump_json(obj, fname, indent=4, mode='w' ,encoding="utf8", ensure_ascii=False): 11 | if "b" in mode: 12 | encoding = None 13 | with open(fname, "w", encoding=encoding) as f: 14 | return json.dump(obj, f, indent=indent, ensure_ascii=ensure_ascii) 15 | 16 | def main(): 17 | model_args, data_args, _, _ = get_infer_args() 18 | id2rel, id2ent, id2sub = data_load_retrieval(data_args) 19 | chat_model = ChatModel() 20 | output_data = [] 21 | with open(os.path.join(data_args.dataset_dir,data_args.dataset,'examples.json'), 'r', encoding='utf-8') as f: 22 | json_data = json.load(f) 23 | # random.shuffle(json_data) 24 | total_lines = 0 25 | matched_lines = 0 26 | will_matched_lines = 0 27 | 28 | # 2. 读取每一行 29 | for data in tqdm(json_data): 30 | total_lines += 1 31 | query = data['instruction']+data['input'] 32 | id = data['ID'] 33 | entity = id2ent[id] 34 | relation = id2rel[id] 35 | subgraph = id2sub[id] 36 | predict = chat_model.chat_beam(query,entity,relation,subgraph) 37 | predict = [p[0] for p in predict] 38 | output_data.append({'label':data['output'],'predict':predict}) 39 | for p in predict: 40 | # 4. 检查"label"和"predict"的值是否相等 41 | if data['output'] == p: 42 | matched_lines += 1 43 | break 44 | for p in predict: 45 | # 4. 检查"label"和"predict"的值是否相等 46 | if re.sub(r'\[.*?\]', '', data['output']) == re.sub(r'\[.*?\]', '', p): 47 | will_matched_lines += 1 48 | break 49 | 50 | 51 | # 5. 计算相等的行的数量 52 | print(f"Total lines: {total_lines}") 53 | print(f"Matched lines: {matched_lines}") 54 | print(f"Will Matched lines: {will_matched_lines}") 55 | 56 | # 6. 计算相等行的占比 57 | percentage = (matched_lines / total_lines) * 100 58 | print(f"Percentage of matched lines: {percentage:.2f}%") 59 | # 6. 计算相等行的占比 60 | will_percentage = (will_matched_lines / total_lines) * 100 61 | print(f"Percentage of will matched lines: {will_percentage:.2f}%") 62 | 63 | 64 | output_dir = os.path.join(os.path.dirname(model_args.checkpoint_dir[0]),'evaluation_beam/generated_predictions.jsonl') 65 | if not os.path.exists(os.path.dirname(output_dir)): 66 | os.makedirs(os.path.dirname(output_dir)) 67 | # with open(output_dir, 'w') as f: 68 | # for item in output_data: 69 | # json_string = json.dumps(item) 70 | # f.write(json_string + '\n') 71 | run_prediction(output_data,os.path.dirname(output_dir),output_predictions=True) 72 | 73 | def run_prediction(output_data,output_dir,output_predictions=True): 74 | print() 75 | print('Start predicting ') 76 | 77 | ex_cnt = 0 78 | contains_ex_cnt = 0 79 | output_list = [] 80 | real_total = 0 81 | for i,pred in enumerate(output_data): 82 | predictions = pred['predict'] 83 | gen_label = pred['label'] 84 | 85 | output_list.append({ 86 | 'predictions':predictions, 87 | 'gen_label':gen_label, 88 | }) 89 | 90 | if predictions[0].lower()==gen_label.lower(): 91 | ex_cnt+=1 92 | 93 | if any([x.lower()==gen_label.lower() for x in predictions]): 94 | contains_ex_cnt+=1 95 | 96 | if gen_label.lower()!='null': 97 | real_total+=1 98 | 99 | 100 | print(f"""total:{len(output_list)}, 101 | ex_cnt:{ex_cnt}, 102 | ex_rate:{ex_cnt/len(output_list)}, 103 | real_ex_rate:{ex_cnt/real_total}, 104 | contains_ex_cnt:{contains_ex_cnt}, 105 | contains_ex_rate:{contains_ex_cnt/len(output_list)} 106 | real_contains_ex_rate:{contains_ex_cnt/real_total} 107 | """) 108 | 109 | 110 | if output_predictions: 111 | file_path = os.path.join(output_dir,f'beam_test_top_k_predictions.json') 112 | 113 | gen_statistics_file_path = os.path.join(output_dir,f'beam_test_gen_statistics.json') 114 | gen_statistics = { 115 | 'total':len(output_list), 116 | 'exmatch_num': ex_cnt, 117 | 'exmatch_rate': ex_cnt/len(output_list), 118 | 'real_exmatch_rate':ex_cnt/real_total, 119 | 'contains_ex_num':contains_ex_cnt, 120 | 'contains_ex_rate':contains_ex_cnt/len(output_list), 121 | 'real_contains_ex_rate':contains_ex_cnt/real_total 122 | } 123 | dump_json(output_list, file_path, indent=4) 124 | dump_json(gen_statistics, gen_statistics_file_path,indent=4) 125 | 126 | 127 | 128 | if __name__ == "__main__": 129 | main() 130 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/cli_demo.py: -------------------------------------------------------------------------------- 1 | from llmtuner import ChatModel 2 | 3 | 4 | def main(): 5 | chat_model = ChatModel() 6 | history = [] 7 | print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") 8 | 9 | while True: 10 | try: 11 | query = input("\nUser: ") 12 | except UnicodeDecodeError: 13 | print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") 14 | continue 15 | except Exception: 16 | raise 17 | 18 | if query.strip() == "exit": 19 | break 20 | 21 | if query.strip() == "clear": 22 | history = [] 23 | print("History has been removed.") 24 | continue 25 | 26 | print("Assistant: ", end="", flush=True) 27 | 28 | response = "" 29 | for new_text in chat_model.stream_chat(query, history): 30 | print(new_text, end="", flush=True) 31 | response += new_text 32 | print() 33 | 34 | history = history + [(query, response)] 35 | 36 | 37 | if __name__ == "__main__": 38 | main() 39 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/__init__.py: -------------------------------------------------------------------------------- 1 | # Level: api, webui > chat > tuner > dsets > extras, hparams 2 | 3 | from llmtuner.api import create_app 4 | from llmtuner.chat import ChatModel 5 | from llmtuner.tuner import export_model, run_exp 6 | from llmtuner.webui import create_ui, create_web_demo 7 | 8 | 9 | __version__ = "0.1.8" 10 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/api/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.api.app import create_app 2 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/api/app.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | from fastapi import FastAPI, HTTPException 3 | from fastapi.middleware.cors import CORSMiddleware 4 | from contextlib import asynccontextmanager 5 | from sse_starlette import EventSourceResponse 6 | from typing import List, Tuple 7 | 8 | from llmtuner.extras.misc import torch_gc 9 | from llmtuner.chat import ChatModel 10 | from llmtuner.api.protocol import ( 11 | Role, 12 | Finish, 13 | ModelCard, 14 | ModelList, 15 | ChatMessage, 16 | DeltaMessage, 17 | ChatCompletionRequest, 18 | ChatCompletionResponse, 19 | ChatCompletionStreamResponse, 20 | ChatCompletionResponseChoice, 21 | ChatCompletionResponseStreamChoice, 22 | ChatCompletionResponseUsage 23 | ) 24 | 25 | 26 | @asynccontextmanager 27 | async def lifespan(app: FastAPI): # collects GPU memory 28 | yield 29 | torch_gc() 30 | 31 | 32 | def create_app(chat_model: ChatModel) -> FastAPI: 33 | app = FastAPI(lifespan=lifespan) 34 | 35 | app.add_middleware( 36 | CORSMiddleware, 37 | allow_origins=["*"], 38 | allow_credentials=True, 39 | allow_methods=["*"], 40 | allow_headers=["*"], 41 | ) 42 | 43 | @app.get("/v1/models", response_model=ModelList) 44 | async def list_models(): 45 | model_card = ModelCard(id="gpt-3.5-turbo") 46 | return ModelList(data=[model_card]) 47 | 48 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) 49 | async def create_chat_completion(request: ChatCompletionRequest): 50 | if len(request.messages) < 1 or request.messages[-1].role != Role.USER: 51 | raise HTTPException(status_code=400, detail="Invalid request") 52 | 53 | query = request.messages[-1].content 54 | prev_messages = request.messages[:-1] 55 | if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: 56 | system = prev_messages.pop(0).content 57 | else: 58 | system = None 59 | 60 | history = [] 61 | if len(prev_messages) % 2 == 0: 62 | for i in range(0, len(prev_messages), 2): 63 | if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT: 64 | history.append([prev_messages[i].content, prev_messages[i+1].content]) 65 | 66 | if request.stream: 67 | generate = predict(query, history, system, request) 68 | return EventSourceResponse(generate, media_type="text/event-stream") 69 | 70 | response, (prompt_length, response_length) = chat_model.chat( 71 | query, history, system, 72 | do_sample=request.do_sample, 73 | temperature=request.temperature, 74 | top_p=request.top_p, 75 | max_new_tokens=request.max_tokens 76 | ) 77 | 78 | usage = ChatCompletionResponseUsage( 79 | prompt_tokens=prompt_length, 80 | completion_tokens=response_length, 81 | total_tokens=prompt_length+response_length 82 | ) 83 | 84 | choice_data = ChatCompletionResponseChoice( 85 | index=0, 86 | message=ChatMessage(role=Role.ASSISTANT, content=response), 87 | finish_reason=Finish.STOP 88 | ) 89 | 90 | return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage) 91 | 92 | async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): 93 | choice_data = ChatCompletionResponseStreamChoice( 94 | index=0, 95 | delta=DeltaMessage(role=Role.ASSISTANT), 96 | finish_reason=None 97 | ) 98 | chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) 99 | yield chunk.json(exclude_unset=True, ensure_ascii=False) 100 | 101 | for new_text in chat_model.stream_chat( 102 | query, history, system, 103 | do_sample=request.do_sample, 104 | temperature=request.temperature, 105 | top_p=request.top_p, 106 | max_new_tokens=request.max_tokens 107 | ): 108 | if len(new_text) == 0: 109 | continue 110 | 111 | choice_data = ChatCompletionResponseStreamChoice( 112 | index=0, 113 | delta=DeltaMessage(content=new_text), 114 | finish_reason=None 115 | ) 116 | chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) 117 | yield chunk.json(exclude_unset=True, ensure_ascii=False) 118 | 119 | choice_data = ChatCompletionResponseStreamChoice( 120 | index=0, 121 | delta=DeltaMessage(), 122 | finish_reason=Finish.STOP 123 | ) 124 | chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) 125 | yield chunk.json(exclude_unset=True, ensure_ascii=False) 126 | yield "[DONE]" 127 | 128 | return app 129 | 130 | 131 | if __name__ == "__main__": 132 | chat_model = ChatModel() 133 | app = create_app(chat_model) 134 | uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) 135 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/api/protocol.py: -------------------------------------------------------------------------------- 1 | import time 2 | from enum import Enum 3 | from pydantic import BaseModel, Field 4 | from typing import List, Optional 5 | 6 | 7 | class Role(str, Enum): 8 | USER = "user" 9 | ASSISTANT = "assistant" 10 | SYSTEM = "system" 11 | 12 | 13 | class Finish(str, Enum): 14 | STOP = "stop" 15 | LENGTH = "length" 16 | 17 | 18 | class ModelCard(BaseModel): 19 | id: str 20 | object: Optional[str] = "model" 21 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 22 | owned_by: Optional[str] = "owner" 23 | root: Optional[str] = None 24 | parent: Optional[str] = None 25 | permission: Optional[list] = [] 26 | 27 | 28 | class ModelList(BaseModel): 29 | object: Optional[str] = "list" 30 | data: Optional[List[ModelCard]] = [] 31 | 32 | 33 | class ChatMessage(BaseModel): 34 | role: Role 35 | content: str 36 | 37 | 38 | class DeltaMessage(BaseModel): 39 | role: Optional[Role] = None 40 | content: Optional[str] = None 41 | 42 | 43 | class ChatCompletionRequest(BaseModel): 44 | model: str 45 | messages: List[ChatMessage] 46 | do_sample: Optional[bool] = True 47 | temperature: Optional[float] = None 48 | top_p: Optional[float] = None 49 | n: Optional[int] = 1 50 | max_tokens: Optional[int] = None 51 | stream: Optional[bool] = False 52 | 53 | 54 | class ChatCompletionResponseChoice(BaseModel): 55 | index: int 56 | message: ChatMessage 57 | finish_reason: Finish 58 | 59 | 60 | class ChatCompletionResponseStreamChoice(BaseModel): 61 | index: int 62 | delta: DeltaMessage 63 | finish_reason: Optional[Finish] = None 64 | 65 | 66 | class ChatCompletionResponseUsage(BaseModel): 67 | prompt_tokens: int 68 | completion_tokens: int 69 | total_tokens: int 70 | 71 | 72 | class ChatCompletionResponse(BaseModel): 73 | id: Optional[str] = "chatcmpl-default" 74 | object: Optional[str] = "chat.completion" 75 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 76 | model: str 77 | choices: List[ChatCompletionResponseChoice] 78 | usage: ChatCompletionResponseUsage 79 | 80 | 81 | class ChatCompletionStreamResponse(BaseModel): 82 | id: Optional[str] = "chatcmpl-default" 83 | object: Optional[str] = "chat.completion.chunk" 84 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 85 | model: str 86 | choices: List[ChatCompletionResponseStreamChoice] 87 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/chat/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.chat.stream_chat import ChatModel 2 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/chat/stream_chat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pro_model.pm import PromptTuningModelForCausalLM 3 | from pro_model.totoken import data_load_retrieval, get_extra_input_ids 4 | from typing import Any, Dict, Generator, List, Optional, Tuple 5 | from threading import Thread 6 | from transformers import GenerationConfig, TextIteratorStreamer 7 | 8 | from llmtuner.extras.misc import dispatch_model, get_logits_processor 9 | from llmtuner.extras.template import get_template_and_fix_tokenizer 10 | from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer 11 | import os 12 | def get_latest_checkpoint_dir(checkpoint_base_dir): 13 | checkpoint_dirs = [d for d in os.listdir(checkpoint_base_dir) if os.path.isdir(os.path.join(checkpoint_base_dir, d))] 14 | if not checkpoint_dirs: 15 | return None 16 | checkpoint_dirs.sort() 17 | latest_checkpoint_dir = checkpoint_dirs[-1] 18 | return latest_checkpoint_dir 19 | 20 | 21 | class ChatModel: 22 | 23 | def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: 24 | model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args) 25 | self.data_args = data_args 26 | self.model_args = model_args 27 | # last_checkpoint_dir = os.path.join(model_args.checkpoint_dir[0], get_latest_checkpoint_dir(model_args.checkpoint_dir[0])) 28 | last_checkpoint_dir = model_args.checkpoint_dir[0] 29 | model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) 30 | model = PromptTuningModelForCausalLM(model, model_args.soft_prompt_length).to(dtype=torch.bfloat16) 31 | 32 | print("load model checkpoint : {}.".format(last_checkpoint_dir+ '/pytorch_model.bin')) 33 | state_dicts = torch.load(last_checkpoint_dir+ '/pytorch_model.bin',map_location='cpu') 34 | # state_dicts = torch.load('/data2/lixinhang/aaai/chatkbqa2/ChatKBQA/Reading/LLaMA-2-7b-hf/WebQSP_default/checkpoint2/checkpoint-105/pytorch_model.bin',map_location='cpu') 35 | if "model.base_model.model.lm_head.0.weight" in state_dicts: 36 | new_key = "model.base_model.model.lm_head.weight" 37 | state_dicts[new_key] =state_dicts.pop("model.base_model.model.lm_head.0.weight") 38 | 39 | model.load_state_dict(state_dicts) 40 | 41 | 42 | self.model = model.to(dtype=torch.bfloat16) 43 | self.model.eval() 44 | 45 | 46 | self.tokenizer.padding_side = "left" 47 | self.model = dispatch_model(self.model) 48 | self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) 49 | self.system_prompt = data_args.system_prompt 50 | 51 | def process_args( 52 | self, 53 | query: str, 54 | entity: str, 55 | relation: str, 56 | subgraph: str, 57 | history: Optional[List[Tuple[str, str]]] = None, 58 | system: Optional[str] = None, 59 | **input_kwargs 60 | ) -> Tuple[Dict[str, Any], int]: 61 | system = system or self.system_prompt 62 | 63 | prompt, _ = self.template.encode_oneturn( 64 | tokenizer=self.tokenizer, query=query, resp="", history=history, system=system 65 | ) 66 | entity_ids, relation_ids, subgraph_ids = get_extra_input_ids(self.tokenizer, 67 | entity, relation, subgraph, 68 | self.data_args.extra_infor_len) 69 | gate_ids = [self.tokenizer.bos_token_id] + prompt 70 | if len(gate_ids) > self.data_args.gate_len-1: 71 | gate_ids = gate_ids[:self.data_args.gate_len-1] 72 | pad_len = self.data_args.gate_len-len(gate_ids)-1 73 | pad_gate_ids = gate_ids + [self.tokenizer.pad_token_id]*pad_len + [self.tokenizer.eos_token_id] 74 | 75 | 76 | 77 | input_ids = torch.tensor([prompt], device=self.model.model.device) 78 | gate_ids = torch.tensor([pad_gate_ids], device=self.model.model.device) 79 | 80 | entity_ids = torch.tensor([entity_ids], device=self.model.model.device) 81 | relation_ids = torch.tensor([relation_ids], device=self.model.model.device) 82 | subgraph_ids = torch.tensor([subgraph_ids], device=self.model.model.device) 83 | 84 | prompt_length = len(input_ids[0]) 85 | 86 | do_sample = input_kwargs.pop("do_sample", None) 87 | temperature = input_kwargs.pop("temperature", None) 88 | top_p = input_kwargs.pop("top_p", None) 89 | top_k = input_kwargs.pop("top_k", None) 90 | repetition_penalty = input_kwargs.pop("repetition_penalty", None) 91 | max_length = input_kwargs.pop("max_length", None) 92 | max_new_tokens = input_kwargs.pop("max_new_tokens", None) 93 | 94 | generating_args = self.generating_args.to_dict() 95 | generating_args.update(dict( 96 | do_sample=False, 97 | num_beams = generating_args["num_beams"], 98 | num_beam_groups = generating_args["num_beams"], 99 | diversity_penalty = 1.0, 100 | num_return_sequences=generating_args["num_beams"], 101 | temperature=temperature or generating_args["temperature"], 102 | top_p=top_p or generating_args["top_p"], 103 | top_k=top_k or generating_args["top_k"], 104 | repetition_penalty=repetition_penalty or generating_args["repetition_penalty"], 105 | eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, 106 | pad_token_id=self.tokenizer.pad_token_id 107 | )) 108 | 109 | if max_length: 110 | generating_args.pop("max_new_tokens", None) 111 | generating_args["max_length"] = max_length 112 | 113 | if max_new_tokens: 114 | generating_args.pop("max_length", None) 115 | generating_args["max_new_tokens"] = max_new_tokens 116 | 117 | gen_kwargs = dict( 118 | inputs =input_ids, 119 | gate_ids = gate_ids, 120 | entitys = entity_ids, 121 | relations = relation_ids, 122 | subgraphs = subgraph_ids, 123 | generation_config=GenerationConfig(**generating_args), 124 | logits_processor=get_logits_processor() 125 | ) 126 | 127 | 128 | return gen_kwargs, prompt_length 129 | 130 | @torch.inference_mode() 131 | def chat_beam( 132 | self, 133 | query: str, 134 | entity: str, 135 | relation: str, 136 | subgraph: str, 137 | history: Optional[List[Tuple[str, str]]] = None, 138 | system: Optional[str] = None, 139 | **input_kwargs 140 | ) -> Tuple[str, Tuple[int, int]]: 141 | gen_kwargs, prompt_length = self.process_args(query, entity, relation, subgraph, history, system, **input_kwargs) 142 | generation_output = self.model.inference(**gen_kwargs,return_dict_in_generate=True, output_scores=True) 143 | 144 | outputs = [g[prompt_length:] for g in generation_output['sequences'].tolist()] 145 | outputs_scores = [s for s in generation_output['sequences_scores'].tolist()] 146 | response = [self.tokenizer.decode(o, skip_special_tokens=True) for o in outputs] 147 | 148 | response_dict = {} 149 | for resp, score in zip(response, outputs_scores): 150 | if resp not in response_dict or score > response_dict[resp]: 151 | response_dict[resp] = score 152 | 153 | # 将字典转换为元组列表并按得分排序 154 | sorted_responses = sorted(response_dict.items(), key=lambda x: x[1], reverse=True) 155 | # response_length = len(outputs) 156 | return sorted_responses 157 | 158 | @torch.inference_mode() 159 | def chat( 160 | self, 161 | query: str, 162 | history: Optional[List[Tuple[str, str]]] = None, 163 | system: Optional[str] = None, 164 | **input_kwargs 165 | ) -> Tuple[str, Tuple[int, int]]: 166 | gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs) 167 | generation_output = self.model.generate(**gen_kwargs) 168 | outputs = generation_output.tolist()[0][prompt_length:] 169 | response = self.tokenizer.decode(outputs, skip_special_tokens=True) 170 | response_length = len(outputs) 171 | return response, (prompt_length, response_length) 172 | 173 | @torch.inference_mode() 174 | def stream_chat( 175 | self, 176 | query: str, 177 | history: Optional[List[Tuple[str, str]]] = None, 178 | system: Optional[str] = None, 179 | **input_kwargs 180 | ) -> Generator[str, None, None]: 181 | gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs) 182 | streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) 183 | gen_kwargs["streamer"] = streamer 184 | 185 | thread = Thread(target=self.model.generate, kwargs=gen_kwargs) 186 | thread.start() 187 | 188 | yield from streamer 189 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/dsets/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.dsets.loader import get_dataset 2 | from llmtuner.dsets.preprocess import preprocess_dataset 3 | from llmtuner.dsets.utils import split_dataset 4 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/dsets/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import TYPE_CHECKING, List, Union 3 | 4 | from datasets import concatenate_datasets, interleave_datasets, load_dataset 5 | 6 | from llmtuner.dsets.utils import checksum, EXT2TYPE 7 | from llmtuner.extras.logging import get_logger 8 | 9 | if TYPE_CHECKING: 10 | from datasets import Dataset, IterableDataset 11 | from llmtuner.hparams import ModelArguments, DataArguments 12 | 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | def get_dataset( 18 | model_args: "ModelArguments", 19 | data_args: "DataArguments" 20 | ) -> Union["Dataset", "IterableDataset"]: 21 | max_samples = data_args.max_samples 22 | all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets 23 | 24 | for dataset_attr in data_args.dataset_list: 25 | logger.info("Loading dataset {}...".format(dataset_attr)) 26 | 27 | if dataset_attr.load_from == "hf_hub": 28 | data_path = dataset_attr.dataset_name 29 | data_files = None 30 | elif dataset_attr.load_from == "script": 31 | data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) 32 | data_files = None 33 | elif dataset_attr.load_from == "file": 34 | data_path = None 35 | data_files: List[str] = [] 36 | 37 | if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory 38 | for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): 39 | data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) 40 | if data_path is None: 41 | data_path = EXT2TYPE.get(file_name.split(".")[-1], None) 42 | else: 43 | assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match." 44 | elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file 45 | data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) 46 | data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None) 47 | else: 48 | raise ValueError("File not found.") 49 | 50 | assert data_path, "File extension must be txt, csv, json or jsonl." 51 | checksum(data_files, dataset_attr.dataset_sha1) 52 | else: 53 | raise NotImplementedError 54 | 55 | dataset = load_dataset( 56 | data_path, 57 | data_files=data_files, 58 | split=data_args.split, 59 | cache_dir=model_args.cache_dir, 60 | streaming=data_args.streaming, 61 | use_auth_token=True if model_args.use_auth_token else None 62 | ) 63 | 64 | if max_samples is not None: 65 | max_samples_temp = min(len(dataset), max_samples) 66 | dataset = dataset.select(range(max_samples_temp)) 67 | 68 | for column_name in ["prompt", "query", "response", "history"]: # align datasets 69 | if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: 70 | dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) 71 | 72 | if dataset_attr.system_prompt: # add system prompt 73 | if data_args.streaming: 74 | dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt}) 75 | else: 76 | dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset)) 77 | 78 | all_datasets.append(dataset) 79 | 80 | if len(data_args.dataset_list) == 1: 81 | return all_datasets[0] 82 | elif data_args.mix_strategy == "concat": 83 | if data_args.streaming: 84 | logger.warning("The samples between different datasets will not be mixed in streaming mode.") 85 | return concatenate_datasets(all_datasets) 86 | elif data_args.mix_strategy.startswith("interleave"): 87 | if not data_args.streaming: 88 | logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") 89 | stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" 90 | return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy) 91 | else: 92 | raise ValueError("Unknown mixing strategy.") 93 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/dsets/utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from typing import TYPE_CHECKING, Dict, List, Optional, Union 3 | 4 | from llmtuner.extras.logging import get_logger 5 | 6 | if TYPE_CHECKING: 7 | from datasets import Dataset, IterableDataset 8 | from transformers import TrainingArguments 9 | from llmtuner.hparams import DataArguments 10 | 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | EXT2TYPE = { 16 | "csv": "csv", 17 | "json": "json", 18 | "jsonl": "json", 19 | "txt": "text" 20 | } 21 | 22 | 23 | def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None: 24 | if file_sha1 is None: 25 | logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") 26 | return 27 | 28 | if len(data_files) != 1: 29 | logger.warning("Checksum failed: too many files.") 30 | return 31 | 32 | with open(data_files[0], "rb") as f: 33 | sha1 = hashlib.sha1(f.read()).hexdigest() 34 | if sha1 != file_sha1: 35 | logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0])) 36 | 37 | 38 | def split_dataset( 39 | dataset: Union["Dataset", "IterableDataset"], 40 | data_args: "DataArguments", 41 | training_args: "TrainingArguments" 42 | ) -> Dict[str, "Dataset"]: 43 | if training_args.do_train: 44 | if data_args.val_size > 1e-6: # Split the dataset 45 | if data_args.streaming: 46 | val_set = dataset.take(int(data_args.val_size)) 47 | train_set = dataset.skip(int(data_args.val_size)) 48 | dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) 49 | return {"train_dataset": train_set, "eval_dataset": val_set} 50 | else: 51 | val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size 52 | dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed) 53 | return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} 54 | else: 55 | if data_args.streaming: 56 | dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) 57 | return {"train_dataset": dataset} 58 | else: # do_eval or do_predict 59 | return {"eval_dataset": dataset} 60 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/extras/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Applied-Machine-Learning-Lab/AMAR/81da184742e887c512ea7fb393a4dc19e9495446/LLMs/LLaMA/src/llmtuner/extras/__init__.py -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/extras/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | from typing import TYPE_CHECKING 5 | from datetime import timedelta 6 | 7 | from transformers import TrainerCallback 8 | from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR 9 | 10 | from llmtuner.extras.constants import LOG_FILE_NAME 11 | from llmtuner.extras.logging import get_logger 12 | 13 | if TYPE_CHECKING: 14 | from transformers import TrainingArguments, TrainerState, TrainerControl 15 | 16 | 17 | logger = get_logger(__name__) 18 | 19 | 20 | class SavePeftModelCallback(TrainerCallback): 21 | 22 | def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 23 | r""" 24 | Event called after a checkpoint save. 25 | """ 26 | if args.should_save: 27 | output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) 28 | model = kwargs.pop("model") 29 | if getattr(model, "is_peft_model", False): 30 | getattr(model, "pretrained_model").save_pretrained(output_dir) 31 | 32 | def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 33 | r""" 34 | Event called at the end of training. 35 | """ 36 | if args.should_save: 37 | model = kwargs.pop("model") 38 | if getattr(model, "is_peft_model", False): 39 | getattr(model, "pretrained_model").save_pretrained(args.output_dir) 40 | 41 | 42 | class LogCallback(TrainerCallback): 43 | 44 | def __init__(self, runner=None): 45 | self.runner = runner 46 | self.in_training = False 47 | self.start_time = time.time() 48 | self.cur_steps = 0 49 | self.max_steps = 0 50 | self.elapsed_time = "" 51 | self.remaining_time = "" 52 | 53 | def timing(self): 54 | cur_time = time.time() 55 | elapsed_time = cur_time - self.start_time 56 | avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0 57 | remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step 58 | self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) 59 | self.remaining_time = str(timedelta(seconds=int(remaining_time))) 60 | 61 | def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 62 | r""" 63 | Event called at the beginning of training. 64 | """ 65 | if state.is_local_process_zero: 66 | self.in_training = True 67 | self.start_time = time.time() 68 | self.max_steps = state.max_steps 69 | if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)): 70 | logger.warning("Previous log file in this folder will be deleted.") 71 | os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) 72 | 73 | def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 74 | r""" 75 | Event called at the end of training. 76 | """ 77 | if state.is_local_process_zero: 78 | self.in_training = False 79 | self.cur_steps = 0 80 | self.max_steps = 0 81 | 82 | def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 83 | r""" 84 | Event called at the end of an substep during gradient accumulation. 85 | """ 86 | if state.is_local_process_zero and self.runner is not None and self.runner.aborted: 87 | control.should_epoch_stop = True 88 | control.should_training_stop = True 89 | 90 | def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 91 | r""" 92 | Event called at the end of a training step. 93 | """ 94 | if state.is_local_process_zero: 95 | self.cur_steps = state.global_step 96 | self.timing() 97 | if self.runner is not None and self.runner.aborted: 98 | control.should_epoch_stop = True 99 | control.should_training_stop = True 100 | 101 | def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 102 | r""" 103 | Event called after an evaluation phase. 104 | """ 105 | if state.is_local_process_zero and not self.in_training: 106 | self.cur_steps = 0 107 | self.max_steps = 0 108 | 109 | def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs): 110 | r""" 111 | Event called after a successful prediction. 112 | """ 113 | if state.is_local_process_zero and not self.in_training: 114 | self.cur_steps = 0 115 | self.max_steps = 0 116 | 117 | def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: 118 | r""" 119 | Event called after logging the last logs. 120 | """ 121 | if not state.is_local_process_zero: 122 | return 123 | 124 | logs = dict( 125 | current_steps=self.cur_steps, 126 | total_steps=self.max_steps, 127 | loss=state.log_history[-1].get("loss", None), 128 | eval_loss=state.log_history[-1].get("eval_loss", None), 129 | predict_loss=state.log_history[-1].get("predict_loss", None), 130 | reward=state.log_history[-1].get("reward", None), 131 | learning_rate=state.log_history[-1].get("learning_rate", None), 132 | epoch=state.log_history[-1].get("epoch", None), 133 | percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, 134 | elapsed_time=self.elapsed_time, 135 | remaining_time=self.remaining_time 136 | ) 137 | os.makedirs(args.output_dir, exist_ok=True) 138 | with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: 139 | f.write(json.dumps(logs) + "\n") 140 | 141 | def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 142 | r""" 143 | Event called after a prediction step. 144 | """ 145 | eval_dataloader = kwargs.pop("eval_dataloader", None) 146 | if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training: 147 | if self.max_steps == 0: 148 | self.max_steps = len(eval_dataloader) 149 | self.cur_steps += 1 150 | self.timing() 151 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/extras/constants.py: -------------------------------------------------------------------------------- 1 | IGNORE_INDEX = -100 2 | 3 | LOG_FILE_NAME = "trainer_log.jsonl" 4 | 5 | LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2"] 6 | 7 | METHODS = ["full", "freeze", "lora"] 8 | 9 | TRAINING_STAGES = { 10 | "Supervised Fine-Tuning": "sft", 11 | "Reward Modeling": "rm", 12 | "PPO": "ppo", 13 | "DPO": "dpo", 14 | "Pre-Training": "pt" 15 | } 16 | 17 | SUPPORTED_MODELS = { 18 | "LLaMA-7B": "huggyllama/llama-7b", 19 | "LLaMA-13B": "huggyllama/llama-13b", 20 | "LLaMA-30B": "huggyllama/llama-30b", 21 | "LLaMA-65B": "huggyllama/llama-65b", 22 | "LLaMA2-7B": "meta-llama/Llama-2-7b-hf", 23 | "LLaMA2-13B": "meta-llama/Llama-2-13b-hf", 24 | "LLaMA2-70B": "meta-llama/Llama-2-70b-hf", 25 | "LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf", 26 | "LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf", 27 | "LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf", 28 | "ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b", 29 | "ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b", 30 | "ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b", 31 | "ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b", 32 | "BLOOM-560M": "bigscience/bloom-560m", 33 | "BLOOM-3B": "bigscience/bloom-3b", 34 | "BLOOM-7B1": "bigscience/bloom-7b1", 35 | "BLOOMZ-560M": "bigscience/bloomz-560m", 36 | "BLOOMZ-3B": "bigscience/bloomz-3b", 37 | "BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt", 38 | "Falcon-7B": "tiiuae/falcon-7b", 39 | "Falcon-40B": "tiiuae/falcon-40b", 40 | "Falcon-7B-Chat": "tiiuae/falcon-7b-instruct", 41 | "Falcon-40B-Chat": "tiiuae/falcon-40b-instruct", 42 | "Baichuan-7B": "baichuan-inc/Baichuan-7B", 43 | "Baichuan-13B": "baichuan-inc/Baichuan-13B-Base", 44 | "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat", 45 | "Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base", 46 | "Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base", 47 | "Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat", 48 | "Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat", 49 | "InternLM-7B": "internlm/internlm-7b", 50 | "InternLM-20B": "internlm/internlm-20b", 51 | "InternLM-7B-Chat": "internlm/internlm-chat-7b", 52 | "InternLM-20B-Chat": "internlm/internlm-chat-20b", 53 | "Qwen-7B": "Qwen/Qwen-7B", 54 | "Qwen-14B": "Qwen/Qwen-14B", 55 | "Qwen-7B-Chat": "Qwen/Qwen-7B-Chat", 56 | "Qwen-14B-Chat": "Qwen/Qwen-14B-Chat", 57 | "XVERSE-13B": "xverse/XVERSE-13B", 58 | "XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat", 59 | "ChatGLM2-6B-Chat": "THUDM/chatglm2-6b", 60 | "Phi1.5-1.3B": "microsoft/phi-1_5" 61 | } 62 | 63 | DEFAULT_MODULE = { 64 | "LLaMA": "q_proj,v_proj", 65 | "LLaMA2": "q_proj,v_proj", 66 | "ChineseLLaMA2": "q_proj,v_proj", 67 | "BLOOM": "query_key_value", 68 | "BLOOMZ": "query_key_value", 69 | "Falcon": "query_key_value", 70 | "Baichuan": "W_pack", 71 | "Baichuan2": "W_pack", 72 | "InternLM": "q_proj,v_proj", 73 | "Qwen": "c_attn", 74 | "XVERSE": "q_proj,v_proj", 75 | "ChatGLM2": "query_key_value", 76 | "Phi1.5": "Wqkv" 77 | } 78 | 79 | DEFAULT_TEMPLATE = { 80 | "LLaMA2": "llama2", 81 | "ChineseLLaMA2": "llama2_zh", 82 | "Baichuan": "baichuan", 83 | "Baichuan2": "baichuan2", 84 | "InternLM": "intern", 85 | "Qwen": "chatml", 86 | "XVERSE": "xverse", 87 | "ChatGLM2": "chatglm2" 88 | } 89 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/extras/logging.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | 4 | 5 | class LoggerHandler(logging.Handler): 6 | 7 | def __init__(self): 8 | super().__init__() 9 | self.log = "" 10 | 11 | def reset(self): 12 | self.log = "" 13 | 14 | def emit(self, record): 15 | if record.name == "httpx": 16 | return 17 | log_entry = self.format(record) 18 | self.log += log_entry 19 | self.log += "\n\n" 20 | 21 | 22 | def reset_logging(): 23 | r""" 24 | Removes basic config of root logger 25 | """ 26 | root = logging.getLogger() 27 | list(map(root.removeHandler, root.handlers)) 28 | list(map(root.removeFilter, root.filters)) 29 | 30 | 31 | def get_logger(name: str) -> logging.Logger: 32 | formatter = logging.Formatter( 33 | fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 34 | datefmt="%m/%d/%Y %H:%M:%S" 35 | ) 36 | handler = logging.StreamHandler(sys.stdout) 37 | handler.setFormatter(formatter) 38 | 39 | logger = logging.getLogger(name) 40 | logger.setLevel(logging.INFO) 41 | logger.addHandler(handler) 42 | 43 | return logger 44 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/extras/misc.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import torch 3 | from typing import TYPE_CHECKING, Tuple 4 | from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList 5 | 6 | if TYPE_CHECKING: 7 | from transformers.modeling_utils import PreTrainedModel 8 | 9 | 10 | class AverageMeter: 11 | r""" 12 | Computes and stores the average and current value. 13 | """ 14 | def __init__(self): 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: 31 | r""" 32 | Returns the number of trainable parameters and number of all parameters in the model. 33 | """ 34 | trainable_params, all_param = 0, 0 35 | for param in model.parameters(): 36 | num_params = param.numel() 37 | # if using DS Zero 3 and the weights are initialized empty 38 | if num_params == 0 and hasattr(param, "ds_numel"): 39 | num_params = param.ds_numel 40 | 41 | # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by 2 42 | if param.__class__.__name__ == "Params4bit": 43 | num_params = num_params * 2 44 | 45 | all_param += num_params 46 | if param.requires_grad: 47 | trainable_params += num_params 48 | 49 | return trainable_params, all_param 50 | 51 | 52 | def get_logits_processor() -> LogitsProcessorList: 53 | logits_processor = LogitsProcessorList() 54 | logits_processor.append(InfNanRemoveLogitsProcessor()) 55 | return logits_processor 56 | 57 | 58 | def torch_gc() -> None: 59 | r""" 60 | Collects GPU memory. 61 | """ 62 | gc.collect() 63 | if torch.cuda.is_available(): 64 | torch.cuda.empty_cache() 65 | torch.cuda.ipc_collect() 66 | 67 | 68 | def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": 69 | r""" 70 | Dispatches a pre-trained model to GPUs with balanced memory. 71 | Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 72 | """ 73 | if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing 74 | return model 75 | 76 | if torch.cuda.device_count() > 1: 77 | from accelerate import dispatch_model 78 | from accelerate.utils import infer_auto_device_map, get_balanced_memory 79 | 80 | if model._no_split_modules is None: 81 | raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") 82 | 83 | kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules} 84 | max_memory = get_balanced_memory(model, **kwargs) 85 | # Make sure tied weights are tied before creating the device map. 86 | model.tie_weights() 87 | device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) 88 | return dispatch_model(model, device_map) 89 | else: 90 | return model.cuda() 91 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/extras/patches/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Applied-Machine-Learning-Lab/AMAR/81da184742e887c512ea7fb393a4dc19e9495446/LLMs/LLaMA/src/llmtuner/extras/patches/__init__.py -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/extras/ploting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import matplotlib.pyplot as plt 5 | from typing import List, Optional 6 | from transformers.trainer import TRAINER_STATE_NAME 7 | 8 | from llmtuner.extras.logging import get_logger 9 | 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | def smooth(scalars: List[float]) -> List[float]: 15 | r""" 16 | EMA implementation according to TensorBoard. 17 | """ 18 | last = scalars[0] 19 | smoothed = list() 20 | weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function 21 | for next_val in scalars: 22 | smoothed_val = last * weight + (1 - weight) * next_val 23 | smoothed.append(smoothed_val) 24 | last = smoothed_val 25 | return smoothed 26 | 27 | 28 | def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: 29 | 30 | with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: 31 | data = json.load(f) 32 | 33 | for key in keys: 34 | steps, metrics = [], [] 35 | for i in range(len(data["log_history"])): 36 | if key in data["log_history"][i]: 37 | steps.append(data["log_history"][i]["step"]) 38 | metrics.append(data["log_history"][i][key]) 39 | 40 | if len(metrics) == 0: 41 | logger.warning(f"No metric {key} to plot.") 42 | continue 43 | 44 | plt.figure() 45 | plt.plot(steps, metrics, alpha=0.4, label="original") 46 | plt.plot(steps, smooth(metrics), label="smoothed") 47 | plt.title("training {} of {}".format(key, save_dictionary)) 48 | plt.xlabel("step") 49 | plt.ylabel(key) 50 | plt.legend() 51 | plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100) 52 | print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key))) 53 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/extras/save_and_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from transformers.trainer import WEIGHTS_NAME 4 | 5 | from llmtuner.extras.logging import get_logger 6 | 7 | 8 | logger = get_logger(__name__) 9 | 10 | 11 | def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: 12 | vhead_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) 13 | if not os.path.exists(vhead_file): 14 | logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir)) 15 | return False 16 | vhead_params = torch.load(vhead_file, map_location="cpu") 17 | model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False) 18 | model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False) 19 | model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False) 20 | model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False) 21 | return True 22 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/hparams/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_args import DataArguments 2 | from .finetuning_args import FinetuningArguments 3 | from .general_args import GeneralArguments 4 | from .generating_args import GeneratingArguments 5 | from .model_args import ModelArguments 6 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/hparams/data_args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import List, Literal, Optional 4 | from dataclasses import dataclass, field 5 | 6 | 7 | @dataclass 8 | class DatasetAttr: 9 | 10 | load_from: str 11 | dataset_name: Optional[str] = None 12 | dataset_sha1: Optional[str] = None 13 | system_prompt: Optional[str] = None 14 | ranking: Optional[bool] = False 15 | prompt: Optional[str] = "instruction" 16 | query: Optional[str] = "input" 17 | response: Optional[str] = "output" 18 | history: Optional[str] = None 19 | 20 | def __repr__(self) -> str: 21 | return self.dataset_name 22 | 23 | 24 | @dataclass 25 | class DataArguments: 26 | r""" 27 | Arguments pertaining to what data we are going to input our model for training and evaluation. 28 | """ 29 | template: Optional[str] = field( 30 | default=None, 31 | metadata={"help": "Which template to use for constructing prompts in training and inference."} 32 | ) 33 | dataset: Optional[str] = field( 34 | default="alpaca_en", 35 | metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."} 36 | ) 37 | dataset_dir: Optional[str] = field( 38 | default="data", 39 | metadata={"help": "The name of the folder containing datasets."} 40 | ) 41 | split: Optional[str] = field( 42 | default="train", 43 | metadata={"help": "Which dataset split to use for training and evaluation."} 44 | ) 45 | cutoff_len: Optional[int] = field( 46 | default=1024, 47 | metadata={"help": "The maximum length of the model inputs after tokenization."} 48 | ) 49 | extra_infor_len: Optional[int] = field( 50 | default=128, 51 | metadata={"help": ""} 52 | ) 53 | gate_len: Optional[int] = field( 54 | default=256, 55 | metadata={"help": ""} 56 | ) 57 | streaming: Optional[bool] = field( 58 | default=False, 59 | metadata={"help": "Enable streaming mode."} 60 | ) 61 | buffer_size: Optional[int] = field( 62 | default=16384, 63 | metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."} 64 | ) 65 | mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( 66 | default="concat", 67 | metadata={"help": "Strategy to use in dataset mixing."} 68 | ) 69 | interleave_probs: Optional[str] = field( 70 | default=None, 71 | metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."} 72 | ) 73 | overwrite_cache: Optional[bool] = field( 74 | default=False, 75 | metadata={"help": "Overwrite the cached training and evaluation sets."} 76 | ) 77 | preprocessing_num_workers: Optional[int] = field( 78 | default=None, 79 | metadata={"help": "The number of processes to use for the preprocessing."} 80 | ) 81 | max_samples: Optional[int] = field( 82 | default=None, 83 | metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} 84 | ) 85 | eval_num_beams: Optional[int] = field( 86 | default=None, 87 | metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} 88 | ) 89 | ignore_pad_token_for_loss: Optional[bool] = field( 90 | default=True, 91 | metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} 92 | ) 93 | system_prompt: Optional[str] = field( 94 | default=None, 95 | metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."} 96 | ) 97 | val_size: Optional[float] = field( 98 | default=0, 99 | metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."} 100 | ) 101 | sft_packing: Optional[bool] = field( 102 | default=False, 103 | metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."} 104 | ) 105 | topk: Optional[int] = field( 106 | default=False, 107 | metadata={"help": "Top k for retrieval of entity relation and subgraph."} 108 | ) 109 | 110 | def init_for_training(self): # support mixing multiple datasets 111 | dataset_names = [ds.strip() for ds in self.dataset.split(",")] 112 | with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: 113 | dataset_info = json.load(f) 114 | 115 | prompt_list = self.system_prompt.split("|") if self.system_prompt else [None] 116 | prompt_list = prompt_list * (len(dataset_names) // len(prompt_list)) 117 | assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1." 118 | 119 | if self.interleave_probs is not None: 120 | self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")] 121 | 122 | self.dataset_list: List[DatasetAttr] = [] 123 | for i, name in enumerate(dataset_names): 124 | if name not in dataset_info: 125 | raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) 126 | 127 | if "hf_hub_url" in dataset_info[name]: 128 | dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) 129 | elif "script_url" in dataset_info[name]: 130 | dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) 131 | else: 132 | dataset_attr = DatasetAttr( 133 | "file", 134 | dataset_name=dataset_info[name]["file_name"], 135 | dataset_sha1=dataset_info[name].get("file_sha1", None) 136 | ) 137 | 138 | if "columns" in dataset_info[name]: 139 | dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None) 140 | dataset_attr.query = dataset_info[name]["columns"].get("query", None) 141 | dataset_attr.response = dataset_info[name]["columns"].get("response", None) 142 | dataset_attr.history = dataset_info[name]["columns"].get("history", None) 143 | 144 | dataset_attr.ranking = dataset_info[name].get("ranking", False) 145 | dataset_attr.system_prompt = prompt_list[i] 146 | self.dataset_list.append(dataset_attr) 147 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/hparams/finetuning_args.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Literal, Optional 3 | from dataclasses import asdict, dataclass, field 4 | 5 | 6 | @dataclass 7 | class FinetuningArguments: 8 | r""" 9 | Arguments pertaining to which techniques we are going to fine-tuning with. 10 | """ 11 | finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field( 12 | default="lora", 13 | metadata={"help": "Which fine-tuning method to use."} 14 | ) 15 | num_layer_trainable: Optional[int] = field( 16 | default=3, 17 | metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."} 18 | ) 19 | name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( 20 | default="mlp", 21 | metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \ 22 | LLaMA choices: [\"mlp\", \"self_attn\"], \ 23 | BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \ 24 | Qwen choices: [\"mlp\", \"attn\"], \ 25 | Phi-1.5 choices: [\"mlp\", \"mixer\"], \ 26 | LLaMA-2, Baichuan, InternLM, XVERSE choices: the same as LLaMA."} 27 | ) 28 | lora_rank: Optional[int] = field( 29 | default=8, 30 | metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} 31 | ) 32 | lora_alpha: Optional[float] = field( 33 | default=32.0, 34 | metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."} 35 | ) 36 | lora_dropout: Optional[float] = field( 37 | default=0.1, 38 | metadata={"help": "Dropout rate for the LoRA fine-tuning."} 39 | ) 40 | lora_target: Optional[str] = field( 41 | default=None, 42 | metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ 43 | LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ 44 | BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ 45 | Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ 46 | Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ 47 | Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \ 48 | LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."} 49 | ) 50 | additional_target: Optional[str] = field( 51 | default=None, 52 | metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."} 53 | ) 54 | resume_lora_training: Optional[bool] = field( 55 | default=True, 56 | metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} 57 | ) 58 | ppo_score_norm: Optional[bool] = field( 59 | default=False, 60 | metadata={"help": "Use score normalization in PPO Training."} 61 | ) 62 | dpo_beta: Optional[float] = field( 63 | default=0.1, 64 | metadata={"help": "The beta parameter for the DPO loss."} 65 | ) 66 | 67 | def __post_init__(self): 68 | if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA 69 | self.lora_target = [target.strip() for target in self.lora_target.split(",")] 70 | 71 | if isinstance(self.additional_target, str): 72 | self.additional_target = [target.strip() for target in self.additional_target.split(",")] 73 | 74 | assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method." 75 | 76 | def save_to_json(self, json_path: str): 77 | r"""Saves the content of this instance in JSON format inside `json_path`.""" 78 | json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" 79 | with open(json_path, "w", encoding="utf-8") as f: 80 | f.write(json_string) 81 | 82 | @classmethod 83 | def load_from_json(cls, json_path: str): 84 | r"""Creates an instance from the content of `json_path`.""" 85 | with open(json_path, "r", encoding="utf-8") as f: 86 | text = f.read() 87 | return cls(**json.loads(text)) 88 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/hparams/general_args.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | from dataclasses import dataclass, field 3 | 4 | 5 | @dataclass 6 | class GeneralArguments: 7 | r""" 8 | Arguments pertaining to which stage we are going to perform. 9 | """ 10 | stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( 11 | default="sft", 12 | metadata={"help": "Which stage will be performed in training."} 13 | ) 14 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/hparams/generating_args.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | from dataclasses import asdict, dataclass, field 3 | 4 | 5 | @dataclass 6 | class GeneratingArguments: 7 | r""" 8 | Arguments pertaining to specify the decoding parameters. 9 | """ 10 | do_sample: Optional[bool] = field( 11 | default=True, 12 | metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."} 13 | ) 14 | temperature: Optional[float] = field( 15 | default=0.95, 16 | metadata={"help": "The value used to modulate the next token probabilities."} 17 | ) 18 | top_p: Optional[float] = field( 19 | default=0.7, 20 | metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."} 21 | ) 22 | top_k: Optional[int] = field( 23 | default=50, 24 | metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} 25 | ) 26 | num_beams: Optional[int] = field( 27 | default=1, 28 | metadata={"help": "Number of beams for beam search. 1 means no beam search."} 29 | ) 30 | max_length: Optional[int] = field( 31 | default=None, 32 | metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."} 33 | ) 34 | max_new_tokens: Optional[int] = field( 35 | default=512, 36 | metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."} 37 | ) 38 | repetition_penalty: Optional[float] = field( 39 | default=1.0, 40 | metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} 41 | ) 42 | length_penalty: Optional[float] = field( 43 | default=1.0, 44 | metadata={"help": "Exponential penalty to the length that is used with beam-based generation."} 45 | ) 46 | 47 | def to_dict(self) -> Dict[str, Any]: 48 | args = asdict(self) 49 | if args.get("max_new_tokens", None): 50 | args.pop("max_length", None) 51 | return args 52 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/hparams/model_args.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Literal, Optional 3 | from dataclasses import dataclass, field 4 | 5 | 6 | @dataclass 7 | class ModelArguments: 8 | r""" 9 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune. 10 | """ 11 | 12 | model_name_or_path: str = field( 13 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} 14 | ) 15 | soft_prompt_length: Optional[int] = field( 16 | default=32, 17 | metadata={"help": "The length of the soft prompt."} 18 | ) 19 | cache_dir: Optional[str] = field( 20 | default=None, 21 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} 22 | ) 23 | use_fast_tokenizer: Optional[bool] = field( 24 | default=True, 25 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} 26 | ) 27 | use_auth_token: Optional[bool] = field( 28 | default=False, 29 | metadata={"help": "Will use the token generated when running `huggingface-cli login`."} 30 | ) 31 | model_revision: Optional[str] = field( 32 | default="main", 33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} 34 | ) 35 | quantization_bit: Optional[int] = field( 36 | default=None, 37 | metadata={"help": "The number of bits to quantize the model."} 38 | ) 39 | quantization_type: Optional[Literal["fp4", "nf4"]] = field( 40 | default="nf4", 41 | metadata={"help": "Quantization data type to use in int4 training."} 42 | ) 43 | double_quantization: Optional[bool] = field( 44 | default=True, 45 | metadata={"help": "Whether to use double quantization in int4 training or not."} 46 | ) 47 | rope_scaling: Optional[Literal["linear", "dynamic"]] = field( 48 | default=None, 49 | metadata={"help": "Adopt scaled rotary positional embeddings."} 50 | ) 51 | flash_attn: Optional[bool] = field( 52 | default=False, 53 | metadata={"help": "Enable FlashAttention-2 for faster training."} 54 | ) 55 | shift_attn: Optional[bool] = field( 56 | default=False, 57 | metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} 58 | ) 59 | checkpoint_dir: Optional[str] = field( 60 | default=None, 61 | metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} 62 | ) 63 | reward_model: Optional[str] = field( 64 | default=None, 65 | metadata={"help": "Path to the directory containing the checkpoints of the reward model."} 66 | ) 67 | plot_loss: Optional[bool] = field( 68 | default=False, 69 | metadata={"help": "Whether to plot the training loss after fine-tuning or not."} 70 | ) 71 | hf_auth_token: Optional[str] = field( 72 | default=None, 73 | metadata={"help": "Auth token to log in with Hugging Face Hub."} 74 | ) 75 | layernorm_dtype: Optional[Literal["auto", "fp16", "bf16", "fp32"]] = field( 76 | default="auto", 77 | metadata={"help": "Data type of the layer norm weights."} 78 | ) 79 | 80 | def __post_init__(self): 81 | self.compute_dtype = None 82 | self.model_max_length = None 83 | 84 | if self.checkpoint_dir is not None: # support merging multiple lora weights 85 | self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] 86 | 87 | if self.quantization_bit is not None: 88 | assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." 89 | 90 | if self.use_auth_token == True and self.hf_auth_token is not None: 91 | from huggingface_hub.hf_api import HfFolder # lazy load 92 | HfFolder.save_token(self.hf_auth_token) 93 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.tuner.tune import export_model, run_exp 2 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/core/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.tuner.core.parser import get_train_args, get_infer_args 2 | from llmtuner.tuner.core.loader import load_model_and_tokenizer 3 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/core/adapter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from typing import TYPE_CHECKING 4 | 5 | from peft import ( 6 | PeftModel, 7 | TaskType, 8 | LoraConfig, 9 | get_peft_model 10 | ) 11 | from peft.utils import CONFIG_NAME, WEIGHTS_NAME,SAFETENSORS_WEIGHTS_NAME 12 | 13 | from llmtuner.extras.logging import get_logger 14 | from llmtuner.tuner.core.utils import find_all_linear_modules 15 | 16 | if TYPE_CHECKING: 17 | from transformers.modeling_utils import PreTrainedModel 18 | from llmtuner.hparams import ModelArguments, FinetuningArguments 19 | 20 | 21 | logger = get_logger(__name__) 22 | 23 | 24 | def init_adapter( 25 | model: "PreTrainedModel", 26 | model_args: "ModelArguments", 27 | finetuning_args: "FinetuningArguments", 28 | is_trainable: bool, 29 | is_mergeable: bool 30 | ) -> "PreTrainedModel": 31 | r""" 32 | Initializes the adapters. 33 | 34 | Support full-parameter, freeze and LoRA training. 35 | 36 | Note that the trainable parameters must be cast to float32. 37 | """ 38 | 39 | if finetuning_args.finetuning_type == "none" and is_trainable: 40 | raise ValueError("You cannot use finetuning_type=none while training.") 41 | 42 | if finetuning_args.finetuning_type == "full" and is_trainable: 43 | logger.info("Fine-tuning method: Full") 44 | model = model.float() 45 | 46 | if finetuning_args.finetuning_type == "freeze": 47 | logger.info("Fine-tuning method: Freeze") 48 | num_layers = getattr(model.config, "num_layers") 49 | if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 50 | trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)] 51 | else: # fine-tuning the first n layers if num_layer_trainable < 0 52 | trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)] 53 | 54 | trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids] 55 | for name, param in model.named_parameters(): 56 | if not any(trainable_layer in name for trainable_layer in trainable_layers): 57 | param.requires_grad_(False) 58 | else: 59 | param.data = param.data.to(torch.float32) 60 | 61 | if finetuning_args.finetuning_type == "lora": 62 | logger.info("Fine-tuning method: LoRA") 63 | latest_checkpoint = None 64 | 65 | if model_args.checkpoint_dir is not None: 66 | lora_config = LoraConfig( 67 | task_type=TaskType.CAUSAL_LM, 68 | inference_mode=True, 69 | r=finetuning_args.lora_rank, 70 | lora_alpha=finetuning_args.lora_alpha, 71 | lora_dropout=finetuning_args.lora_dropout, 72 | target_modules = finetuning_args.lora_target, 73 | ) 74 | 75 | model = get_peft_model(model, lora_config) 76 | if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923 77 | model.base_model.peft_config = model.peft_config 78 | 79 | 80 | # assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], SAFETENSORS_WEIGHTS_NAME)), \ 81 | # "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0]) 82 | # assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ 83 | # "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." 84 | 85 | # if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning 86 | # checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] 87 | # else: 88 | # checkpoints_to_merge = model_args.checkpoint_dir 89 | 90 | # for checkpoint in checkpoints_to_merge: 91 | # model = PeftModel.from_pretrained(model, checkpoint) 92 | # model = model.merge_and_unload() 93 | 94 | # if len(checkpoints_to_merge) > 0: 95 | # logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) 96 | 97 | # if latest_checkpoint is not None: # resume lora training or quantized inference 98 | # model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable) 99 | 100 | if is_trainable and latest_checkpoint is None: # create new lora weights while training 101 | if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": 102 | target_modules = find_all_linear_modules(model, model_args.quantization_bit) 103 | else: 104 | target_modules = finetuning_args.lora_target 105 | 106 | lora_config = LoraConfig( 107 | task_type=TaskType.CAUSAL_LM, 108 | inference_mode=False, 109 | r=finetuning_args.lora_rank, 110 | lora_alpha=finetuning_args.lora_alpha, 111 | lora_dropout=finetuning_args.lora_dropout, 112 | target_modules=target_modules, 113 | modules_to_save=finetuning_args.additional_target 114 | ) 115 | model = get_peft_model(model, lora_config) 116 | if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923 117 | model.base_model.peft_config = model.peft_config 118 | 119 | if model_args.checkpoint_dir is not None: 120 | logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) 121 | 122 | return model 123 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/core/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import TYPE_CHECKING, List, Optional 3 | 4 | from llmtuner.extras.constants import LAYERNORM_NAMES 5 | 6 | if TYPE_CHECKING: 7 | from transformers.modeling_utils import PreTrainedModel 8 | 9 | 10 | def find_all_linear_modules( 11 | model: "PreTrainedModel", 12 | quantization_bit: Optional[int] = None, 13 | output_layer_name: Optional[str] = "lm_head" 14 | ) -> List[str]: 15 | if quantization_bit is not None: 16 | import bitsandbytes as bnb 17 | linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt 18 | else: 19 | linear_cls = torch.nn.Linear 20 | 21 | module_names = set() 22 | for name, module in model.named_modules(): 23 | if output_layer_name not in name and isinstance(module, linear_cls): 24 | module_names.add(name.split(".")[-1]) 25 | 26 | if output_layer_name in module_names: 27 | module_names.pop(output_layer_name) 28 | 29 | return list(module_names) 30 | 31 | 32 | def prepare_model_for_training( 33 | model: "PreTrainedModel", 34 | layernorm_dtype: torch.dtype, 35 | finetuning_type: str, 36 | output_layer_name: Optional[str] = "lm_head", 37 | use_gradient_checkpointing: Optional[bool] = True, 38 | layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES 39 | ) -> "PreTrainedModel": 40 | r""" 41 | Includes: 42 | (1) cast the layernorm in fp32 43 | (2) make output embedding layer require grads 44 | (3) upcast the lm_head to fp32 45 | Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33 46 | """ 47 | for name, param in model.named_parameters(): 48 | if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): 49 | param.data = param.data.to(layernorm_dtype) 50 | 51 | if use_gradient_checkpointing: 52 | if hasattr(model, "enable_input_require_grads"): 53 | model.enable_input_require_grads() 54 | else: 55 | def make_inputs_require_grad(module, input, output): 56 | output.requires_grad_(True) 57 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 58 | 59 | model.gradient_checkpointing_enable() 60 | model.config.use_cache = False # turn off when gradient checkpointing is enabled 61 | 62 | if finetuning_type != "full" and hasattr(model, output_layer_name): 63 | output_layer: torch.nn.Linear = getattr(model, output_layer_name) 64 | input_dtype = output_layer.weight.dtype 65 | 66 | class CastOutputToFloat(torch.nn.Sequential): 67 | 68 | def forward(self, x: torch.Tensor) -> torch.Tensor: 69 | return super().forward(x.to(input_dtype)).to(torch.float32) 70 | 71 | setattr(model, output_layer_name, CastOutputToFloat(output_layer)) 72 | 73 | return model 74 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/dpo/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.tuner.dpo.workflow import run_dpo 2 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/dpo/collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, List, Sequence, Tuple 4 | from transformers import DataCollatorForSeq2Seq 5 | 6 | 7 | @dataclass 8 | class DPODataCollatorWithPadding(DataCollatorForSeq2Seq): 9 | r""" 10 | Data collator for pairwise data. 11 | """ 12 | 13 | def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor: 14 | padded_labels = [] 15 | for feature, (prompt_len, answer_len) in zip(batch, positions): 16 | if self.tokenizer.padding_side == "left": 17 | start, end = feature.size(0) - answer_len, feature.size(0) 18 | else: 19 | start, end = prompt_len, prompt_len + answer_len 20 | padded_tensor = self.label_pad_token_id * torch.ones_like(feature) 21 | padded_tensor[start:end] = feature[start:end] 22 | padded_labels.append(padded_tensor) 23 | return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory 24 | 25 | def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: 26 | r""" 27 | Pads batched data to the longest sequence in the batch. 28 | 29 | We generate 2 * n examples where the first n examples represent chosen examples and 30 | the last n examples represent rejected examples. 31 | """ 32 | concatenated_features = [] 33 | label_positions = [] 34 | for key in ("chosen_ids", "rejected_ids"): 35 | for feature in features: 36 | prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key]) 37 | concatenated_features.append({ 38 | "input_ids": feature["prompt_ids"] + feature[key], 39 | "attention_mask": [1] * (prompt_len + answer_len) 40 | }) 41 | label_positions.append((prompt_len, answer_len)) 42 | 43 | batch = self.tokenizer.pad( 44 | concatenated_features, 45 | padding=self.padding, 46 | max_length=self.max_length, 47 | pad_to_multiple_of=self.pad_to_multiple_of, 48 | return_tensors=self.return_tensors, 49 | ) 50 | batch["labels"] = self._pad_labels(batch["input_ids"], label_positions) 51 | return batch 52 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/dpo/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union 4 | from transformers import BatchEncoding, Trainer 5 | from trl import DPOTrainer 6 | from trl.trainer.utils import disable_dropout_in_model 7 | 8 | from llmtuner.extras.constants import IGNORE_INDEX 9 | 10 | if TYPE_CHECKING: 11 | from transformers import PreTrainedModel 12 | 13 | 14 | class CustomDPOTrainer(DPOTrainer): 15 | 16 | def __init__( 17 | self, 18 | beta: float, 19 | model: Union["PreTrainedModel", torch.nn.Module], 20 | ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None, 21 | disable_dropout: Optional[bool] = True, 22 | **kwargs 23 | ): 24 | if disable_dropout: 25 | disable_dropout_in_model(model) 26 | if ref_model is not None: 27 | disable_dropout_in_model(ref_model) 28 | 29 | self.is_encoder_decoder = model.config.is_encoder_decoder 30 | self.ref_model = ref_model 31 | self.use_dpo_data_collator = True # hack to avoid warning 32 | self.label_pad_token_id = IGNORE_INDEX 33 | self.padding_value = 0 34 | self.beta = beta 35 | self._stored_metrics = defaultdict(lambda: defaultdict(list)) 36 | 37 | Trainer.__init__(self, model=model, **kwargs) 38 | if not hasattr(self, "accelerator"): 39 | raise AttributeError("Please update `transformers`.") 40 | 41 | if ref_model is not None: 42 | if self.is_deepspeed_enabled: 43 | self.ref_model, = self.accelerator._prepare_deepspeed(self.ref_model) 44 | self.ref_model.eval() 45 | else: 46 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) 47 | 48 | def concatenated_forward( 49 | self, 50 | model: Optional[torch.nn.Module] = None, 51 | batch: Optional[Dict[str, torch.Tensor]] = None 52 | ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: 53 | batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error 54 | 55 | all_logits = model( 56 | input_ids=batch_copied["input_ids"], 57 | attention_mask=batch_copied["attention_mask"], 58 | return_dict=True 59 | ).logits.to(torch.float32) 60 | 61 | all_logps = self._get_batch_logps( 62 | all_logits, 63 | batch["labels"], 64 | average_log_prob=False 65 | ) 66 | batch_size = batch["input_ids"].size(0) // 2 67 | chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0) 68 | chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0) 69 | return chosen_logps, rejected_logps, chosen_logits, rejected_logits 70 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/dpo/workflow.py: -------------------------------------------------------------------------------- 1 | # Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py 2 | 3 | from copy import deepcopy 4 | from peft import PeftModel 5 | from typing import TYPE_CHECKING, Optional, List 6 | from transformers import Seq2SeqTrainingArguments 7 | 8 | from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset 9 | from llmtuner.extras.constants import IGNORE_INDEX 10 | from llmtuner.extras.ploting import plot_loss 11 | from llmtuner.tuner.core import load_model_and_tokenizer 12 | from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding 13 | from llmtuner.tuner.dpo.trainer import CustomDPOTrainer 14 | 15 | if TYPE_CHECKING: 16 | from transformers import TrainerCallback 17 | from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments 18 | 19 | 20 | def run_dpo( 21 | model_args: "ModelArguments", 22 | data_args: "DataArguments", 23 | training_args: "Seq2SeqTrainingArguments", 24 | finetuning_args: "FinetuningArguments", 25 | callbacks: Optional[List["TrainerCallback"]] = None 26 | ): 27 | dataset = get_dataset(model_args, data_args) 28 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") 29 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") 30 | data_collator = DPODataCollatorWithPadding( 31 | tokenizer=tokenizer, 32 | label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 33 | ) 34 | 35 | training_args_dict = training_args.to_dict() 36 | training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset 37 | training_args = Seq2SeqTrainingArguments(**training_args_dict) 38 | 39 | # Initialize our Trainer 40 | trainer = CustomDPOTrainer( 41 | beta=finetuning_args.dpo_beta, 42 | model=model, 43 | ref_model=deepcopy(model) if not isinstance(model, PeftModel) else None, 44 | args=training_args, 45 | tokenizer=tokenizer, 46 | data_collator=data_collator, 47 | callbacks=callbacks, 48 | **split_dataset(dataset, data_args, training_args) 49 | ) 50 | 51 | # Training 52 | if training_args.do_train: 53 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 54 | trainer.log_metrics("train", train_result.metrics) 55 | trainer.save_metrics("train", train_result.metrics) 56 | trainer.save_state() 57 | trainer.save_model() 58 | if trainer.is_world_process_zero() and model_args.plot_loss: 59 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) 60 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.tuner.ppo.workflow import run_ppo 2 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/ppo/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple 3 | 4 | from llmtuner.extras.constants import LAYERNORM_NAMES 5 | 6 | if TYPE_CHECKING: 7 | from trl import AutoModelForCausalLMWithValueHead 8 | 9 | 10 | def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None: 11 | if target == "reward": # save default head temporarily 12 | valuehead_state_dict = model.v_head.state_dict() 13 | setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone()) 14 | setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone()) 15 | 16 | model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active 17 | model.v_head.load_state_dict({ 18 | "summary.weight": getattr(model, "{}_head_weight".format(target)), 19 | "summary.bias": getattr(model, "{}_head_bias".format(target)) 20 | }) 21 | 22 | 23 | def cast_layernorm_dtype( 24 | model: "AutoModelForCausalLMWithValueHead", 25 | compute_dtype: torch.dtype, 26 | layer_norm_params: Optional[Dict[str, torch.Tensor]] = None, 27 | layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES 28 | ) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]: 29 | 30 | layer_norm_state_dict = {} 31 | 32 | for name, param in model.named_parameters(): 33 | if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): 34 | if layer_norm_params is None: 35 | layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability 36 | param.data = param.data.to(compute_dtype) 37 | else: 38 | param.data = layer_norm_params[name] # restore float32 weights 39 | 40 | return model, layer_norm_state_dict 41 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/ppo/workflow.py: -------------------------------------------------------------------------------- 1 | # Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py 2 | 3 | import math 4 | from trl import PPOConfig 5 | from torch.optim import AdamW 6 | from typing import TYPE_CHECKING, Optional, List 7 | from transformers import DataCollatorWithPadding 8 | from transformers.optimization import get_scheduler 9 | 10 | from llmtuner.dsets import get_dataset, preprocess_dataset 11 | from llmtuner.extras.callbacks import SavePeftModelCallback 12 | from llmtuner.extras.ploting import plot_loss 13 | from llmtuner.tuner.core import load_model_and_tokenizer 14 | from llmtuner.tuner.ppo.trainer import CustomPPOTrainer 15 | 16 | if TYPE_CHECKING: 17 | from transformers import Seq2SeqTrainingArguments, TrainerCallback 18 | from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments 19 | 20 | 21 | def run_ppo( 22 | model_args: "ModelArguments", 23 | data_args: "DataArguments", 24 | training_args: "Seq2SeqTrainingArguments", 25 | finetuning_args: "FinetuningArguments", 26 | generating_args: "GeneratingArguments", 27 | callbacks: Optional[List["TrainerCallback"]] = None 28 | ): 29 | dataset = get_dataset(model_args, data_args) 30 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") 31 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo") 32 | 33 | tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training 34 | data_collator = DataCollatorWithPadding(tokenizer=tokenizer) 35 | 36 | ppo_config = PPOConfig( 37 | model_name=model_args.model_name_or_path, 38 | learning_rate=training_args.learning_rate, 39 | mini_batch_size=training_args.per_device_train_batch_size, 40 | batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps, 41 | gradient_accumulation_steps=training_args.gradient_accumulation_steps, 42 | ppo_epochs=1, 43 | max_grad_norm=training_args.max_grad_norm, 44 | seed=training_args.seed, 45 | log_with=training_args.report_to, 46 | optimize_cuda_cache=True, 47 | accelerator_kwargs={"step_scheduler_with_optimizer": False} 48 | ) 49 | 50 | if finetuning_args.ppo_score_norm: 51 | ppo_config.use_score_scaling = True 52 | ppo_config.use_score_norm = True 53 | 54 | optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate) 55 | total_train_batch_size = ( 56 | training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size 57 | ) 58 | num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size) 59 | lr_scheduler = get_scheduler( 60 | training_args.lr_scheduler_type, 61 | optimizer=optimizer, 62 | num_warmup_steps=training_args.get_warmup_steps(num_training_steps), 63 | num_training_steps=num_training_steps 64 | ) 65 | 66 | # Initialize our Trainer 67 | ppo_trainer = CustomPPOTrainer( 68 | training_args=training_args, 69 | generating_args=generating_args, 70 | callbacks=callbacks + [SavePeftModelCallback()], 71 | compute_dtype=model_args.compute_dtype, 72 | config=ppo_config, 73 | model=model, 74 | ref_model=None, 75 | tokenizer=tokenizer, 76 | dataset=dataset, 77 | data_collator=data_collator, 78 | optimizer=optimizer, 79 | lr_scheduler=lr_scheduler 80 | ) 81 | 82 | # Training 83 | if training_args.do_train: 84 | ppo_trainer.ppo_train() 85 | ppo_trainer.save_model() 86 | ppo_trainer.save_state() # must be called after save_model to have a folder 87 | if ppo_trainer.is_world_process_zero() and model_args.plot_loss: 88 | plot_loss(training_args.output_dir, keys=["loss", "reward"]) 89 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/pt/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.tuner.pt.workflow import run_pt 2 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/pt/workflow.py: -------------------------------------------------------------------------------- 1 | # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py 2 | 3 | import math 4 | from typing import TYPE_CHECKING, Optional, List 5 | from transformers import DataCollatorForLanguageModeling, Trainer 6 | 7 | from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset 8 | from llmtuner.extras.ploting import plot_loss 9 | from llmtuner.tuner.core import load_model_and_tokenizer 10 | 11 | if TYPE_CHECKING: 12 | from transformers import Seq2SeqTrainingArguments, TrainerCallback 13 | from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments 14 | 15 | 16 | def run_pt( 17 | model_args: "ModelArguments", 18 | data_args: "DataArguments", 19 | training_args: "Seq2SeqTrainingArguments", 20 | finetuning_args: "FinetuningArguments", 21 | callbacks: Optional[List["TrainerCallback"]] = None 22 | ): 23 | dataset = get_dataset(model_args, data_args) 24 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") 25 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt") 26 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 27 | 28 | # Initialize our Trainer 29 | trainer = Trainer( 30 | model=model, 31 | args=training_args, 32 | tokenizer=tokenizer, 33 | data_collator=data_collator, 34 | callbacks=callbacks, 35 | **split_dataset(dataset, data_args, training_args) 36 | ) 37 | 38 | # Training 39 | if training_args.do_train: 40 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 41 | trainer.log_metrics("train", train_result.metrics) 42 | trainer.save_metrics("train", train_result.metrics) 43 | trainer.save_state() 44 | trainer.save_model() 45 | if trainer.is_world_process_zero() and model_args.plot_loss: 46 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) 47 | 48 | # Evaluation 49 | if training_args.do_eval: 50 | metrics = trainer.evaluate(metric_key_prefix="eval") 51 | try: 52 | perplexity = math.exp(metrics["eval_loss"]) 53 | except OverflowError: 54 | perplexity = float("inf") 55 | 56 | metrics["perplexity"] = perplexity 57 | trainer.log_metrics("eval", metrics) 58 | trainer.save_metrics("eval", metrics) 59 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/rm/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.tuner.rm.workflow import run_rm 2 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/rm/collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass 3 | from typing import Any, Dict, Sequence 4 | from transformers import DataCollatorWithPadding 5 | 6 | 7 | @dataclass 8 | class PairwiseDataCollatorWithPadding(DataCollatorWithPadding): 9 | r""" 10 | Data collator for pairwise data. 11 | """ 12 | 13 | def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: 14 | r""" 15 | Pads batched data to the longest sequence in the batch. 16 | 17 | We generate 2 * n examples where the first n examples represent chosen examples and 18 | the last n examples represent rejected examples. 19 | """ 20 | features = [ 21 | { 22 | "input_ids": feature["prompt_ids"] + feature[key], 23 | "attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key])) 24 | } 25 | for key in ("chosen_ids", "rejected_ids") for feature in features 26 | ] 27 | return super().__call__(features) 28 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/rm/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict, Sequence, Tuple, Union 3 | 4 | 5 | def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: 6 | preds, _ = eval_preds 7 | return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])} 8 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/rm/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union 5 | from transformers import Trainer 6 | 7 | from llmtuner.extras.logging import get_logger 8 | 9 | if TYPE_CHECKING: 10 | from transformers.trainer import PredictionOutput 11 | from transformers.modeling_utils import PreTrainedModel 12 | 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | class PairwiseTrainer(Trainer): 18 | r""" 19 | Inherits PeftTrainer to compute pairwise loss. 20 | """ 21 | 22 | def __init__(self, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.can_return_loss = True # override property to return eval_loss 25 | 26 | def compute_loss( 27 | self, 28 | model: "PreTrainedModel", 29 | inputs: Dict[str, torch.Tensor], 30 | return_outputs: Optional[bool] = False 31 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: 32 | r""" 33 | Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. 34 | 35 | Subclass and override to inject custom behavior. 36 | 37 | Note that the first element will be removed from the output tuple. 38 | See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509 39 | """ 40 | # Compute rewards 41 | _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) 42 | if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2 43 | values = torch.transpose(values, 0, 1) 44 | 45 | # Split the inputs and rewards into two parts, chosen and rejected 46 | batch_size = inputs["input_ids"].size(0) // 2 47 | chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:] 48 | chosen_attn_mask, rejected_attn_mask = ( 49 | inputs["attention_mask"][:batch_size], inputs["attention_mask"][batch_size:] 50 | ) 51 | chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:] 52 | chosen_scores, rejected_scores = [], [] 53 | 54 | # Compute pairwise loss. Only backprop on the different tokens before padding 55 | # Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py 56 | loss = 0 57 | for i in range(batch_size): 58 | chosen_length = chosen_attn_mask[i].nonzero()[-1] + 1 59 | rejected_length = rejected_attn_mask[i].nonzero()[-1] + 1 60 | check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero() 61 | 62 | if len(check_divergence) == 0: 63 | end_index = chosen_length 64 | div_index = end_index - 1 65 | else: 66 | end_index = max(chosen_length, rejected_length) 67 | div_index = check_divergence[0] 68 | 69 | assert div_index > 0 70 | chosen_trunc_rewards = chosen_rewards[i, div_index:end_index] 71 | rejected_trunc_rewards = rejected_rewards[i, div_index:end_index] 72 | if return_outputs: # use the score on the EOS token for inference 73 | chosen_scores.append(chosen_rewards[i, chosen_length-1]) 74 | rejected_scores.append(rejected_rewards[i, rejected_length-1]) 75 | loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean() 76 | 77 | loss = loss / batch_size 78 | if return_outputs: 79 | chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores) 80 | return loss, [loss, chosen_scores, rejected_scores] 81 | 82 | return loss 83 | 84 | def save_predictions( 85 | self, 86 | predict_results: "PredictionOutput" 87 | ) -> None: 88 | r""" 89 | Saves model predictions to `output_dir`. 90 | 91 | A custom behavior that not contained in Seq2SeqTrainer. 92 | """ 93 | if not self.is_world_process_zero(): 94 | return 95 | 96 | output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") 97 | logger.info(f"Saving prediction results to {output_prediction_file}") 98 | 99 | chosen_scores, rejected_scores = predict_results.predictions 100 | 101 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 102 | res: List[str] = [] 103 | for c_score, r_score in zip(chosen_scores, rejected_scores): 104 | res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)})) 105 | writer.write("\n".join(res)) 106 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/rm/workflow.py: -------------------------------------------------------------------------------- 1 | # Inspired by: 2 | # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py 3 | 4 | from typing import TYPE_CHECKING, Optional, List 5 | from transformers import Seq2SeqTrainingArguments 6 | 7 | from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset 8 | from llmtuner.extras.callbacks import SavePeftModelCallback 9 | from llmtuner.extras.ploting import plot_loss 10 | from llmtuner.tuner.core import load_model_and_tokenizer 11 | from llmtuner.tuner.rm.metric import compute_accuracy 12 | from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding 13 | from llmtuner.tuner.rm.trainer import PairwiseTrainer 14 | 15 | if TYPE_CHECKING: 16 | from transformers import TrainerCallback 17 | from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments 18 | 19 | 20 | def run_rm( 21 | model_args: "ModelArguments", 22 | data_args: "DataArguments", 23 | training_args: "Seq2SeqTrainingArguments", 24 | finetuning_args: "FinetuningArguments", 25 | callbacks: Optional[List["TrainerCallback"]] = None 26 | ): 27 | dataset = get_dataset(model_args, data_args) 28 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm") 29 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") 30 | data_collator = PairwiseDataCollatorWithPadding(tokenizer) 31 | 32 | training_args_dict = training_args.to_dict() 33 | training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset 34 | training_args = Seq2SeqTrainingArguments(**training_args_dict) 35 | 36 | # Initialize our Trainer 37 | trainer = PairwiseTrainer( 38 | model=model, 39 | args=training_args, 40 | tokenizer=tokenizer, 41 | data_collator=data_collator, 42 | callbacks=callbacks + [SavePeftModelCallback()], 43 | compute_metrics=compute_accuracy, 44 | **split_dataset(dataset, data_args, training_args) 45 | ) 46 | 47 | # Training 48 | if training_args.do_train: 49 | train_result = trainer.train() 50 | trainer.log_metrics("train", train_result.metrics) 51 | trainer.save_metrics("train", train_result.metrics) 52 | trainer.save_state() 53 | trainer.save_model() 54 | if trainer.is_world_process_zero() and model_args.plot_loss: 55 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) 56 | 57 | # Evaluation 58 | if training_args.do_eval: 59 | metrics = trainer.evaluate(metric_key_prefix="eval") 60 | trainer.log_metrics("eval", metrics) 61 | trainer.save_metrics("eval", metrics) 62 | 63 | # Predict 64 | if training_args.do_predict: 65 | predict_results = trainer.predict(dataset, metric_key_prefix="predict") 66 | trainer.log_metrics("predict", predict_results.metrics) 67 | trainer.save_metrics("predict", predict_results.metrics) 68 | trainer.save_predictions(predict_results) 69 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/sft/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.tuner.sft.workflow import run_sft 2 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/sft/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from dataclasses import dataclass 3 | from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union 4 | 5 | import jieba 6 | from rouge_chinese import Rouge 7 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 8 | 9 | from llmtuner.extras.constants import IGNORE_INDEX 10 | 11 | if TYPE_CHECKING: 12 | from transformers.tokenization_utils import PreTrainedTokenizer 13 | 14 | 15 | @dataclass 16 | class ComputeMetrics: 17 | r""" 18 | Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. 19 | """ 20 | 21 | tokenizer: "PreTrainedTokenizer" 22 | 23 | def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: 24 | r""" 25 | Uses the model predictions to compute metrics. 26 | """ 27 | preds, labels = eval_preds 28 | score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} 29 | 30 | preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) 31 | labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) 32 | 33 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) 34 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) 35 | 36 | for pred, label in zip(decoded_preds, decoded_labels): 37 | hypothesis = list(jieba.cut(pred)) 38 | reference = list(jieba.cut(label)) 39 | 40 | if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0: 41 | result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} 42 | else: 43 | rouge = Rouge() 44 | scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) 45 | result = scores[0] 46 | 47 | for k, v in result.items(): 48 | score_dict[k].append(round(v["f"] * 100, 4)) 49 | 50 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) 51 | score_dict["bleu-4"].append(round(bleu_score * 100, 4)) 52 | 53 | return {k: float(np.mean(v)) for k, v in score_dict.items()} 54 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/sft/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union 7 | from transformers import Seq2SeqTrainer 8 | 9 | from llmtuner.extras.constants import IGNORE_INDEX 10 | from llmtuner.extras.logging import get_logger 11 | 12 | if TYPE_CHECKING: 13 | from transformers.trainer import PredictionOutput 14 | 15 | 16 | logger = get_logger(__name__) 17 | 18 | 19 | class CustomSeq2SeqTrainer(Seq2SeqTrainer): 20 | r""" 21 | Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE. 22 | """ 23 | 24 | def prediction_step( 25 | self, 26 | model: nn.Module, 27 | inputs: Dict[str, Union[torch.Tensor, Any]], 28 | prediction_loss_only: bool, 29 | ignore_keys: Optional[List[str]] = None, 30 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 31 | r""" 32 | Removes the prompt part in the generated tokens. 33 | 34 | Subclass and override to inject custom behavior. 35 | """ 36 | if self.args.predict_with_generate: 37 | assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." 38 | assert self.tokenizer.pad_token_id is not None, "Pad token is required." 39 | prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) 40 | if prompt_len > label_len: 41 | inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) 42 | if label_len > prompt_len: 43 | inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"]) 44 | if "attention_mask" in inputs: 45 | inputs["attention_mask"] = self._pad_tensors_to_target_len( 46 | inputs["attention_mask"], inputs["labels"], pad_token_id=0 47 | ) 48 | if "position_ids" in inputs: 49 | inputs["position_ids"] = self._pad_tensors_to_target_len( 50 | inputs["position_ids"], inputs["labels"], pad_token_id=0 51 | ) 52 | 53 | loss, generated_tokens, labels = super().prediction_step( 54 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 55 | ) 56 | if generated_tokens is not None and self.args.predict_with_generate: 57 | generated_tokens[:, :max(prompt_len, label_len)] = self.tokenizer.pad_token_id 58 | generated_tokens = generated_tokens.contiguous() 59 | 60 | return loss, generated_tokens, labels 61 | 62 | def _pad_tensors_to_target_len( 63 | self, 64 | src_tensor: torch.Tensor, 65 | tgt_tensor: torch.Tensor, 66 | pad_token_id: Optional[int] = None 67 | ) -> torch.Tensor: 68 | r""" 69 | Pads the tensor to the same length as the target tensor. 70 | """ 71 | pad_token_id = pad_token_id if pad_token_id is not None else self.tokenizer.pad_token_id 72 | padded_tensor = pad_token_id * torch.ones_like(tgt_tensor) 73 | padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding 74 | return padded_tensor.contiguous() # in contiguous memory 75 | 76 | def save_predictions( 77 | self, 78 | predict_results: "PredictionOutput" 79 | ) -> None: 80 | r""" 81 | Saves model predictions to `output_dir`. 82 | 83 | A custom behavior that not contained in Seq2SeqTrainer. 84 | """ 85 | if not self.is_world_process_zero(): 86 | return 87 | 88 | output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") 89 | logger.info(f"Saving prediction results to {output_prediction_file}") 90 | 91 | preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) 92 | labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id) 93 | 94 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) 95 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=True) 96 | 97 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 98 | res: List[str] = [] 99 | for pred, label in zip(decoded_preds, decoded_labels): 100 | res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) 101 | writer.write("\n".join(res)) 102 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/sft/workflow.py: -------------------------------------------------------------------------------- 1 | # Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py 2 | 3 | from typing import TYPE_CHECKING, Optional, List 4 | from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments 5 | 6 | from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset 7 | from llmtuner.extras.constants import IGNORE_INDEX 8 | from llmtuner.extras.misc import get_logits_processor 9 | from llmtuner.extras.ploting import plot_loss 10 | from llmtuner.tuner.core import load_model_and_tokenizer 11 | from llmtuner.tuner.sft.metric import ComputeMetrics 12 | from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer 13 | import torch 14 | from pro_model.pm import PromptTuningModelForCausalLM 15 | 16 | 17 | if TYPE_CHECKING: 18 | from transformers import TrainerCallback 19 | from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments 20 | 21 | 22 | def run_sft( 23 | model_args: "ModelArguments", 24 | data_args: "DataArguments", 25 | training_args: "Seq2SeqTrainingArguments", 26 | finetuning_args: "FinetuningArguments", 27 | generating_args: "GeneratingArguments", 28 | callbacks: Optional[List["TrainerCallback"]] = None 29 | ): 30 | dataset = get_dataset(model_args, data_args) 31 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") 32 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft") 33 | 34 | 35 | model = PromptTuningModelForCausalLM(model, model_args.soft_prompt_length).to(dtype=torch.bfloat16) 36 | 37 | for para in model.named_parameters(): 38 | 39 | print(para[0]) 40 | 41 | 42 | if training_args.predict_with_generate: 43 | tokenizer.padding_side = "left" # use left-padding in generation 44 | 45 | data_collator = DataCollatorForSeq2Seq( 46 | tokenizer=tokenizer, 47 | label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 48 | ) 49 | 50 | # Override the decoding parameters of Seq2SeqTrainer 51 | training_args_dict = training_args.to_dict() 52 | training_args_dict.update(dict( 53 | generation_max_length=training_args.generation_max_length or data_args.cutoff_len, 54 | generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams 55 | )) 56 | training_args = Seq2SeqTrainingArguments(**training_args_dict) 57 | 58 | # Initialize our Trainer 59 | trainer = CustomSeq2SeqTrainer( 60 | model=model, 61 | args=training_args, 62 | tokenizer=tokenizer, 63 | data_collator=data_collator, 64 | callbacks=callbacks, 65 | compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, 66 | **split_dataset(dataset, data_args, training_args) 67 | ) 68 | 69 | # Keyword arguments for `model.generate` 70 | gen_kwargs = generating_args.to_dict() 71 | gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids 72 | gen_kwargs["pad_token_id"] = tokenizer.pad_token_id 73 | gen_kwargs["logits_processor"] = get_logits_processor() 74 | 75 | # Training 76 | if training_args.do_train: 77 | train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 78 | trainer.log_metrics("train", train_result.metrics) 79 | trainer.save_metrics("train", train_result.metrics) 80 | trainer.save_state() 81 | trainer.save_model() 82 | if trainer.is_world_process_zero() and model_args.plot_loss: 83 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) 84 | 85 | # Evaluation 86 | if training_args.do_eval: 87 | metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) 88 | if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled 89 | metrics.pop("eval_loss", None) 90 | trainer.log_metrics("eval", metrics) 91 | trainer.save_metrics("eval", metrics) 92 | 93 | # Predict 94 | if training_args.do_predict: 95 | predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) 96 | if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled 97 | predict_results.metrics.pop("predict_loss", None) 98 | trainer.log_metrics("predict", predict_results.metrics) 99 | trainer.save_metrics("predict", predict_results.metrics) 100 | trainer.save_predictions(predict_results) 101 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/tuner/tune.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Dict, List, Optional 2 | 3 | from llmtuner.extras.callbacks import LogCallback 4 | from llmtuner.extras.logging import get_logger 5 | from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer 6 | from llmtuner.tuner.pt import run_pt 7 | from llmtuner.tuner.sft import run_sft 8 | from llmtuner.tuner.rm import run_rm 9 | # from llmtuner.tuner.ppo import run_ppo 10 | # from llmtuner.tuner.dpo import run_dpo 11 | 12 | if TYPE_CHECKING: 13 | from transformers import TrainerCallback 14 | 15 | 16 | logger = get_logger(__name__) 17 | 18 | 19 | def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None): 20 | model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args) 21 | callbacks = [LogCallback()] if callbacks is None else callbacks 22 | 23 | print(data_args) 24 | 25 | if general_args.stage == "pt": 26 | run_pt(model_args, data_args, training_args, finetuning_args, callbacks) 27 | elif general_args.stage == "sft": 28 | run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) 29 | elif general_args.stage == "rm": 30 | run_rm(model_args, data_args, training_args, finetuning_args, callbacks) 31 | elif general_args.stage == "ppo": 32 | run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) 33 | elif general_args.stage == "dpo": 34 | run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) 35 | else: 36 | raise ValueError("Unknown task.") 37 | 38 | 39 | def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"): 40 | model_args, _, training_args, finetuning_args, _, _ = get_train_args(args) 41 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) 42 | tokenizer.padding_side = "left" # restore padding side 43 | tokenizer.init_kwargs["padding_side"] = "left" 44 | model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size) 45 | try: 46 | tokenizer.save_pretrained(training_args.output_dir) 47 | except: 48 | logger.warning("Cannot save tokenizer, please copy the files manually.") 49 | 50 | 51 | if __name__ == "__main__": 52 | run_exp() 53 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/webui/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.webui.interface import create_ui, create_web_demo 2 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/webui/chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Dict, List, Optional, Tuple 3 | 4 | from llmtuner.chat.stream_chat import ChatModel 5 | from llmtuner.extras.misc import torch_gc 6 | from llmtuner.hparams import GeneratingArguments 7 | from llmtuner.webui.common import get_model_path, get_save_dir 8 | from llmtuner.webui.locales import ALERTS 9 | 10 | 11 | class WebChatModel(ChatModel): 12 | 13 | def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None: 14 | if lazy_init: 15 | self.model = None 16 | self.tokenizer = None 17 | self.generating_args = GeneratingArguments() 18 | else: 19 | super().__init__(args) 20 | 21 | def load_model( 22 | self, 23 | lang: str, 24 | model_name: str, 25 | checkpoints: List[str], 26 | finetuning_type: str, 27 | quantization_bit: str, 28 | template: str, 29 | system_prompt: str 30 | ): 31 | if self.model is not None: 32 | yield ALERTS["err_exists"][lang] 33 | return 34 | 35 | if not model_name: 36 | yield ALERTS["err_no_model"][lang] 37 | return 38 | 39 | model_name_or_path = get_model_path(model_name) 40 | if not model_name_or_path: 41 | yield ALERTS["err_no_path"][lang] 42 | return 43 | 44 | if checkpoints: 45 | checkpoint_dir = ",".join( 46 | [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] 47 | ) 48 | else: 49 | checkpoint_dir = None 50 | 51 | yield ALERTS["info_loading"][lang] 52 | args = dict( 53 | model_name_or_path=model_name_or_path, 54 | checkpoint_dir=checkpoint_dir, 55 | finetuning_type=finetuning_type, 56 | quantization_bit=int(quantization_bit) if quantization_bit and quantization_bit != "None" else None, 57 | template=template, 58 | system_prompt=system_prompt 59 | ) 60 | super().__init__(args) 61 | 62 | yield ALERTS["info_loaded"][lang] 63 | 64 | def unload_model(self, lang: str): 65 | yield ALERTS["info_unloading"][lang] 66 | self.model = None 67 | self.tokenizer = None 68 | torch_gc() 69 | yield ALERTS["info_unloaded"][lang] 70 | 71 | def predict( 72 | self, 73 | chatbot: List[Tuple[str, str]], 74 | query: str, 75 | history: List[Tuple[str, str]], 76 | system: str, 77 | max_new_tokens: int, 78 | top_p: float, 79 | temperature: float 80 | ): 81 | chatbot.append([query, ""]) 82 | response = "" 83 | for new_text in self.stream_chat( 84 | query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature 85 | ): 86 | response += new_text 87 | response = self.postprocess(response) 88 | new_history = history + [(query, response)] 89 | chatbot[-1] = [query, response] 90 | yield chatbot, new_history 91 | 92 | def postprocess(self, response: str) -> str: 93 | blocks = response.split("```") 94 | for i, block in enumerate(blocks): 95 | if i % 2 == 0: 96 | blocks[i] = block.replace("<", "<").replace(">", ">") 97 | return "```".join(blocks) 98 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/webui/common.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any, Dict, Optional 4 | 5 | import gradio as gr 6 | from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME 7 | from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME 8 | 9 | from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES 10 | 11 | 12 | DEFAULT_CACHE_DIR = "cache" 13 | DEFAULT_DATA_DIR = "data" 14 | DEFAULT_SAVE_DIR = "saves" 15 | USER_CONFIG = "user.config" 16 | DATA_CONFIG = "dataset_info.json" 17 | 18 | 19 | def get_save_dir(*args) -> os.PathLike: 20 | return os.path.join(DEFAULT_SAVE_DIR, *args) 21 | 22 | 23 | def get_config_path() -> os.PathLike: 24 | return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) 25 | 26 | 27 | def load_config() -> Dict[str, Any]: 28 | try: 29 | with open(get_config_path(), "r", encoding="utf-8") as f: 30 | return json.load(f) 31 | except: 32 | return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None} 33 | 34 | 35 | def save_config(lang: str, model_name: str, model_path: str) -> None: 36 | os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) 37 | user_config = load_config() 38 | user_config["lang"] = lang or user_config["lang"] 39 | if model_name: 40 | user_config["last_model"] = model_name 41 | user_config["path_dict"][model_name] = model_path 42 | with open(get_config_path(), "w", encoding="utf-8") as f: 43 | json.dump(user_config, f, indent=2, ensure_ascii=False) 44 | 45 | 46 | def get_model_path(model_name: str) -> str: 47 | user_config = load_config() 48 | return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, "")) 49 | 50 | 51 | def get_template(model_name: str) -> str: 52 | if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE: 53 | return DEFAULT_TEMPLATE[model_name.split("-")[0]] 54 | return "default" 55 | 56 | 57 | def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: 58 | checkpoints = [] 59 | save_dir = get_save_dir(model_name, finetuning_type) 60 | if save_dir and os.path.isdir(save_dir): 61 | for checkpoint in os.listdir(save_dir): 62 | if ( 63 | os.path.isdir(os.path.join(save_dir, checkpoint)) 64 | and any([ 65 | os.path.isfile(os.path.join(save_dir, checkpoint, name)) 66 | for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME) 67 | ]) 68 | ): 69 | checkpoints.append(checkpoint) 70 | return gr.update(value=[], choices=checkpoints) 71 | 72 | 73 | def load_dataset_info(dataset_dir: str) -> Dict[str, Any]: 74 | try: 75 | with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: 76 | return json.load(f) 77 | except: 78 | return {} 79 | 80 | 81 | def list_dataset( 82 | dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0] 83 | ) -> Dict[str, Any]: 84 | dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) 85 | ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"] 86 | datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking] 87 | return gr.update(value=[], choices=datasets) 88 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/webui/components/__init__.py: -------------------------------------------------------------------------------- 1 | from llmtuner.webui.components.top import create_top 2 | from llmtuner.webui.components.train import create_train_tab 3 | from llmtuner.webui.components.eval import create_eval_tab 4 | from llmtuner.webui.components.infer import create_infer_tab 5 | from llmtuner.webui.components.export import create_export_tab 6 | from llmtuner.webui.components.chatbot import create_chat_box 7 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/webui/components/chatbot.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Dict, Optional, Tuple 2 | 3 | import gradio as gr 4 | 5 | if TYPE_CHECKING: 6 | from gradio.blocks import Block 7 | from gradio.components import Component 8 | from llmtuner.webui.chat import WebChatModel 9 | 10 | 11 | def create_chat_box( 12 | chat_model: "WebChatModel", 13 | visible: Optional[bool] = False 14 | ) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]: 15 | with gr.Box(visible=visible) as chat_box: 16 | chatbot = gr.Chatbot() 17 | 18 | with gr.Row(): 19 | with gr.Column(scale=4): 20 | system = gr.Textbox(show_label=False) 21 | query = gr.Textbox(show_label=False, lines=8) 22 | submit_btn = gr.Button(variant="primary") 23 | 24 | with gr.Column(scale=1): 25 | clear_btn = gr.Button() 26 | max_new_tokens = gr.Slider(10, 2048, value=chat_model.generating_args.max_new_tokens, step=1) 27 | top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01) 28 | temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01) 29 | 30 | history = gr.State([]) 31 | 32 | submit_btn.click( 33 | chat_model.predict, 34 | [chatbot, query, history, system, max_new_tokens, top_p, temperature], 35 | [chatbot, history], 36 | show_progress=True 37 | ).then( 38 | lambda: gr.update(value=""), outputs=[query] 39 | ) 40 | 41 | clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) 42 | 43 | return chat_box, chatbot, history, dict( 44 | system=system, 45 | query=query, 46 | submit_btn=submit_btn, 47 | clear_btn=clear_btn, 48 | max_new_tokens=max_new_tokens, 49 | top_p=top_p, 50 | temperature=temperature 51 | ) 52 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/webui/components/data.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from typing import TYPE_CHECKING, Tuple 3 | 4 | if TYPE_CHECKING: 5 | from gradio.blocks import Block 6 | from gradio.components import Component 7 | 8 | 9 | def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]: 10 | with gr.Box(visible=False, elem_classes="modal-box") as preview_box: 11 | with gr.Row(): 12 | preview_count = gr.Number(interactive=False) 13 | 14 | with gr.Row(): 15 | preview_samples = gr.JSON(interactive=False) 16 | 17 | close_btn = gr.Button() 18 | 19 | close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False) 20 | 21 | return preview_box, preview_count, preview_samples, close_btn 22 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/webui/components/eval.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Dict 2 | import gradio as gr 3 | 4 | from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR 5 | from llmtuner.webui.components.data import create_preview_box 6 | from llmtuner.webui.utils import can_preview, get_preview 7 | 8 | if TYPE_CHECKING: 9 | from gradio.components import Component 10 | from llmtuner.webui.runner import Runner 11 | 12 | 13 | def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: 14 | with gr.Row(): 15 | dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) 16 | dataset = gr.Dropdown(multiselect=True, scale=4) 17 | data_preview_btn = gr.Button(interactive=False, scale=1) 18 | 19 | preview_box, preview_count, preview_samples, close_btn = create_preview_box() 20 | 21 | dataset_dir.change(list_dataset, [dataset_dir], [dataset]) 22 | dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn]) 23 | data_preview_btn.click( 24 | get_preview, 25 | [dataset_dir, dataset], 26 | [preview_count, preview_samples, preview_box], 27 | queue=False 28 | ) 29 | 30 | with gr.Row(): 31 | cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1) 32 | max_samples = gr.Textbox(value="100000") 33 | batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1) 34 | predict = gr.Checkbox(value=True) 35 | 36 | with gr.Row(): 37 | max_new_tokens = gr.Slider(10, 2048, value=128, step=1) 38 | top_p = gr.Slider(0.01, 1, value=0.7, step=0.01) 39 | temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01) 40 | 41 | with gr.Row(): 42 | cmd_preview_btn = gr.Button() 43 | start_btn = gr.Button() 44 | stop_btn = gr.Button() 45 | 46 | with gr.Row(): 47 | process_bar = gr.Slider(visible=False, interactive=False) 48 | 49 | with gr.Box(): 50 | output_box = gr.Markdown() 51 | 52 | input_components = [ 53 | top_elems["lang"], 54 | top_elems["model_name"], 55 | top_elems["checkpoints"], 56 | top_elems["finetuning_type"], 57 | top_elems["quantization_bit"], 58 | top_elems["template"], 59 | top_elems["system_prompt"], 60 | dataset_dir, 61 | dataset, 62 | cutoff_len, 63 | max_samples, 64 | batch_size, 65 | predict, 66 | max_new_tokens, 67 | top_p, 68 | temperature 69 | ] 70 | 71 | output_components = [ 72 | output_box, 73 | process_bar 74 | ] 75 | 76 | cmd_preview_btn.click(runner.preview_eval, input_components, output_components) 77 | start_btn.click(runner.run_eval, input_components, output_components) 78 | stop_btn.click(runner.set_abort, queue=False) 79 | 80 | return dict( 81 | dataset_dir=dataset_dir, 82 | dataset=dataset, 83 | data_preview_btn=data_preview_btn, 84 | preview_count=preview_count, 85 | preview_samples=preview_samples, 86 | close_btn=close_btn, 87 | cutoff_len=cutoff_len, 88 | max_samples=max_samples, 89 | batch_size=batch_size, 90 | predict=predict, 91 | max_new_tokens=max_new_tokens, 92 | top_p=top_p, 93 | temperature=temperature, 94 | cmd_preview_btn=cmd_preview_btn, 95 | start_btn=start_btn, 96 | stop_btn=stop_btn, 97 | output_box=output_box 98 | ) 99 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/webui/components/export.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Dict 2 | import gradio as gr 3 | 4 | from llmtuner.webui.utils import save_model 5 | 6 | if TYPE_CHECKING: 7 | from gradio.components import Component 8 | 9 | 10 | def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]: 11 | with gr.Row(): 12 | save_dir = gr.Textbox() 13 | max_shard_size = gr.Slider(value=10, minimum=1, maximum=100) 14 | 15 | export_btn = gr.Button() 16 | info_box = gr.Textbox(show_label=False, interactive=False) 17 | 18 | export_btn.click( 19 | save_model, 20 | [ 21 | top_elems["lang"], 22 | top_elems["model_name"], 23 | top_elems["checkpoints"], 24 | top_elems["finetuning_type"], 25 | top_elems["template"], 26 | max_shard_size, 27 | save_dir 28 | ], 29 | [info_box] 30 | ) 31 | 32 | return dict( 33 | save_dir=save_dir, 34 | max_shard_size=max_shard_size, 35 | export_btn=export_btn, 36 | info_box=info_box 37 | ) 38 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/webui/components/infer.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Dict 2 | 3 | import gradio as gr 4 | 5 | from llmtuner.webui.chat import WebChatModel 6 | from llmtuner.webui.components.chatbot import create_chat_box 7 | 8 | if TYPE_CHECKING: 9 | from gradio.components import Component 10 | 11 | 12 | def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]: 13 | with gr.Row(): 14 | load_btn = gr.Button() 15 | unload_btn = gr.Button() 16 | 17 | info_box = gr.Textbox(show_label=False, interactive=False) 18 | 19 | chat_model = WebChatModel(lazy_init=True) 20 | chat_box, chatbot, history, chat_elems = create_chat_box(chat_model) 21 | 22 | load_btn.click( 23 | chat_model.load_model, 24 | [ 25 | top_elems["lang"], 26 | top_elems["model_name"], 27 | top_elems["checkpoints"], 28 | top_elems["finetuning_type"], 29 | top_elems["quantization_bit"], 30 | top_elems["template"], 31 | top_elems["system_prompt"] 32 | ], 33 | [info_box] 34 | ).then( 35 | lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box] 36 | ) 37 | 38 | unload_btn.click( 39 | chat_model.unload_model, [top_elems["lang"]], [info_box] 40 | ).then( 41 | lambda: ([], []), outputs=[chatbot, history] 42 | ).then( 43 | lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box] 44 | ) 45 | 46 | return dict( 47 | info_box=info_box, 48 | load_btn=load_btn, 49 | unload_btn=unload_btn, 50 | **chat_elems 51 | ) 52 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/webui/components/top.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Dict 2 | 3 | import gradio as gr 4 | 5 | from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS 6 | from llmtuner.extras.template import templates 7 | from llmtuner.webui.common import list_checkpoint, get_model_path, get_template, save_config 8 | from llmtuner.webui.utils import can_quantize 9 | 10 | if TYPE_CHECKING: 11 | from gradio.components import Component 12 | 13 | 14 | def create_top() -> Dict[str, "Component"]: 15 | available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] 16 | 17 | with gr.Row(): 18 | lang = gr.Dropdown(choices=["en", "zh"], scale=1) 19 | model_name = gr.Dropdown(choices=available_models, scale=3) 20 | model_path = gr.Textbox(scale=3) 21 | 22 | with gr.Row(): 23 | finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) 24 | checkpoints = gr.Dropdown(multiselect=True, scale=5) 25 | refresh_btn = gr.Button(scale=1) 26 | 27 | with gr.Accordion(label="Advanced config", open=False) as advanced_tab: 28 | with gr.Row(): 29 | quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1) 30 | template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1) 31 | system_prompt = gr.Textbox(scale=2) 32 | 33 | lang.change(save_config, [lang, model_name, model_path]) 34 | 35 | model_name.change( 36 | list_checkpoint, [model_name, finetuning_type], [checkpoints] 37 | ).then( 38 | get_model_path, [model_name], [model_path] 39 | ).then( 40 | get_template, [model_name], [template] 41 | ) # do not save config since the below line will save 42 | 43 | model_path.change(save_config, [lang, model_name, model_path]) 44 | 45 | finetuning_type.change( 46 | list_checkpoint, [model_name, finetuning_type], [checkpoints] 47 | ).then( 48 | can_quantize, [finetuning_type], [quantization_bit] 49 | ) 50 | 51 | refresh_btn.click( 52 | list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False 53 | ) 54 | 55 | return dict( 56 | lang=lang, 57 | model_name=model_name, 58 | model_path=model_path, 59 | finetuning_type=finetuning_type, 60 | checkpoints=checkpoints, 61 | refresh_btn=refresh_btn, 62 | advanced_tab=advanced_tab, 63 | quantization_bit=quantization_bit, 64 | template=template, 65 | system_prompt=system_prompt 66 | ) 67 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/webui/components/train.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Dict 2 | from transformers.trainer_utils import SchedulerType 3 | 4 | import gradio as gr 5 | 6 | from llmtuner.extras.constants import TRAINING_STAGES 7 | from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR 8 | from llmtuner.webui.components.data import create_preview_box 9 | from llmtuner.webui.utils import can_preview, get_preview, gen_plot 10 | 11 | if TYPE_CHECKING: 12 | from gradio.components import Component 13 | from llmtuner.webui.runner import Runner 14 | 15 | 16 | def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: 17 | with gr.Row(): 18 | training_stage = gr.Dropdown( 19 | choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2 20 | ) 21 | dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) 22 | dataset = gr.Dropdown(multiselect=True, scale=4) 23 | data_preview_btn = gr.Button(interactive=False, scale=1) 24 | 25 | preview_box, preview_count, preview_samples, close_btn = create_preview_box() 26 | 27 | training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset]) 28 | dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset]) 29 | dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn]) 30 | data_preview_btn.click( 31 | get_preview, 32 | [dataset_dir, dataset], 33 | [preview_count, preview_samples, preview_box], 34 | queue=False 35 | ) 36 | 37 | with gr.Row(): 38 | cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1) 39 | learning_rate = gr.Textbox(value="5e-5") 40 | num_train_epochs = gr.Textbox(value="3.0") 41 | max_samples = gr.Textbox(value="100000") 42 | compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16") 43 | 44 | with gr.Row(): 45 | batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1) 46 | gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1) 47 | lr_scheduler_type = gr.Dropdown( 48 | choices=[scheduler.value for scheduler in SchedulerType], value="cosine" 49 | ) 50 | max_grad_norm = gr.Textbox(value="1.0") 51 | val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) 52 | 53 | with gr.Accordion(label="Advanced config", open=False) as advanced_tab: 54 | with gr.Row(): 55 | logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5) 56 | save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) 57 | warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1) 58 | flash_attn = gr.Checkbox(value=False) 59 | rope_scaling = gr.Checkbox(value=False) 60 | 61 | with gr.Accordion(label="LoRA config", open=False) as lora_tab: 62 | with gr.Row(): 63 | lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1) 64 | lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1) 65 | lora_target = gr.Textbox(scale=2) 66 | resume_lora_training = gr.Checkbox(value=True, scale=1) 67 | 68 | with gr.Accordion(label="RLHF config", open=False) as rlhf_tab: 69 | with gr.Row(): 70 | dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2) 71 | reward_model = gr.Dropdown(scale=2) 72 | refresh_btn = gr.Button(scale=1) 73 | 74 | refresh_btn.click( 75 | list_checkpoint, 76 | [top_elems["model_name"], top_elems["finetuning_type"]], 77 | [reward_model], 78 | queue=False 79 | ) 80 | 81 | with gr.Row(): 82 | cmd_preview_btn = gr.Button() 83 | start_btn = gr.Button() 84 | stop_btn = gr.Button() 85 | 86 | with gr.Row(): 87 | with gr.Column(scale=3): 88 | with gr.Row(): 89 | output_dir = gr.Textbox() 90 | 91 | with gr.Row(): 92 | process_bar = gr.Slider(visible=False, interactive=False) 93 | 94 | with gr.Box(): 95 | output_box = gr.Markdown() 96 | 97 | with gr.Column(scale=1): 98 | loss_viewer = gr.Plot() 99 | 100 | input_components = [ 101 | top_elems["lang"], 102 | top_elems["model_name"], 103 | top_elems["checkpoints"], 104 | top_elems["finetuning_type"], 105 | top_elems["quantization_bit"], 106 | top_elems["template"], 107 | top_elems["system_prompt"], 108 | training_stage, 109 | dataset_dir, 110 | dataset, 111 | cutoff_len, 112 | learning_rate, 113 | num_train_epochs, 114 | max_samples, 115 | compute_type, 116 | batch_size, 117 | gradient_accumulation_steps, 118 | lr_scheduler_type, 119 | max_grad_norm, 120 | val_size, 121 | logging_steps, 122 | save_steps, 123 | warmup_steps, 124 | flash_attn, 125 | rope_scaling, 126 | lora_rank, 127 | lora_dropout, 128 | lora_target, 129 | resume_lora_training, 130 | dpo_beta, 131 | reward_model, 132 | output_dir 133 | ] 134 | 135 | output_components = [ 136 | output_box, 137 | process_bar 138 | ] 139 | 140 | cmd_preview_btn.click(runner.preview_train, input_components, output_components) 141 | start_btn.click(runner.run_train, input_components, output_components) 142 | stop_btn.click(runner.set_abort, queue=False) 143 | 144 | process_bar.change( 145 | gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False 146 | ) 147 | 148 | return dict( 149 | training_stage=training_stage, 150 | dataset_dir=dataset_dir, 151 | dataset=dataset, 152 | data_preview_btn=data_preview_btn, 153 | preview_count=preview_count, 154 | preview_samples=preview_samples, 155 | close_btn=close_btn, 156 | cutoff_len=cutoff_len, 157 | learning_rate=learning_rate, 158 | num_train_epochs=num_train_epochs, 159 | max_samples=max_samples, 160 | compute_type=compute_type, 161 | batch_size=batch_size, 162 | gradient_accumulation_steps=gradient_accumulation_steps, 163 | lr_scheduler_type=lr_scheduler_type, 164 | max_grad_norm=max_grad_norm, 165 | val_size=val_size, 166 | advanced_tab=advanced_tab, 167 | logging_steps=logging_steps, 168 | save_steps=save_steps, 169 | warmup_steps=warmup_steps, 170 | flash_attn=flash_attn, 171 | rope_scaling=rope_scaling, 172 | lora_tab=lora_tab, 173 | lora_rank=lora_rank, 174 | lora_dropout=lora_dropout, 175 | lora_target=lora_target, 176 | resume_lora_training=resume_lora_training, 177 | rlhf_tab=rlhf_tab, 178 | dpo_beta=dpo_beta, 179 | reward_model=reward_model, 180 | refresh_btn=refresh_btn, 181 | cmd_preview_btn=cmd_preview_btn, 182 | start_btn=start_btn, 183 | stop_btn=stop_btn, 184 | output_dir=output_dir, 185 | output_box=output_box, 186 | loss_viewer=loss_viewer 187 | ) 188 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/webui/css.py: -------------------------------------------------------------------------------- 1 | CSS = r""" 2 | .modal-box { 3 | position: fixed !important; 4 | top: 50%; 5 | left: 50%; 6 | transform: translate(-50%, -50%); /* center horizontally */ 7 | max-width: 1000px; 8 | max-height: 750px; 9 | overflow-y: scroll !important; 10 | background-color: var(--input-background-fill); 11 | border: 2px solid black !important; 12 | z-index: 1000; 13 | } 14 | 15 | .dark .modal-box { 16 | border: 2px solid white !important; 17 | } 18 | """ 19 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/webui/interface.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from transformers.utils.versions import require_version 3 | 4 | from llmtuner.webui.components import ( 5 | create_top, 6 | create_train_tab, 7 | create_eval_tab, 8 | create_infer_tab, 9 | create_export_tab, 10 | create_chat_box 11 | ) 12 | from llmtuner.webui.chat import WebChatModel 13 | from llmtuner.webui.css import CSS 14 | from llmtuner.webui.manager import Manager 15 | from llmtuner.webui.runner import Runner 16 | 17 | 18 | require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0") 19 | 20 | 21 | def create_ui() -> gr.Blocks: 22 | runner = Runner() 23 | 24 | with gr.Blocks(title="Web Tuner", css=CSS) as demo: 25 | top_elems = create_top() 26 | 27 | with gr.Tab("Train"): 28 | train_elems = create_train_tab(top_elems, runner) 29 | 30 | with gr.Tab("Evaluate"): 31 | eval_elems = create_eval_tab(top_elems, runner) 32 | 33 | with gr.Tab("Chat"): 34 | infer_elems = create_infer_tab(top_elems) 35 | 36 | with gr.Tab("Export"): 37 | export_elems = create_export_tab(top_elems) 38 | 39 | elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems] 40 | manager = Manager(elem_list) 41 | 42 | demo.load( 43 | manager.gen_label, 44 | [top_elems["lang"]], 45 | [elem for elems in elem_list for elem in elems.values()], 46 | ) 47 | 48 | top_elems["lang"].change( 49 | manager.gen_label, 50 | [top_elems["lang"]], 51 | [elem for elems in elem_list for elem in elems.values()], 52 | queue=False 53 | ) 54 | 55 | return demo 56 | 57 | 58 | def create_web_demo() -> gr.Blocks: 59 | chat_model = WebChatModel(lazy_init=False) 60 | 61 | with gr.Blocks(title="Web Demo", css=CSS) as demo: 62 | lang = gr.Dropdown(choices=["en", "zh"], value="en") 63 | 64 | _, _, _, chat_elems = create_chat_box(chat_model, visible=True) 65 | 66 | manager = Manager([{"lang": lang}, chat_elems]) 67 | 68 | demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values())) 69 | 70 | lang.select(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False) 71 | 72 | return demo 73 | 74 | 75 | if __name__ == "__main__": 76 | demo = create_ui() 77 | demo.queue() 78 | demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True) 79 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/webui/manager.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from gradio.components import Component 3 | from typing import Any, Dict, List 4 | 5 | from llmtuner.webui.common import get_model_path, list_dataset, load_config 6 | from llmtuner.webui.locales import LOCALES 7 | from llmtuner.webui.utils import get_time 8 | 9 | 10 | class Manager: 11 | 12 | def __init__(self, elem_list: List[Dict[str, Component]]): 13 | self.elem_list = elem_list 14 | 15 | def gen_refresh(self, lang: str) -> Dict[str, Any]: 16 | refresh_dict = { 17 | "dataset": {"choices": list_dataset()["choices"]}, 18 | "output_dir": {"value": get_time()} 19 | } 20 | 21 | user_config = load_config() 22 | if not lang: 23 | if user_config.get("lang", None): 24 | lang = user_config["lang"] 25 | else: 26 | lang = "en" 27 | 28 | refresh_dict["lang"] = {"value": lang} 29 | 30 | if user_config.get("last_model", None): 31 | refresh_dict["model_name"] = {"value": user_config["last_model"]} 32 | refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])} 33 | 34 | return refresh_dict 35 | 36 | def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING 37 | update_dict = {} 38 | refresh_dict = self.gen_refresh(lang) 39 | 40 | for elems in self.elem_list: 41 | for name, component in elems.items(): 42 | update_dict[component] = gr.update( 43 | **LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {}) 44 | ) 45 | 46 | return update_dict 47 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/llmtuner/webui/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import gradio as gr 4 | import matplotlib.figure 5 | import matplotlib.pyplot as plt 6 | from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple 7 | from datetime import datetime 8 | 9 | from llmtuner.extras.ploting import smooth 10 | from llmtuner.tuner import export_model 11 | from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG 12 | from llmtuner.webui.locales import ALERTS 13 | 14 | if TYPE_CHECKING: 15 | from llmtuner.extras.callbacks import LogCallback 16 | 17 | 18 | def update_process_bar(callback: "LogCallback") -> Dict[str, Any]: 19 | if not callback.max_steps: 20 | return gr.update(visible=False) 21 | 22 | percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0 23 | label = "Running {:d}/{:d}: {} < {}".format( 24 | callback.cur_steps, 25 | callback.max_steps, 26 | callback.elapsed_time, 27 | callback.remaining_time 28 | ) 29 | return gr.update(label=label, value=percentage, visible=True) 30 | 31 | 32 | def get_time() -> str: 33 | return datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 34 | 35 | 36 | def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]: 37 | with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: 38 | dataset_info = json.load(f) 39 | 40 | if ( 41 | len(dataset) > 0 42 | and "file_name" in dataset_info[dataset[0]] 43 | and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])) 44 | ): 45 | return gr.update(interactive=True) 46 | else: 47 | return gr.update(interactive=False) 48 | 49 | 50 | def get_preview( 51 | dataset_dir: str, dataset: list, start: Optional[int] = 0, end: Optional[int] = 2 52 | ) -> Tuple[int, list, Dict[str, Any]]: 53 | with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: 54 | dataset_info = json.load(f) 55 | 56 | data_file: str = dataset_info[dataset[0]]["file_name"] 57 | with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f: 58 | if data_file.endswith(".json"): 59 | data = json.load(f) 60 | elif data_file.endswith(".jsonl"): 61 | data = [json.loads(line) for line in f] 62 | else: 63 | data = [line for line in f] 64 | return len(data), data[start:end], gr.update(visible=True) 65 | 66 | 67 | def can_quantize(finetuning_type: str) -> Dict[str, Any]: 68 | if finetuning_type != "lora": 69 | return gr.update(value="None", interactive=False) 70 | else: 71 | return gr.update(interactive=True) 72 | 73 | 74 | def gen_cmd(args: Dict[str, Any]) -> str: 75 | if args.get("do_train", None): 76 | args["plot_loss"] = True 77 | cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py "] 78 | for k, v in args.items(): 79 | if v is not None and v != "": 80 | cmd_lines.append(" --{} {} ".format(k, str(v))) 81 | cmd_text = "\\\n".join(cmd_lines) 82 | cmd_text = "```bash\n{}\n```".format(cmd_text) 83 | return cmd_text 84 | 85 | 86 | def get_eval_results(path: os.PathLike) -> str: 87 | with open(path, "r", encoding="utf-8") as f: 88 | result = json.dumps(json.load(f), indent=4) 89 | return "```json\n{}\n```\n".format(result) 90 | 91 | 92 | def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure: 93 | log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl") 94 | if not os.path.isfile(log_file): 95 | return None 96 | 97 | plt.close("all") 98 | fig = plt.figure() 99 | ax = fig.add_subplot(111) 100 | steps, losses = [], [] 101 | with open(log_file, "r", encoding="utf-8") as f: 102 | for line in f: 103 | log_info = json.loads(line) 104 | if log_info.get("loss", None): 105 | steps.append(log_info["current_steps"]) 106 | losses.append(log_info["loss"]) 107 | 108 | if len(losses) == 0: 109 | return None 110 | 111 | ax.plot(steps, losses, alpha=0.4, label="original") 112 | ax.plot(steps, smooth(losses), label="smoothed") 113 | ax.legend() 114 | ax.set_xlabel("step") 115 | ax.set_ylabel("loss") 116 | return fig 117 | 118 | 119 | def save_model( 120 | lang: str, 121 | model_name: str, 122 | checkpoints: List[str], 123 | finetuning_type: str, 124 | template: str, 125 | max_shard_size: int, 126 | save_dir: str 127 | ) -> Generator[str, None, None]: 128 | if not model_name: 129 | yield ALERTS["err_no_model"][lang] 130 | return 131 | 132 | model_name_or_path = get_model_path(model_name) 133 | if not model_name_or_path: 134 | yield ALERTS["err_no_path"][lang] 135 | return 136 | 137 | if not checkpoints: 138 | yield ALERTS["err_no_checkpoint"][lang] 139 | return 140 | 141 | checkpoint_dir = ",".join( 142 | [get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints] 143 | ) 144 | 145 | if not save_dir: 146 | yield ALERTS["err_no_save_dir"][lang] 147 | return 148 | 149 | args = dict( 150 | model_name_or_path=model_name_or_path, 151 | checkpoint_dir=checkpoint_dir, 152 | finetuning_type=finetuning_type, 153 | template=template, 154 | output_dir=save_dir 155 | ) 156 | 157 | yield ALERTS["info_exporting"][lang] 158 | export_model(args, max_shard_size="{}GB".format(max_shard_size)) 159 | yield ALERTS["info_exported"][lang] 160 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/pro_model/cross_token.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | 6 | 7 | 8 | class ConsistencyModule(nn.Module): 9 | 10 | def __init__(self, embed_dim): 11 | super(ConsistencyModule, self).__init__() 12 | self.entity_cross_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=4,batch_first=True) 13 | self.relation_cross_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=4,batch_first=True) 14 | self.entity_self_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=4,batch_first=True) 15 | self.relation_self_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=4,batch_first=True) 16 | 17 | 18 | def forward(self, entity, relation, subgraph): 19 | 20 | #bs*num*hs 21 | entity_subgraph_attn, _ = self.entity_cross_attention(entity, subgraph, subgraph) 22 | relation_subgraph_attn, _ = self.relation_cross_attention(relation, subgraph, subgraph) 23 | 24 | entity_mlp_output, _ = self.entity_self_attention(entity, entity, entity) 25 | relation_mlp_output, _ = self.relation_self_attention(relation, relation, relation) 26 | return entity_subgraph_attn, relation_subgraph_attn, entity_mlp_output, relation_mlp_output 27 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/pro_model/gate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | 6 | class MLP(nn.Module): 7 | def __init__(self, input_dim, output_dim): 8 | super(MLP, self).__init__() 9 | self.linear1 = nn.Linear(input_dim, input_dim // 2) 10 | self.gelu = nn.GELU() 11 | self.linear2 = nn.Linear( input_dim // 2, output_dim) 12 | 13 | def forward(self, x): 14 | x = self.linear1(x) 15 | x = self.gelu(x) 16 | x = self.linear2(x) 17 | return x 18 | 19 | class SiameseNetwork(nn.Module): 20 | def __init__(self, embedding_dim): 21 | super(SiameseNetwork, self).__init__() 22 | self.mlp = MLP(embedding_dim, embedding_dim) 23 | 24 | def forward(self, question, retrieval): 25 | question = self.mlp(question) 26 | retrieval = self.mlp(retrieval) 27 | simi = torch.bmm(retrieval,question.permute(0,2,1)) # B x L_retrieval x L_question 28 | simi = torch.mean(simi,dim=-1,keepdim=True) # B x L_retrieval x 1 29 | simi = torch.sigmoid(simi) 30 | return simi 31 | 32 | 33 | class GatingModule(nn.Module): 34 | def __init__(self, token_dim): 35 | super(GatingModule, self).__init__() 36 | self.dense_question_mlp = MLP(token_dim, token_dim) 37 | self.siamese_entity = SiameseNetwork(token_dim) 38 | self.siamese_relation = SiameseNetwork(token_dim) 39 | self.siamese_subgraph = SiameseNetwork(token_dim) 40 | 41 | #bs*length*embed 42 | def forward(self, question, entity_token, relation_token, subgraph_token): 43 | dense_question = self.dense_question_mlp(question) 44 | 45 | gate_entity = self.siamese_entity(dense_question, entity_token) 46 | gate_relation = self.siamese_relation(dense_question, relation_token) 47 | gate_subgraph = self.siamese_subgraph(dense_question, subgraph_token) 48 | 49 | weighted_entity = (gate_entity * entity_token) # B x L_retrieval x 1 and B x L_retrieval x embed 50 | weighted_relation = (gate_relation * relation_token) 51 | weighted_subgraph = (gate_subgraph * subgraph_token) 52 | 53 | combined_embedding = torch.cat([ weighted_entity, weighted_relation, weighted_subgraph], dim=1) 54 | 55 | return combined_embedding -------------------------------------------------------------------------------- /LLMs/LLaMA/src/pro_model/map_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | 5 | class Mapnet(nn.Module): 6 | def __init__(self, word_embeddings, hidden_size): 7 | super(Mapnet, self).__init__() 8 | self.word_embeddings = word_embeddings # 4* topk * 256 * 1024 9 | self.entity_map_module = nn.Sequential(OrderedDict([ 10 | ("linear1", nn.Linear(hidden_size, hidden_size // 16)), 11 | ("gelu", nn.GELU()), 12 | ("linear2", nn.Linear(hidden_size // 16, hidden_size)) 13 | ])) 14 | self.relation_map_module = nn.Sequential(OrderedDict([ 15 | ("linear1", nn.Linear(hidden_size, hidden_size // 16)), 16 | ("gelu", nn.GELU()), 17 | ("linear2", nn.Linear(hidden_size // 16, hidden_size)) 18 | ])) 19 | self.subgraph_map_module = nn.Sequential(OrderedDict([ 20 | ("linear1", nn.Linear(hidden_size, hidden_size // 16)), 21 | ("gelu", nn.GELU()), 22 | ("linear2", nn.Linear(hidden_size // 16, hidden_size)) 23 | ])) 24 | 25 | 26 | 27 | 28 | def forward(self, entitys, relations, subgraphs): 29 | entity_embs = self.word_embeddings(entitys) 30 | relation_embs = self.word_embeddings(relations) 31 | subgraph_embs = self.word_embeddings(subgraphs) 32 | 33 | 34 | mean_entity_embs = torch.mean(entity_embs, dim=2) 35 | entity_vector = self.entity_map_module(mean_entity_embs) # 4* topk * 1024 36 | 37 | mean_relation_embs = torch.mean(relation_embs, dim=2) 38 | relation_vector = self.relation_map_module(mean_relation_embs) 39 | 40 | mean_subgraph_embs = torch.mean(subgraph_embs, dim=2) 41 | subgraph_vector = self.subgraph_map_module(mean_subgraph_embs) 42 | 43 | return entity_vector, relation_vector, subgraph_vector 44 | 45 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/pro_model/pm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import inspect 3 | from peft.utils import _get_batch_size 4 | from .map_layer import Mapnet 5 | from .cross_token import ConsistencyModule 6 | from .gate import GatingModule 7 | import torch.nn as nn 8 | import warnings 9 | 10 | class PromptTuningModelForCausalLM(nn.Module): 11 | 12 | def __init__( 13 | self, model: torch.nn.Module, soft_prompt_length, **kwargs 14 | ) -> None: 15 | super().__init__() 16 | self.model = model 17 | self.config = self.model.config 18 | self.soft_token_num = soft_prompt_length 19 | self.model_prepare_inputs_for_generation = self.model.base_model.model.prepare_inputs_for_generation 20 | self.mapnet = Mapnet(self.model.base_model.model.model.embed_tokens, self.config.hidden_size) 21 | self.consistnet = ConsistencyModule(self.config.hidden_size) 22 | self.gatenet= GatingModule(self.config.hidden_size) 23 | self.prefix_encoder = self.create_prefix_encoder(self.soft_token_num) 24 | self.prefix_tokens = torch.arange(self.soft_token_num).to(torch.int64) 25 | 26 | 27 | 28 | def create_prefix_encoder(self, num_prefix_tokens): 29 | prefix_encoder = nn.Embedding(num_prefix_tokens, self.config.hidden_size) 30 | # prefix_encoder = nn.Sequential(prefix_embedding) 31 | return prefix_encoder 32 | 33 | 34 | def get_soft_prompts(self,batch_size): 35 | prefix_encoder = self.prefix_encoder 36 | table_prompt_tokens = ( 37 | self.prefix_tokens 38 | .unsqueeze(0) 39 | .expand(batch_size, -1) 40 | ).to(prefix_encoder.weight.device) 41 | soft_prompts = prefix_encoder(table_prompt_tokens) 42 | return soft_prompts 43 | 44 | 45 | def forward( 46 | self, 47 | input_ids=None, 48 | gate_ids=None, 49 | entitys=None, 50 | relations=None, 51 | subgraphs=None, 52 | attention_mask=None, 53 | inputs_embeds=None, 54 | labels=None, 55 | output_attentions=None, 56 | output_hidden_states=None, 57 | return_dict=None, 58 | **kwargs, 59 | ): 60 | 61 | 62 | batch_size = _get_batch_size(input_ids, inputs_embeds) 63 | 64 | 65 | # 1. prepare input_embeds 66 | if inputs_embeds is None: 67 | inputs_embeds = self.model.base_model.model.model.embed_tokens(input_ids) 68 | gate_embeds = self.model.base_model.model.model.embed_tokens(gate_ids) 69 | 70 | # 2. 71 | entity_vector, relation_vector, subgraph_vector = self.mapnet(entitys, relations, subgraphs) 72 | # 3. 73 | es_consistency_token, rs_consistency_token, e_consistency_token, r_consistency_token = self.consistnet(entity_vector, relation_vector, subgraph_vector) 74 | #4. 75 | combined_embedding = self.gatenet(gate_embeds, e_consistency_token, r_consistency_token, es_consistency_token+rs_consistency_token).to(inputs_embeds.dtype) 76 | #5. 77 | soft_prompts = self.get_soft_prompts(batch_size).to(inputs_embeds.dtype) 78 | 79 | all_prefix = torch.cat((soft_prompts,combined_embedding), dim=1) 80 | prefix_length = all_prefix.shape[1] 81 | 82 | if attention_mask is not None: 83 | # concat prompt attention mask 84 | prefix_attention_mask = torch.ones(batch_size, prefix_length).to(attention_mask.device) 85 | attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) 86 | 87 | if kwargs.get("position_ids", None) is not None: 88 | warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") 89 | kwargs["position_ids"] = None 90 | if kwargs.get("token_type_ids", None) is not None: 91 | warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids") 92 | kwargs["token_type_ids"] = None 93 | kwargs.update( 94 | { 95 | "attention_mask": attention_mask, 96 | "labels": labels, 97 | "output_attentions": output_attentions, 98 | "output_hidden_states": output_hidden_states, 99 | "return_dict": return_dict, 100 | } 101 | ) 102 | # # concat prompt labels 103 | if labels is not None: 104 | prefix_labels = torch.full((batch_size, prefix_length), -100).to(labels.device) 105 | kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1) 106 | 107 | 108 | 109 | inputs_embeds = torch.cat((all_prefix, inputs_embeds), dim=1) 110 | return self.model(inputs_embeds=inputs_embeds, **kwargs) 111 | 112 | 113 | def inference(self, **kwargs): 114 | 115 | self.num_beams = kwargs['generation_config'].num_beams 116 | 117 | self.gate_ids = kwargs['gate_ids'].repeat(self.num_beams, 1) 118 | self.entitys = kwargs['entitys'].repeat(self.num_beams, 1, 1) 119 | self.relations = kwargs['relations'].repeat(self.num_beams, 1, 1) 120 | self.subgraphs = kwargs['subgraphs'].repeat(self.num_beams, 1, 1) 121 | 122 | del kwargs["gate_ids"] 123 | del kwargs["entitys"] 124 | del kwargs["relations"] 125 | del kwargs["subgraphs"] 126 | 127 | outputs = self.generate( 128 | **kwargs 129 | ) 130 | 131 | return outputs 132 | 133 | 134 | 135 | 136 | 137 | def generate(self, **kwargs): 138 | self.model.base_model.model.prepare_inputs_for_generation = self.prepare_inputs_for_generation 139 | 140 | 141 | self.model.base_model.generation_config = kwargs['generation_config'] 142 | self.model.base_model.model.generation_config = kwargs['generation_config'] 143 | try: 144 | 145 | outputs = self.model.generate(**kwargs) 146 | except: 147 | self.model.base_model.model.prepare_inputs_for_generation = self.model_prepare_inputs_for_generation 148 | raise 149 | else: 150 | self.model.base_model.model.prepare_inputs_for_generation = self.model_prepare_inputs_for_generation 151 | return outputs 152 | 153 | 154 | 155 | def prepare_inputs_for_generation(self, *args, **kwargs): 156 | model_kwargs = self.model_prepare_inputs_for_generation(*args, **kwargs) 157 | batch_size = model_kwargs["input_ids"].shape[0] 158 | 159 | if model_kwargs.get("position_ids", None) is not None: 160 | warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") 161 | model_kwargs["position_ids"] = None 162 | 163 | 164 | 165 | if model_kwargs["past_key_values"] is None: 166 | batch_size=model_kwargs["input_ids"].shape[0] 167 | self.device = model_kwargs["input_ids"].device 168 | self.batch_size = batch_size 169 | inputs_embeds = self.model.base_model.model.model.embed_tokens(model_kwargs["input_ids"]) 170 | gate_embeds = self.model.base_model.model.model.embed_tokens(self.gate_ids) 171 | 172 | entity_vector, relation_vector, subgraph_vector = self.mapnet(self.entitys, self.relations, self.subgraphs) 173 | es_consistency_token, rs_consistency_token, e_consistency_token, r_consistency_token = self.consistnet(entity_vector, relation_vector, subgraph_vector) 174 | combined_embedding = self.gatenet(gate_embeds, e_consistency_token, r_consistency_token, es_consistency_token+rs_consistency_token).to(inputs_embeds.dtype) 175 | soft_prompts = self.get_soft_prompts(batch_size).to(inputs_embeds.dtype) 176 | 177 | all_prefix = torch.cat((soft_prompts,combined_embedding), dim=1) 178 | self.prefix_length = all_prefix.shape[1] 179 | inputs_embeds = torch.cat((all_prefix, inputs_embeds), dim=1) 180 | model_kwargs["inputs_embeds"] = inputs_embeds 181 | model_kwargs["input_ids"] = None 182 | 183 | 184 | if model_kwargs.get("attention_mask", None) is not None: 185 | 186 | prefix_attention_mask = torch.ones( 187 | self.batch_size, self.prefix_length 188 | ).to(self.device) 189 | model_kwargs["attention_mask"] = torch.cat( 190 | (prefix_attention_mask, model_kwargs["attention_mask"]), dim=1 191 | ) 192 | 193 | 194 | 195 | 196 | return model_kwargs 197 | -------------------------------------------------------------------------------- /LLMs/LLaMA/src/pro_model/totoken.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import json 3 | import math 4 | def process_ent_list(ent_list, data_args): 5 | if len(ent_list) >= data_args.topk: 6 | return ent_list[:data_args.topk] 7 | else: 8 | new_ent_list = ent_list * math.ceil(data_args.topk / len(ent_list)) 9 | return new_ent_list[:data_args.topk] 10 | 11 | def data_load_retrieval(data_args): 12 | 13 | if 'WebQSP' in data_args.dataset: 14 | dataset = 'WebQSP' 15 | splits = ['train','test'] 16 | else: 17 | dataset = 'CWQ' 18 | splits = ['train','test','dev'] 19 | id2rel = {} 20 | for split in splits: 21 | with open(f'data/retrieval_data/{dataset}_{split}_cand_rels_sorted.json','r') as f: 22 | id2rel.update(json.load(f)) 23 | id2rel = {Id:rel[:data_args.topk] for Id, rel in id2rel.items()} 24 | 25 | id2ent_ = {} 26 | for split in splits: 27 | with open(f'data/retrieval_data/{dataset}_{split}_merged_cand_entities_elq_facc1.json','r') as f: 28 | id2ent_.update(json.load(f)) 29 | id2ent = {} 30 | for Id, ent in id2ent_.items(): 31 | ent = [en['label'] for en in process_ent_list(ent, data_args)] 32 | id2ent[Id] = ent 33 | 34 | subgraphs = [] 35 | for split in splits: 36 | with open(f'data/retrieval_data/{dataset}_{split}_subgraph_BM25.json','r') as f: 37 | subgraphs.extend(json.load(f)) 38 | id2subgraph = {} 39 | for subg in subgraphs: 40 | id2subgraph[subg['QuestionId']] = [ctx['text'] for ctx in subg['ctxs'][:data_args.topk]] 41 | 42 | return id2rel, id2ent, id2subgraph 43 | 44 | def get_ids(tokenizer, sequence, max_length): 45 | sequence = tokenizer.encode(sequence) 46 | sequence = [tokenizer.bos_token_id] + sequence + [tokenizer.eos_token_id] 47 | if len(sequence) < max_length: 48 | sequence = sequence + [tokenizer.pad_token_id] * (max_length - len(sequence)) 49 | return sequence[:max_length] 50 | 51 | def get_extra_input_ids(tokenizer, 52 | entitys: List[str], 53 | relations: List[str], 54 | subgraphs: List[str], 55 | max_length: int 56 | ): 57 | 58 | entity_ids = [get_ids(tokenizer, entity, max_length) for entity in entitys] 59 | relation_ids = [get_ids(tokenizer, relation, max_length) for relation in relations] 60 | subgraph_ids = [get_ids(tokenizer, subgraph, max_length) for subgraph in subgraphs] 61 | 62 | 63 | return entity_ids, relation_ids, subgraph_ids -------------------------------------------------------------------------------- /LLMs/LLaMA/src/train_bash.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["WANDB_DISABLED"] = "true" 3 | 4 | from llmtuner import run_exp 5 | 6 | 7 | def main(): 8 | run_exp() 9 | 10 | 11 | def _mp_fn(index): 12 | # For xla_spawn (TPUs) 13 | main() 14 | 15 | 16 | if __name__ == "__main__": 17 | main() 18 | -------------------------------------------------------------------------------- /LLMs/data_id/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "WebQSP_Freebase_NQ_train": { 3 | "script_url": "WebQSP_Freebase_NQ_train", 4 | "columns": { 5 | "prompt": "instruction", 6 | "query": "input", 7 | "response": "output", 8 | "history": "history" 9 | } 10 | }, 11 | "WebQSP_Freebase_NQ_test": { 12 | "script_url": "WebQSP_Freebase_NQ_test", 13 | "columns": { 14 | "prompt": "instruction", 15 | "query": "input", 16 | "response": "output", 17 | "history": "history" 18 | } 19 | }, 20 | "CWQ_Freebase_NQ_train": { 21 | "script_url": "CWQ_Freebase_NQ_train", 22 | "columns": { 23 | "prompt": "instruction", 24 | "query": "input", 25 | "response": "output", 26 | "history": "history" 27 | } 28 | }, 29 | "CWQ_Freebase_NQ_test": { 30 | "script_url": "CWQ_Freebase_NQ_test", 31 | "columns": { 32 | "prompt": "instruction", 33 | "query": "input", 34 | "response": "output", 35 | "history": "history" 36 | } 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AMAR 2 | 3 | This is the code of the paper 'Harnessing Large Language Models for Knowledge Graph Question Answering via Adaptive Multi-Aspect Retrieval-Augmentation' (Accepted by AAAI 2025). 4 | 5 | # Environment Setup 6 | conda create -n amar python=3.9.18 7 | pip install -r requirement.txt 8 | 9 | ### Freebase Setup 10 | Both datasets use Freebase as the knowledge source. You may refer to [Freebase Virtuoso Setup](https://github.com/dki-lab/Freebase-Setup) to set up a Virtuoso triplestore service. We briefly list some key steps below: 11 | 12 | 13 | Download OpenLink Virtuoso from https://github.com/openlink/virtuoso-opensource/releases, and put it in `Amar/` 14 | 15 | Env setting: 16 | 17 | sudo apt install unixodbc unixodbc-dev 18 | 19 | Download Database: 20 | 21 | cd Freebase-Setup 22 | wget https://www.dropbox.com/s/q38g0fwx1a3lz8q/virtuoso_db.zip 23 | tar -zxvf virtuoso_db.zip 24 | 25 | to start the Virtuoso service: 26 | 27 | python3 virtuoso.py start 3001 -d virtuoso_db 28 | 29 | and to stop a currently running service at the same port: 30 | 31 | python3 virtuoso.py stop 3001 32 | 33 | 34 | ## KGQA Datasets and Retrieval data 35 | 36 | Download from [Google drive](https://drive.google.com/drive/folders/1uOcpPoBcFeL2JWE6-Wpj6kuR-7Vdgiyz?usp=sharing) or [Baidu drive](https://pan.baidu.com/s/12w4bCqFhDKp6iPVW6i2cFw?pwd=p9da), and unzip data.zip to `Amar/data` 37 | 38 | 39 | More details of entity/relation retrieval can be found in [GMT-KBQA](https://github.com/HXX97/GMT-KBQA), and subgraph retrieval can be found in [DECAF](https://github.com/awslabs/decode-answer-logical-form). 40 | 41 | ## Reproduction from sft checkpoint 42 | Change the `--model_name_or_path` in `run_ft.sh` to your LLM checkpoint path. 43 | 44 | Download checkpoints from [Google drive](https://drive.google.com/drive/folders/1uOcpPoBcFeL2JWE6-Wpj6kuR-7Vdgiyz?usp=sharing). The checkpoint files are located in `WebQSP_webqsp_100_7_32_16/checkpoint/`. 45 | 46 | We also included the Inference results (`evaluation_beam/beam_test_gen_statistics.json` and `evaluation_beam/beam_test_top_k_predictions.json`) and the results of querying on Freebase (all files with prefix `beam_test_top_k_predictions.json_*`) 47 | 48 | ## Reproduction from scratch 49 | 50 | 51 | Change the `--model_name_or_path` in `run_ft.sh` to your LLM checkpoint path. 52 | 53 | Reproduce the results for CWQ and WebQSP by executing the following: 54 | 55 | bash run_all.sh 56 | 57 | Alternatively, you can run the commands step-by-step as shown below: 58 | ### Finetuning 59 | CUDA_VISIBLE_DEVICES=0 bash run_ft.sh WebQSP LLaMA-2-7b-hf webqsp_100_7_32_16 train 100 7 0 32 16 15 60 | CUDA_VISIBLE_DEVICES=0 bash run_ft.sh CWQ LLaMA-2-13b-hf cwq_4_16_32_16 train 4 16 0 32 16 8 61 | 62 | ### Inference 63 | CUDA_VISIBLE_DEVICES=0 bash run_ft.sh WebQSP LLaMA-2-7b-hf webqsp_100_7_32_16 test 100 7 0 32 16 15 64 | CUDA_VISIBLE_DEVICES=0 bash run_ft.sh CWQ LLaMA-2-13b-hf cwq_4_16_32_16 test 4 16 0 32 16 8 65 | 66 | ### Querying on Freebase 67 | CUDA_VISIBLE_DEVICES=0 python -u eval_final.py --dataset WebQSP --pred_file Reading/LLaMA-2-7b-hf/WebQSP_webqsp_100_7_32_16/evaluation_beam/beam_test_top_k_predictions.json 68 | CUDA_VISIBLE_DEVICES=0 python -u eval_final.py --dataset CWQ --pred_file Reading/LLaMA-2-13b-hf/CWQ_cwq_4_16_32_16/evaluation_beam/beam_test_top_k_predictions.json 69 | 70 | Querying with golden entity: 71 | 72 | CUDA_VISIBLE_DEVICES=0 python -u eval_final.py --dataset WebQSP --pred_file Reading/LLaMA-2-7b-hf/WebQSP_webqsp_100_7_32_16/evaluation_beam/beam_test_top_k_predictions.json --golden_ent 73 | CUDA_VISIBLE_DEVICES=0 python -u eval_final.py --dataset CWQ --pred_file Reading/LLaMA-2-13b-hf/CWQ_cwq_4_16_32_16/evaluation_beam/beam_test_top_k_predictions.json --golden_ent 74 | 75 | 76 | This repo refers to [ChatKBQA](https://github.com/LHRLAB/ChatKBQA), [GMT-KBQA](https://github.com/HXX97/GMT-KBQA) and [DECAF](https://github.com/awslabs/decode-answer-logical-form). Thanks for their great jobs! 77 | -------------------------------------------------------------------------------- /components/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | class ListDataset(Dataset): 4 | def __init__(self, examples): 5 | self.examples = examples 6 | 7 | def __len__(self): 8 | return len(self.examples) 9 | 10 | def __getitem__(self, i): 11 | return self.examples[i] 12 | 13 | def __iter__(self): 14 | return iter(self.examples) 15 | 16 | class LFCandidate: 17 | def __init__(self, s_expr, normed_expr, ex=None, f1=None, edist=None): 18 | self.s_expr = s_expr 19 | self.normed_expr = normed_expr 20 | self.ex = ex 21 | self.f1 = f1 22 | self.edist = edist 23 | 24 | def __str__(self): 25 | return '{}\n\t->{}\n'.format(self.s_expr, self.normed_expr) 26 | 27 | def __repr__(self): 28 | return self.__str__() 29 | -------------------------------------------------------------------------------- /components/utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | import os 4 | import shutil 5 | import re 6 | from typing import List 7 | from executor.sparql_executor import get_label_with_odbc 8 | 9 | 10 | def dump_to_bin(obj, fname): 11 | with open(fname, "wb") as f: 12 | pickle.dump(obj, f) 13 | 14 | 15 | def load_bin(fname): 16 | with open(fname, "rb") as f: 17 | return pickle.load(f) 18 | 19 | 20 | def load_json(fname, mode="r", encoding="utf8"): 21 | if "b" in mode: 22 | encoding = None 23 | with open(fname, mode=mode, encoding=encoding) as f: 24 | return json.load(f) 25 | 26 | 27 | def dump_json(obj, fname, indent=4, mode='w' ,encoding="utf8", ensure_ascii=False): 28 | if "b" in mode: 29 | encoding = None 30 | with open(fname, "w", encoding=encoding) as f: 31 | return json.dump(obj, f, indent=indent, ensure_ascii=ensure_ascii) 32 | 33 | 34 | def mkdir_f(prefix): 35 | if os.path.exists(prefix): 36 | shutil.rmtree(prefix) 37 | os.makedirs(prefix) 38 | 39 | 40 | def mkdir_p(prefix): 41 | if not os.path.exists(prefix): 42 | os.makedirs(prefix) 43 | 44 | 45 | illegal_xml_re = re.compile(u'[\x00-\x08\x0b-\x1f\x7f-\x84\x86-\x9f\ud800-\udfff\ufdd0-\ufddf\ufffe-\uffff]') 46 | def clean_str(s: str) -> str: 47 | """remove illegal unicode characters""" 48 | return illegal_xml_re.sub('',s) 49 | 50 | 51 | 52 | def tokenize_s_expr(expr): 53 | expr = expr.replace('(', ' ( ') 54 | expr = expr.replace(')', ' ) ') 55 | toks = expr.split(' ') 56 | toks = [x for x in toks if len(x)] 57 | return toks 58 | 59 | def extract_mentioned_entities_from_sexpr(expr:str) -> List[str]: 60 | expr = expr.replace('(', ' ( ') 61 | expr = expr.replace(')', ' ) ') 62 | toks = expr.split(' ') 63 | toks = [x for x in toks if len(x)] 64 | entitiy_tokens = [] 65 | for t in toks: 66 | # normalize entity 67 | if t.startswith('m.') or t.startswith('g.'): 68 | entitiy_tokens.append(t) 69 | return entitiy_tokens 70 | 71 | def extract_mentioned_entities_from_sparql(sparql:str) -> List[str]: 72 | """extract entity from sparql""" 73 | sparql = sparql.replace('(',' ( ').replace(')',' ) ') 74 | toks = sparql.split(' ') 75 | toks = [x.replace('\t.','') for x in toks if len(x)] 76 | entity_tokens = [] 77 | for t in toks: 78 | if t.startswith('ns:m.') or t.startswith('ns:g.'): 79 | entity_tokens.append(t[3:]) 80 | 81 | entity_tokens = list(set(entity_tokens)) 82 | return entity_tokens 83 | 84 | def extract_mentioned_relations_from_sparql(sparql:str): 85 | """extract relation from sparql""" 86 | sparql = sparql.replace('(',' ( ').replace(')',' ) ') 87 | toks = sparql.split(' ') 88 | toks = [x for x in toks if len(x)] 89 | relation_tokens = [] 90 | for t in toks: 91 | if (re.match("ns:[a-zA-Z_0-9]*\.[a-zA-Z_0-9]*\.[a-zA-Z_0-9]*",t.strip()) 92 | or re.match("ns:[a-zA-Z_0-9]*\.[a-zA-Z_0-9]*\.[a-zA-Z_0-9]*\.[a-zA-Z_0-9]*",t.strip())): 93 | relation_tokens.append(t[3:]) 94 | 95 | relation_tokens = list(set(relation_tokens)) 96 | return relation_tokens 97 | 98 | 99 | def extract_mentioned_relations_from_sexpr(sexpr:str)->List[str]: 100 | sexpr = sexpr.replace('(',' ( ').replace(')',' ) ') 101 | toks = sexpr.split(' ') 102 | toks = [x for x in toks if len(x)] 103 | relation_tokens = [] 104 | 105 | for t in toks: 106 | if (re.match("[a-zA-Z_]*\.[a-zA-Z_]*\.[a-zA-z_]*",t.strip()) 107 | or re.match("[a-zA-Z_]*\.[a-zA-Z_]*\.[a-zA-Z_]*\.[a-zA-Z_]*",t.strip())): 108 | relation_tokens.append(t) 109 | relation_tokens = list(set(relation_tokens)) 110 | return relation_tokens 111 | 112 | def vanilla_sexpr_linearization_method(expr, entity_label_map={}, relation_label_map={}, linear_origin_map={}): 113 | """ 114 | textualize a logical form, replace mids with labels 115 | 116 | Returns: 117 | (str): normalized s_expr 118 | """ 119 | expr = expr.replace("(", " ( ") # add space for parantheses 120 | expr = expr.replace(")", " ) ") 121 | toks = expr.split(" ") # split by space 122 | toks = [x for x in toks if len(x)] 123 | 124 | norm_toks = [] 125 | for t in toks: 126 | 127 | # original token 128 | origin_t = t 129 | 130 | if t.startswith("m.") or t.startswith("g."): # replace entity with its name 131 | if t in entity_label_map: 132 | t = entity_label_map[t] 133 | else: 134 | # name = get_label(t) 135 | name = get_label_with_odbc(t) 136 | if name is not None: 137 | entity_label_map[t] = name 138 | t = name 139 | t = '[ '+t+' ]' 140 | elif "XMLSchema" in t: # remove xml type 141 | format_pos = t.find("^^") 142 | t = t[:format_pos] 143 | elif t == "ge": # replace ge/gt/le/lt 144 | t = "GREATER EQUAL" 145 | elif t == "gt": 146 | t = "GREATER THAN" 147 | elif t == "le": 148 | t = "LESS EQUAL" 149 | elif t == "lt": 150 | t = "LESS THAN" 151 | else: 152 | t = t.replace("_", " ") # replace "_" with " " 153 | t = t.replace(".", " , ") # replace "." with " , " 154 | 155 | if "." in origin_t: # relation 156 | t = "[ "+t+" ]" 157 | relation_label_map[origin_t]=t 158 | 159 | norm_toks.append(t) 160 | linear_origin_map[t] = origin_t # for reverse transduction 161 | 162 | return " ".join(norm_toks) 163 | 164 | def _textualize_relation(r): 165 | """return a relation string with '_' and '.' replaced""" 166 | if "_" in r: # replace "_" with " " 167 | r = r.replace("_", " ") 168 | if "." in r: # replace "." with " , " 169 | r = r.replace(".", " , ") 170 | return r -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from transformers import ( 7 | AutoTokenizer, 8 | AutoConfig, 9 | ) 10 | 11 | ELQ_SERVICE_URL = "http://localhost:5688/entity_linking" 12 | FREEBASE_SPARQL_WRAPPER_URL = "http://localhost:8890/sparql" 13 | FREEBASE_ODBC_PORT = "13001" 14 | 15 | def set_seed(args): 16 | random.seed(args.seed) 17 | np.random.seed(args.seed) 18 | torch.manual_seed(args.seed) 19 | if args.n_gpu > 0: 20 | torch.cuda.manual_seed_all(args.seed) 21 | 22 | 23 | def to_list(tensor): 24 | return tensor.detach().cpu().tolist() 25 | 26 | def register_args(parser): 27 | # Required parameters 28 | parser.add_argument( 29 | "--dataset", 30 | default=None, 31 | type=str, 32 | required=True, 33 | help="dataset to operate on", 34 | ) 35 | parser.add_argument( 36 | "--model_type", 37 | default=None, 38 | type=str, 39 | required=True, 40 | help="Model type", 41 | ) 42 | parser.add_argument( 43 | "--model_name_or_path", 44 | default=None, 45 | type=str, 46 | required=True, 47 | help="Path to pretrained model or model identifier from huggingface.co/models", 48 | ) 49 | parser.add_argument( 50 | "--output_dir", 51 | default=None, 52 | type=str, 53 | required=True, 54 | help="The output directory where the model checkpoints and predictions will be written.", 55 | ) 56 | 57 | # Other parameters 58 | parser.add_argument( 59 | "--data_dir", 60 | default=None, 61 | type=str, 62 | help="The input data dir. Should contain the .json files for the task." 63 | + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", 64 | ) 65 | parser.add_argument( 66 | "--train_file", 67 | default=None, 68 | type=str, 69 | help="The input training file. If a data dir is specified, will look for the file there" 70 | + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", 71 | ) 72 | parser.add_argument( 73 | "--predict_file", 74 | default=None, 75 | type=str, 76 | help="The input evaluation file. If a data dir is specified, will look for the file there" 77 | + "If no data dir or train/predict files are specified, will run with tensorflow_datasets.", 78 | ) 79 | parser.add_argument( 80 | "--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name" 81 | ) 82 | parser.add_argument( 83 | "--tokenizer_name", 84 | default="", 85 | type=str, 86 | help="Pretrained tokenizer name or path if not the same as model_name", 87 | ) 88 | parser.add_argument( 89 | "--cache_dir", 90 | default=None, 91 | type=str, 92 | help="Where do you want to store the pre-trained models downloaded from s3", 93 | ) 94 | 95 | parser.add_argument( 96 | "--max_seq_length", 97 | default=96, 98 | type=int, 99 | help="The maximum total input sequence length after WordPiece tokenization. Sequences " 100 | "longer than this will be truncated, and sequences shorter than this will be padded.", 101 | ) 102 | parser.add_argument("--do_train", action="store_true", help="Whether to run training.") 103 | parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.") 104 | parser.add_argument("--do_predict", action="store_true", help="Whether to do prediction.") 105 | parser.add_argument( 106 | "--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step." 107 | ) 108 | parser.add_argument( 109 | "--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model." 110 | ) 111 | 112 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.") 113 | parser.add_argument( 114 | "--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation." 115 | ) 116 | parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") 117 | parser.add_argument( 118 | "--gradient_accumulation_steps", 119 | type=int, 120 | default=1, 121 | help="Number of updates steps to accumulate before performing a backward/update pass.", 122 | ) 123 | parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") 124 | parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.") 125 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 126 | parser.add_argument( 127 | "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform." 128 | ) 129 | parser.add_argument( 130 | "--max_steps", 131 | default=-1, 132 | type=int, 133 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.", 134 | ) 135 | parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") 136 | parser.add_argument("--warmup_ratio", default=0.0, type=float, help="Linear warmup over warmup ratio.") 137 | parser.add_argument( 138 | "--verbose_logging", 139 | action="store_true", 140 | help="If true, all of the warnings related to data processing will be printed. " 141 | "A number of warnings are expected for a normal SQuAD evaluation.", 142 | ) 143 | 144 | parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.") 145 | parser.add_argument("--eval_steps", type=int, default=500, help="Eval every X updates steps.") 146 | parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.") 147 | parser.add_argument( 148 | "--eval_all_checkpoints", 149 | action="store_true", 150 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number", 151 | ) 152 | parser.add_argument( 153 | "--disable_tqdm", action="store_true", help="Disable tqdm bar" 154 | ) 155 | parser.add_argument("--num_contrast_sample", type=int, default=20, help="number of samples in a batch.") 156 | parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available") 157 | parser.add_argument( 158 | "--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory" 159 | ) 160 | parser.add_argument( 161 | "--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets" 162 | ) 163 | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") 164 | 165 | parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") 166 | parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.") 167 | parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.") 168 | 169 | parser.add_argument("--threads", type=int, default=1, help="multiple threads for converting example to features") 170 | 171 | # train curriculum 172 | parser.add_argument("--training_curriculum", default="random",type=str, choices=["random", "bootstrap", "mixbootstrap"]) 173 | parser.add_argument("--bootstrapping_start", default=None, type=int, help="when to start bootstrapping sampling") 174 | parser.add_argument("--bootstrapping_ticks", default=None, type=str, help="when to update scores for bootstrapping in addition to the startpoint") 175 | 176 | # textualizing choices 177 | parser.add_argument("--linear_method", default="vanilla",type=str, choices=["vanilla", "naive_text", "reduct_text"]) 178 | 179 | # logger 180 | parser.add_argument("--logger",default=None, help="logger") 181 | 182 | def validate_args(args): 183 | # validate before loading data 184 | if args.training_curriculum == "random": 185 | args.bootstrapping_update_epochs = [] 186 | else: 187 | assert args.bootstrapping_start is not None 188 | assert args.bootstrapping_start > 0 189 | 190 | if args.bootstrapping_ticks is None: 191 | bootstrapping_update_epochs = [args.bootstrapping_start] 192 | else: 193 | additional_update_epochs = [int(x) for x in args.bootstrapping_ticks.split(',')] 194 | bootstrapping_update_epochs = [args.bootstrapping_start] + additional_update_epochs 195 | args.bootstrapping_update_epochs = bootstrapping_update_epochs 196 | 197 | def load_untrained_model(args): 198 | args.model_type = args.model_type.lower() 199 | config = AutoConfig.from_pretrained( 200 | args.config_name if args.config_name else args.model_name_or_path, 201 | cache_dir=args.cache_dir if args.cache_dir else None, 202 | ) 203 | tokenizer = AutoTokenizer.from_pretrained( 204 | args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 205 | do_lower_case=args.do_lower_case, 206 | cache_dir=args.cache_dir if args.cache_dir else None, 207 | ) 208 | model_class = MODEL_TYPE_DICT[args.model_type] 209 | model = model_class.from_pretrained( 210 | args.model_name_or_path, 211 | from_tf=bool(".ckpt" in args.model_name_or_path), 212 | config=config, 213 | cache_dir=args.cache_dir if args.cache_dir else None, 214 | ) 215 | 216 | return config, tokenizer, model 217 | 218 | def get_model_class(args): 219 | return MODEL_TYPE_DICT[args.model_type] -------------------------------------------------------------------------------- /entity_retrieval/aqqu_util.py: -------------------------------------------------------------------------------- 1 | import re 2 | from nltk import word_tokenize 3 | 4 | 5 | def normalize_entity_name(name): 6 | name = name.lower() 7 | # name = name.replace('!', '') 8 | # name = name.replace('.', '') 9 | # name = name.replace(',', '') 10 | # name = name.replace('-', '') 11 | # name = name.replace('_', '') 12 | # name = name.replace(' ', '') 13 | # name = name.replace('\'', '') 14 | # name = name.replace('"', '') 15 | # name = name.replace('\\', '') 16 | 17 | 18 | # the following is only for freebase_complete_all_mention 19 | name = ' '.join(word_tokenize(name)) 20 | # word_tokenize from nltk will change the left " to ``, which is pretty weird. Fix it here 21 | name = name.replace('``', '"').replace("''", '"') 22 | 23 | return name 24 | 25 | 26 | def read_abbreviations(abbreviations_file): 27 | ''' 28 | Return a set of abbreviations. 29 | :param abbreviations_file: 30 | :return: 31 | ''' 32 | abbreviations = set() 33 | with open(abbreviations_file, 'r') as f: 34 | for line in f: 35 | abbreviations.add(line.strip().decode('utf-8').lower()) 36 | return abbreviations 37 | 38 | 39 | def remove_abbreviations_from_entity_name(entity_name, 40 | abbreviations): 41 | tokens = entity_name.lower().split(' ') 42 | non_abbr_tokens = [t for t in tokens if t not in abbreviations] 43 | return ' '.join(non_abbr_tokens) 44 | 45 | 46 | def remove_prefixes_from_name(name): 47 | if name.startswith('the'): 48 | name = name[3:] 49 | return name 50 | 51 | 52 | def remove_suffixes_from_name(name): 53 | if '#' in name or '(' in name: 54 | name = remove_number_suffix(name) 55 | name = remove_bracket_suffix(name) 56 | return name 57 | 58 | 59 | def remove_number_suffix(name): 60 | res = re.match(r'.*( #[0-9]+)$', name) 61 | if res: 62 | name = name[:res.start(1)] 63 | return name 64 | else: 65 | return name 66 | 67 | 68 | def remove_bracket_suffix(name): 69 | res = re.match(r'.*( \([^\(\)]+\))$', name) 70 | if res: 71 | name = name[:res.start(1)] 72 | return name 73 | else: 74 | return name -------------------------------------------------------------------------------- /generation/cwq_evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from executor import sparql_executor 3 | from components.utils import dump_json, load_json 4 | from tqdm import tqdm 5 | import os 6 | 7 | 8 | def cwq_evaluate_valid_results(args): 9 | """Compute P, R and F1 for CWQ""" 10 | pred_data = load_json(args.pred_file) 11 | # origin dataset 12 | dataset_data = load_json(f'data/CWQ/origin/ComplexWebQuestions_{args.split}.json') 13 | 14 | dataset_dict = {x["ID"]:x for x in dataset_data} 15 | 16 | p_list = [] 17 | r_list = [] 18 | f_list = [] 19 | hit_list = [] 20 | p_dict = {} 21 | r_dict = {} 22 | f_dict = {} 23 | hit_dict = {} 24 | acc_num = 0 25 | 26 | pred_dict = {} 27 | acc_qid_list = [] # Pred Answer ACC 28 | for pred in pred_data: 29 | qid = pred['qid'] 30 | pred_answer = set(pred['answer']) 31 | pred_dict[qid]=pred_answer 32 | 33 | for qid,example in tqdm(dataset_dict.items()): 34 | 35 | gt_sparql = example['sparql'] 36 | if 'answer' in example: 37 | gt_answer = set(example['answer']) 38 | else: 39 | gt_answer = set(sparql_executor.execute_query(gt_sparql)) 40 | 41 | # for dev split 42 | # gt_answer = set([item["answer_id"] for item in example["answers"]]) 43 | 44 | pred_answer = set(pred_dict.get(qid,{})) 45 | 46 | # assert len(pred_answer)>0 and len(gt_answer)>0 47 | if pred_answer == gt_answer: 48 | acc_num+=1 49 | acc_qid_list.append(qid) 50 | 51 | if len(pred_answer)== 0: 52 | if len(gt_answer)==0: 53 | p=1 54 | r=1 55 | f=1 56 | hit=1 57 | else: 58 | p=0 59 | r=0 60 | f=0 61 | hit=0 62 | elif len(gt_answer)==0: 63 | p=0 64 | r=0 65 | f=0 66 | hit=0 67 | else: 68 | p = len(pred_answer & gt_answer)/ len(pred_answer) 69 | r = len(pred_answer & gt_answer)/ len(gt_answer) 70 | f = 2*(p*r)/(p+r) if p+r>0 else 0 71 | hit = 1 if len(pred_answer & gt_answer)>0 else 0 72 | 73 | p_list.append(p) 74 | r_list.append(r) 75 | f_list.append(f) 76 | hit_list.append(hit) 77 | p_dict[qid] = p 78 | r_dict[qid] = r 79 | f_dict[qid] = f 80 | hit_dict[qid] = hit 81 | 82 | p_average = sum(p_list)/len(p_list) 83 | r_average = sum(r_list)/len(r_list) 84 | f_average = sum(f_list)/len(f_list) 85 | hits1 = sum(hit_list)/len(hit_list) 86 | 87 | res = f'Total: {len(p_list)}, ACC:{acc_num/len(p_list)}, AVGP: {p_average}, AVGR: {r_average}, AVGF: {f_average}, Hits@1: {hits1}' 88 | print(res) 89 | dirname = os.path.dirname(args.pred_file) 90 | filename = os.path.basename(args.pred_file) 91 | with open (os.path.join(dirname,f'{filename}_final_eval_results.txt'),'w') as f: 92 | f.write(res) 93 | f.flush() 94 | 95 | # Write answer acc result to prediction file 96 | for pred in pred_data: 97 | qid = pred['qid'] 98 | if qid in acc_qid_list: 99 | pred['answer_acc'] = True 100 | else: 101 | pred['answer_acc'] = False 102 | pred['precision'] = p_dict[qid] if qid in p_dict else None 103 | pred['recall'] = r_dict[qid] if qid in r_dict else None 104 | pred['f1'] = f_dict[qid] if qid in f_dict else None 105 | pred['hits1'] = hit_dict[qid] if qid in hit_dict else None 106 | 107 | dump_json(pred_data, os.path.join(dirname, f'{filename}_new.json'), indent=4) 108 | 109 | 110 | if __name__ == "__main__": 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument( 113 | "--split", 114 | type=str, 115 | required=True, 116 | help="split to operate on, can be `test`, `dev` and `train`", 117 | ) 118 | parser.add_argument( 119 | "--pred_file", type=str, default=None, help="prediction results file" 120 | ) 121 | 122 | args = parser.parse_args() 123 | 124 | cwq_evaluate_valid_results(args) 125 | 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /generation/webqsp_evaluate_offcial.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import argparse 4 | 5 | def dump_json(obj, fname, indent=4, mode='w' ,encoding="utf8", ensure_ascii=False): 6 | if "b" in mode: 7 | encoding = None 8 | with open(fname, "w", encoding=encoding) as f: 9 | return json.dump(obj, f, indent=indent, ensure_ascii=ensure_ascii) 10 | 11 | def load_json(fname, mode="r", encoding="utf8"): 12 | if "b" in mode: 13 | encoding = None 14 | with open(fname, mode=mode, encoding=encoding) as f: 15 | return json.load(f) 16 | 17 | def webqsp_evaluate_valid_results(args): 18 | if args.split == 'dev': 19 | res = main(args.pred_file, f'data/WebQSP/origin/WebQSP.pdev.json') 20 | else: 21 | res = main(args.pred_file, f'data/WebQSP/origin/WebQSP.{args.split}.json') 22 | dirname = os.path.dirname(args.pred_file) 23 | filename = os.path.basename(args.pred_file) 24 | with open (os.path.join(dirname,f'{filename}_final_eval_results_official.txt'),'w') as f: 25 | f.write(res) 26 | f.flush() 27 | 28 | def FindInList(entry,elist): 29 | for item in elist: 30 | if entry == item: 31 | return True 32 | return False 33 | 34 | def CalculatePRF1(goldAnswerList, predAnswerList): 35 | if len(goldAnswerList) == 0: 36 | if len(predAnswerList) == 0: 37 | return [1.0, 1.0, 1.0, 1] # consider it 'correct' when there is no labeled answer, and also no predicted answer 38 | else: 39 | return [0.0, 1.0, 0.0, 1] # precision=0 and recall=1 when there is no labeled answer, but has some predicted answer(s) 40 | elif len(predAnswerList)==0: 41 | return [1.0, 0.0, 0.0, 0] # precision=1 and recall=0 when there is labeled answer(s), but no predicted answer 42 | else: 43 | glist =[x["AnswerArgument"] for x in goldAnswerList] 44 | plist =predAnswerList 45 | 46 | tp = 1e-40 # numerical trick 47 | fp = 0.0 48 | fn = 0.0 49 | 50 | for gentry in glist: 51 | if FindInList(gentry,plist): 52 | tp += 1 53 | else: 54 | fn += 1 55 | for pentry in plist: 56 | if not FindInList(pentry,glist): 57 | fp += 1 58 | 59 | 60 | precision = tp/(tp + fp) 61 | recall = tp/(tp + fn) 62 | 63 | f1 = (2*precision*recall)/(precision+recall) 64 | 65 | if tp > 1e-40: 66 | hit = 1 67 | else: 68 | hit = 0 69 | return [precision, recall, f1, hit] 70 | 71 | 72 | def main(pred_data, dataset_data): 73 | 74 | goldData = load_json(dataset_data) 75 | predAnswers = load_json(pred_data) 76 | 77 | PredAnswersById = {} 78 | 79 | for item in predAnswers: 80 | PredAnswersById[item["QuestionId"]] = item["Answers"] 81 | 82 | total = 0.0 83 | f1sum = 0.0 84 | recSum = 0.0 85 | precSum = 0.0 86 | hitSum = 0 87 | numCorrect = 0 88 | prediction_res = [] 89 | if "Questions" in goldData: 90 | goldData = goldData["Questions"] 91 | for entry in goldData: 92 | 93 | skip = True 94 | for pidx in range(0,len(entry["Parses"])): 95 | np = entry["Parses"][pidx] 96 | if np["AnnotatorComment"]["QuestionQuality"] == "Good" and np["AnnotatorComment"]["ParseQuality"] == "Complete": 97 | skip = False 98 | 99 | if(len(entry["Parses"])==0 or skip): 100 | continue 101 | 102 | total += 1 103 | 104 | id = entry["QuestionId"] 105 | 106 | if id not in PredAnswersById: 107 | print("The problem " + id + " is not in the prediction set") 108 | print("Continue to evaluate the other entries") 109 | continue 110 | 111 | if len(entry["Parses"]) == 0: 112 | print("Empty parses in the gold set. Breaking!!") 113 | break 114 | 115 | predAnswers = PredAnswersById[id] 116 | 117 | bestf1 = -9999 118 | bestf1Rec = -9999 119 | bestf1Prec = -9999 120 | besthit = 0 121 | for pidx in range(0,len(entry["Parses"])): 122 | pidxAnswers = entry["Parses"][pidx]["Answers"] 123 | prec,rec,f1,hit = CalculatePRF1(pidxAnswers,predAnswers) 124 | if f1 > bestf1: 125 | bestf1 = f1 126 | bestf1Rec = rec 127 | bestf1Prec = prec 128 | if hit > besthit: 129 | besthit = hit 130 | 131 | f1sum += bestf1 132 | recSum += bestf1Rec 133 | precSum += bestf1Prec 134 | hitSum += besthit 135 | 136 | pred = {} 137 | pred['qid'] = id 138 | pred['precision'] = bestf1Prec 139 | pred['recall'] = bestf1Rec 140 | pred['f1'] = bestf1 141 | pred['hit'] = besthit 142 | prediction_res.append(pred) 143 | 144 | if bestf1 == 1.0: 145 | numCorrect += 1 146 | 147 | print("Number of questions:", int(total)) 148 | print("Average precision over questions: %.3f" % (precSum / total)) 149 | print("Average recall over questions: %.3f" % (recSum / total)) 150 | print("Average f1 over questions (accuracy): %.3f" % (f1sum / total)) 151 | print("F1 of average recall and average precision: %.3f" % (2 * (recSum / total) * (precSum / total) / (recSum / total + precSum / total))) 152 | print("True accuracy (ratio of questions answered exactly correctly): %.3f" % (numCorrect / total)) 153 | print("Hits@1 over questions: %.3f" % (hitSum / total)) 154 | res = f'Number of questions:{int(total)}\n, Average precision over questions: {(precSum / total)}\n, Average recall over questions: {(recSum / total)}\n, Average f1 over questions (accuracy): {(f1sum / total)}\n, F1 of average recall and average precision: {(2 * (recSum / total) * (precSum / total) / (recSum / total + precSum / total))}\n, True accuracy (ratio of questions answered exactly correctly): {(numCorrect / total)}\n, Hits@1 over questions: {(hitSum / total)}' 155 | dirname = os.path.dirname(pred_data) 156 | filename = os.path.basename(pred_data) 157 | # dump_json(prediction_res, os.path.join(dirname, f'{filename}_new.json')) # 可以查看每个样本的三个指标得分 158 | return res 159 | 160 | if __name__ == "__main__": 161 | parser = argparse.ArgumentParser() 162 | parser.add_argument( 163 | "--split", 164 | type=str, 165 | required=True, 166 | help="split to operate on, can be `test`, `dev` and `train`", 167 | ) 168 | parser.add_argument( 169 | "--pred_file", type=str, default=None, help="prediction results file" 170 | ) 171 | args = parser.parse_args() 172 | 173 | webqsp_evaluate_valid_results(args) -------------------------------------------------------------------------------- /lib/virtodbc.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Applied-Machine-Learning-Lab/AMAR/81da184742e887c512ea7fb393a4dc19e9495446/lib/virtodbc.so -------------------------------------------------------------------------------- /ontology/README.md: -------------------------------------------------------------------------------- 1 | Files under this folder originate from [GrailQA](https://github.com/dki-lab/GrailQA) -------------------------------------------------------------------------------- /process_NQ.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | from components.utils import load_json 5 | from tqdm import tqdm 6 | 7 | def load_data(split, args): 8 | data_file_name = 'data/{}/generation/merged/{}_{}.json'.format(args.dataset_type,args.dataset_type,split) 9 | print('Loading data from:',data_file_name) 10 | data_dict = load_json(data_file_name) 11 | return data_dict 12 | 13 | def _parse_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--dataset_type', default="WebQSP", type=str, help="CWQ | WebQSP") 16 | args = parser.parse_args() 17 | return args 18 | 19 | def prepare_dataloader(args,split): 20 | assert split in ['train','test','dev','train_sample','dev_sample','test_sample'] 21 | 22 | data = load_data(split, args) 23 | print(f'Origin {split} dataset len: {len(data)}') 24 | assert type(data)==list 25 | if 'train' in split or 'dev' in split: 26 | # for train and dev, filter the examples without sexpr 27 | examples = [] 28 | for x in data: 29 | if x['sexpr'].lower()!="null": 30 | examples.append(x) 31 | else: 32 | examples = [x for x in data] 33 | print(f'Real {split} dataset len: {len(examples)}') 34 | 35 | json_data=[] 36 | instruction='Generate a Logical Form query that retrieves the information corresponding to the given question. \n' 37 | for cnt, item in tqdm(enumerate(examples)): 38 | question=item['question'] 39 | input = 'Question: { '+question+' }' 40 | output = item['normed_sexpr'] 41 | json_data.append({"instruction":instruction,"input":input,"output":output,"history":[]}) 42 | 43 | 44 | output_dir = 'LLMs/data/{}_Freebase_NQ_{}/examples.json'.format(args.dataset_type, split) 45 | 46 | if not os.path.exists(os.path.dirname(output_dir)): 47 | os.mkdir(os.path.dirname(output_dir)) 48 | 49 | with open(output_dir, "w", encoding="utf-8") as file: 50 | json.dump(json_data, file) 51 | 52 | 53 | if __name__=='__main__': 54 | 55 | args = _parse_args() 56 | print(args) 57 | prepare_dataloader(args,'train') 58 | prepare_dataloader(args, 'test') 59 | print('Finished') 60 | 61 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets==2.18.0 2 | deepspeed==0.14.0 3 | Flask==3.0.3 4 | huggingface-hub==0.24.3 5 | jieba==0.42.1 6 | jsonlines==4.0.0 7 | networkx==3.2.1 8 | nltk==3.6.5 9 | numpy==1.19.5 10 | pandas==1.4.0 11 | peft==0.6.0 12 | pillow==10.2.0 13 | ptvsd==4.3.2 14 | pyodbc==5.1.0 15 | pyparsing==3.1.1 16 | pyserini==0.22.1 17 | pytz==2023.3.post1 18 | PyYAML==6.0 19 | requests==2.31.0 20 | requests-toolbelt==1.0.0 21 | rich==13.7.1 22 | rouge-chinese==1.0.3 23 | ruff==0.5.5 24 | safetensors==0.4.2 25 | scikit-learn==1.0.2 26 | scipy==1.5.4 27 | sentence-transformers==2.2.2 28 | simcse==0.4 29 | SPARQLWrapper==2.0.0 30 | tiktoken==0.7.0 31 | timm==0.9.7 32 | tokenizers==0.13.3 33 | tomlkit==0.12.0 34 | torch==1.13.1+cu117 35 | torchaudio==0.13.1+cu117 36 | torchvision==0.14.1+cu117 37 | tqdm==4.62.3 38 | transformers==4.31.0 39 | trl==0.9.4 40 | typer==0.12.3 41 | tyro==0.8.5 42 | urllib3==2.1.0 43 | uvicorn==0.30.4 44 | wandb==0.16.3 45 | glob2==0.7 46 | faiss-gpu==1.7.2 47 | antlr4-python3-runtime==4.9.2 48 | argparse==1.4.0 49 | graphq-trans==0.1.0 50 | rdflib==6.1.1 51 | func-timeout==4.3.5 52 | sympy==1.13.1 53 | cvxpy==1.5.2 54 | colorama==0.4.6 55 | termcolor==2.4.0 56 | flair 57 | pytorch-transformers 58 | sse-starlette 59 | fastapi 60 | matplotlib 61 | rouge_chinese 62 | gradio -------------------------------------------------------------------------------- /run_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # cwq 4 | CUDA_VISIBLE_DEVICES=0 bash run_ft.sh CWQ LLaMA-2-13b-hf cwq_4_16_32_16 train-test 4 16 1 32 16 8 5 | sleep 1m 6 | CUDA_VISIBLE_DEVICES=0 bash run_ft.sh CWQ LLaMA-2-13b-hf cwq_4_16_32_16 train 4 16 1 32 16 8 7 | CUDA_VISIBLE_DEVICES=0 python -u eval_final.py --dataset CWQ --pred_file Reading/LLaMA-2-13b-hf/CWQ_cwq_4_16_32_16/evaluation_beam/beam_test_top_k_predictions.json 8 | CUDA_VISIBLE_DEVICES=0 python -u eval_final.py --dataset CWQ --pred_file Reading/LLaMA-2-13b-hf/CWQ_cwq_4_16_32_16/evaluation_beam/beam_test_top_k_predictions.json --golden_ent 9 | 10 | # webqsp 11 | CUDA_VISIBLE_DEVICES=0 bash run_ft.sh WebQSP LLaMA-2-7b-hf webqsp_100_7_32_16 train 100 7 0 32 16 15 12 | sleep 1m 13 | CUDA_VISIBLE_DEVICES=0 bash run_ft.sh WebQSP LLaMA-2-7b-hf webqsp_100_7_32_16 test 100 7 0 32 16 15 14 | CUDA_VISIBLE_DEVICES=0 python -u eval_final.py --dataset WebQSP --pred_file Reading/LLaMA-2-7b-hf/WebQSP_webqsp_100_7_32_16/evaluation_beam/beam_test_top_k_predictions.json 15 | CUDA_VISIBLE_DEVICES=0 python -u eval_final.py --dataset WebQSP --pred_file Reading/LLaMA-2-7b-hf/WebQSP_webqsp_100_7_32_16/evaluation_beam/beam_test_top_k_predictions.json --golden_ent 16 | 17 | -------------------------------------------------------------------------------- /run_ft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | data=$1 # 'CWQ' 4 | llm_model=$2 # 'LLaMA-2-13b-hf' 5 | setting=$3 6 | mode=$4 7 | topk=$5 8 | soft_prompt_length=$6 9 | cuda=$7 10 | extra_infor_len=$8 11 | gate_len=${9} 12 | num_beams=${10} 13 | 14 | if [ "$data" == "WebQSP" ]; then 15 | num_train_epochs=80.0 16 | else 17 | num_train_epochs=8.0 18 | fi 19 | 20 | 21 | if [[ $mode == *"train"* ]]; then 22 | python -u LLMs/LLaMA/src/train_bash.py \ 23 | --stage sft \ 24 | --model_name_or_path ./../../LLM_checkpoint/${llm_model} \ 25 | --do_train \ 26 | --dataset_dir LLMs/data_id \ 27 | --dataset ${data}_Freebase_NQ_train \ 28 | --template default \ 29 | --finetuning_type lora \ 30 | --lora_target gate_proj,down_proj,up_proj \ 31 | --output_dir Reading/${llm_model}/${data}_${setting}/checkpoint \ 32 | --overwrite_cache \ 33 | --per_device_train_batch_size 4 \ 34 | --gradient_accumulation_steps 4 \ 35 | --lr_scheduler_type cosine \ 36 | --logging_steps 10 \ 37 | --save_strategy no \ 38 | --learning_rate 5e-5 \ 39 | --num_train_epochs ${num_train_epochs} \ 40 | --plot_loss \ 41 | --overwrite_output_dir \ 42 | --topk ${topk} \ 43 | --soft_prompt_length ${soft_prompt_length} \ 44 | --extra_infor_len ${extra_infor_len} \ 45 | --gate_len ${gate_len} 46 | fi 47 | 48 | 49 | 50 | if [[ $mode == *"test"* ]]; then 51 | python -u LLMs/LLaMA/src/beam_output_eva.py \ 52 | --model_name_or_path ./../../LLM_checkpoint/${llm_model} \ 53 | --dataset_dir LLMs/data_id \ 54 | --dataset ${data}_Freebase_NQ_test \ 55 | --template default \ 56 | --finetuning_type lora \ 57 | --lora_target gate_proj,down_proj,up_proj \ 58 | --checkpoint_dir Reading/${llm_model}/${data}_${setting}/checkpoint \ 59 | --num_beams ${num_beams} \ 60 | --max_new_tokens 256 \ 61 | --topk ${topk} \ 62 | --soft_prompt_length ${soft_prompt_length} \ 63 | --extra_infor_len ${extra_infor_len} \ 64 | --gate_len ${gate_len} 65 | fi 66 | -------------------------------------------------------------------------------- /run_generator_final.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | from components.utils import dump_json 5 | 6 | def prepare_dataloader(args): 7 | print('Loading data from:',args.data_file_name) 8 | with open(args.data_file_name, 'r', encoding='utf-8') as f: 9 | # 读取每一行并转换为字典 10 | data = [json.loads(line) for line in f] 11 | print(f'Dataset len: {len(data)}') 12 | return data 13 | 14 | 15 | def _parse_args(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--data_file_name',default='Reading/LLaMA2-13b/WebQSP_Freebase_NQ_lora_epoch100/evaluation_beam/generated_predictions.jsonl') 18 | 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | def run_prediction(args,dataloader,output_dir,output_predictions=True): 24 | print() 25 | print('Start predicting ') 26 | 27 | ex_cnt = 0 28 | contains_ex_cnt = 0 29 | output_list = [] 30 | real_total = 0 31 | for i,pred in enumerate(dataloader): 32 | predictions = pred['predict'] 33 | gen_label = pred['label'] 34 | 35 | output_list.append({ 36 | 'predictions':predictions, 37 | 'gen_label':gen_label, 38 | }) 39 | 40 | if predictions[0].lower()==gen_label.lower(): 41 | ex_cnt+=1 42 | 43 | if any([x.lower()==gen_label.lower() for x in predictions]): 44 | contains_ex_cnt+=1 45 | 46 | if gen_label.lower()!='null': 47 | real_total+=1 48 | 49 | 50 | print(f"""total:{len(output_list)}, 51 | ex_cnt:{ex_cnt}, 52 | ex_rate:{ex_cnt/len(output_list)}, 53 | real_ex_rate:{ex_cnt/real_total}, 54 | contains_ex_cnt:{contains_ex_cnt}, 55 | contains_ex_rate:{contains_ex_cnt/len(output_list)} 56 | real_contains_ex_rate:{contains_ex_cnt/real_total} 57 | """) 58 | 59 | 60 | if output_predictions: 61 | file_path = os.path.join(output_dir,f'beam_test_top_k_predictions.json') 62 | 63 | gen_statistics_file_path = os.path.join(output_dir,f'beam_test_gen_statistics.json') 64 | gen_statistics = { 65 | 'total':len(output_list), 66 | 'exmatch_num': ex_cnt, 67 | 'exmatch_rate': ex_cnt/len(output_list), 68 | 'real_exmatch_rate':ex_cnt/real_total, 69 | 'contains_ex_num':contains_ex_cnt, 70 | 'contains_ex_rate':contains_ex_cnt/len(output_list), 71 | 'real_contains_ex_rate':contains_ex_cnt/real_total 72 | } 73 | dump_json(output_list, file_path, indent=4) 74 | dump_json(gen_statistics, gen_statistics_file_path,indent=4) 75 | 76 | 77 | if __name__=='__main__': 78 | 79 | args = _parse_args() 80 | print(args) 81 | 82 | test_dataloader = prepare_dataloader(args) 83 | run_prediction(args,test_dataloader,output_dir=os.path.dirname(args.data_file_name),output_predictions=True) 84 | 85 | print('Prediction Finished') 86 | 87 | --------------------------------------------------------------------------------