├── tests ├── __init__.py ├── ljp_task.py └── model_generate.py ├── .gitignore ├── lm_eval ├── __init__.py ├── models │ ├── __init__.py │ ├── base.py │ ├── openai_model.py │ └── hf_model.py ├── tasks │ ├── __init__.py │ ├── base.py │ └── ljp.py ├── utils.py ├── parser.py └── evaluate.py ├── resources ├── static │ ├── fig_setting_example_v2.jpg │ └── fig_setting_example_v2.pdf ├── paper_version_acc.csv └── paper_version_f1.csv ├── run.sh ├── requirements.txt ├── config ├── ljp_prompt.json ├── default_openai.json └── default_hf.json ├── compress.sh ├── scripts ├── csv2htlm.py ├── copy_paper_results.sh ├── get_result_table.py └── parse_all_output.py ├── download_data.sh ├── main.py └── readme.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.env* 3 | wandb/ 4 | outputs/ 5 | data_hub/ 6 | runs/ -------------------------------------------------------------------------------- /lm_eval/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Language model evaluation of open-end generations. 3 | """ -------------------------------------------------------------------------------- /lm_eval/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .hf_model import HF_AutoModel, HF_Model_Config, HF_Gen_Conf 2 | # from .openai_model import Openai_Model -------------------------------------------------------------------------------- /resources/static/fig_setting_example_v2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srhthu/LM-CompEval-Legal/HEAD/resources/static/fig_setting_example_v2.jpg -------------------------------------------------------------------------------- /resources/static/fig_setting_example_v2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/srhthu/LM-CompEval-Legal/HEAD/resources/static/fig_setting_example_v2.pdf -------------------------------------------------------------------------------- /lm_eval/models/base.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from dataclasses import dataclass, asdict 3 | from typing import Optional, Union, List, Dict, Any 4 | 5 | -------------------------------------------------------------------------------- /lm_eval/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | """Prepare the task data, e.g., add prompt""" 2 | from .base import TaskBase 3 | from .ljp import JudgmentPrediction_Config, JudgmentPrediction_Task -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main.py \ 2 | --config ./config/default_hf.json \ 3 | --output_dir ./runs/chatglm_6b \ 4 | --model_type hf \ 5 | --model THUDM/chatglm-6b -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | environs==9.5.0 2 | jieba==0.42.1 3 | numpy==1.24.3 4 | openai==0.27.8 5 | pandas==2.0.1 6 | peft==0.5.0 7 | rank_bm25==0.2.2 8 | scikit_learn==1.2.2 9 | tiktoken==0.4.0 10 | torch==2.0.1 11 | tqdm==4.65.0 12 | transformers==4.31.0 13 | -------------------------------------------------------------------------------- /config/ljp_prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "instruction_free_zs": "根据中国刑法判断犯罪嫌疑人的罪名,直接输出罪名。", 3 | "instruction_free_fs": "根据中国刑法判断犯罪嫌疑人的罪名", 4 | "instruction_multi_zs": "根据中国刑法判断犯罪嫌疑人的罪名", 5 | "instruction_multi_fs": "根据中国刑法判断犯罪嫌疑人的罪名", 6 | "demo_template": "案件事实: {input}\n罪名: {output}" 7 | } -------------------------------------------------------------------------------- /compress.sh: -------------------------------------------------------------------------------- 1 | if [ ! -d "data_hub" ]; then 2 | mkdir data_hub 3 | fi 4 | 5 | # compress eval data into tar.gz 6 | cd data_hub 7 | [ -d "ljp" ] && tar czvf ljp_data.tar.gz ljp || echo "Error: no ljp folder" 8 | 9 | cd ../runs 10 | [ -d "paper_version" ] && tar czvf paper_version.tar.gz paper_version || echo "Error: no paper_version" -------------------------------------------------------------------------------- /lm_eval/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | def read_jsonl(path): 4 | return [json.loads(k) for k in open(path, encoding='utf8')] 5 | 6 | def read_json(path): 7 | return json.load(open(path, encoding='utf8')) 8 | 9 | def save_jsonl(data, path): 10 | with open(path, 'w', encoding='utf8') as f: 11 | for d in data: 12 | f.write(json.dumps(d, ensure_ascii=False) + '\n') -------------------------------------------------------------------------------- /config/default_openai.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | }, 4 | "gen_config": { 5 | "num_return_sequences": 5, 6 | "max_new_tokens": 30, 7 | "temperature": 0.8 8 | }, 9 | "task_config": { 10 | "data_path": "./data_hub/ljp", 11 | "prompt_config_file": "./config/ljp_prompt.json", 12 | "query_max_len": 1000, 13 | "demo_max_len": 500 14 | } 15 | } -------------------------------------------------------------------------------- /scripts/csv2htlm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert the csv table to html table 3 | """ 4 | # %% 5 | import pandas as pd 6 | 7 | df = pd.read_csv('../resources/paper_version_f1.csv') 8 | # %% 9 | df2 = df.reset_index() 10 | df2['index'] = df2['index'] + 1 11 | df2 = df2.rename({'index': 'rank'}, axis= 1) 12 | # %% 13 | print(df2.to_html(index = False, float_format=lambda k: f'{k*100:.2f}', 14 | justify = 'justify')) 15 | # %% 16 | -------------------------------------------------------------------------------- /config/default_hf.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_config": { 3 | "base_model": "THUDM/chatglm-6b", 4 | "trust_remote_code": true, 5 | "save_memory": true 6 | }, 7 | "gen_config": { 8 | "num_return_sequences": 5, 9 | "max_new_tokens": 30, 10 | "temperature": 0.8, 11 | "do_sample": true 12 | }, 13 | "task_config": { 14 | "data_path": "./data_hub/ljp", 15 | "prompt_config_file": "./config/ljp_prompt.json", 16 | "query_max_len": 1000, 17 | "demo_max_len": 500 18 | } 19 | } -------------------------------------------------------------------------------- /scripts/copy_paper_results.sh: -------------------------------------------------------------------------------- 1 | for model in gpt4 chatgpt bloomz_7B chatglm_6B vicuna_13B; 2 | do 3 | if [ ! -d "runs/paper_version/${model}" ]; then 4 | mkdir "runs/paper_version/${model}" 5 | fi 6 | for ttype in free multi 7 | do 8 | for shot in {0..4} 9 | do 10 | [ $shot -eq 0 ] && set_suffix="" || set_suffix="_sim" 11 | cur_dir="runs/paper_version/${model}/${ttype}-${shot}shot" 12 | [ ! -d $cur_dir ] && mkdir $cur_dir 13 | echo $cur_dir 14 | cp "/storage_fast/rhshui/workspace/LJP/lm/runs/benchmark/${ttype}_${shot}shot${set_suffix}/${model}/raw_test_output.txt" "${cur_dir}/" 15 | done 16 | 17 | done 18 | done -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | echo "Download the evaluation data..." 2 | [ ! -d "data_hub" ] && mkdir data_hub 3 | 4 | # https://drive.google.com/file/d/1H-wReUapuUIXnJe3lUKZoN4rLxbwT_q6/view?usp=share_link 5 | wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1H-wReUapuUIXnJe3lUKZoN4rLxbwT_q6' -O ./data_hub/ljp_data.tar.gz 6 | 7 | tar xzf data_hub/ljp_data.tar.gz -C data_hub 8 | 9 | echo "Download model generated results..." 10 | [ ! -d "runs" ] && mkdir runs 11 | # https://drive.google.com/file/d/14Zfy60udBaymsYEd8zI9z236JBS1kHvt/view?usp=share_link 12 | wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=14Zfy60udBaymsYEd8zI9z236JBS1kHvt' -O ./runs/paper_version.tar.gz 13 | 14 | tar xzf runs/paper_version.tar.gz -C runs -------------------------------------------------------------------------------- /lm_eval/parser.py: -------------------------------------------------------------------------------- 1 | """Parse open generation text to pre-defined labels""" 2 | import numpy as np 3 | from rank_bm25 import BM25Okapi 4 | import jieba 5 | from typing import List 6 | 7 | class BM25_Parser: 8 | def __init__(self, label2id): 9 | self.all_labels = list(label2id.keys()) 10 | corpus = [list(jieba.cut(k, cut_all = True)) for k in self.all_labels] 11 | self.bm25 = BM25Okapi(corpus) 12 | 13 | def __call__(self, choices: List[str]): 14 | """ 15 | Given several sampled outputs, map them to a label. 16 | """ 17 | tokenized_choices = [list(jieba.cut(c, cut_all = True)) for c in choices] 18 | cho_score = [self.bm25.get_scores(c) for c in tokenized_choices] 19 | # average the similarity score across all outputs 20 | cho_s = np.mean(cho_score, axis = 0) 21 | # select the 22 | top_idx = np.argsort(cho_s)[-1] 23 | return self.all_labels[top_idx] -------------------------------------------------------------------------------- /resources/paper_version_acc.csv: -------------------------------------------------------------------------------- 1 | model,score,free-0shot,free-1shot,free-2shot,free-3shot,free-4shot,multi-0shot,multi-1shot,multi-2shot,multi-3shot,multi-4shot 2 | gpt4,0.6517857313156128,0.5517857074737549,0.6482142806053162,0.6910714507102966,0.6982142925262451,0.7196428775787354,0.6392857432365417,0.7124999761581421,0.7250000238418579,0.737500011920929,0.7446428537368774 3 | chatgpt,0.5950892865657806,0.4660714268684387,0.6000000238418579,0.6285714507102966,0.6482142806053162,0.6696428656578064,0.6160714030265808,0.6446428298950195,0.6696428656578064,0.7035714387893677,0.6714285612106323 4 | chatglm_6b,0.49196428805589676,0.41428571939468384,0.5178571343421936,0.5,0.5035714507102966,0.5053571462631226,0.5571428537368774,0.5053571462631226,0.49642857909202576,0.49464285373687744,0.4732142984867096 5 | bloomz_7b,0.46741071343421936,0.4982142746448517,0.5482142567634583,0.5267857313156128,0.5249999761581421,0.512499988079071,0.5339285731315613,0.3196428716182709,0.3107142746448517,0.27321428060531616,0.2660714387893677 6 | vicuna_13b,0.42276785522699356,0.28214284777641296,0.5035714507102966,0.49642857909202576,0.5178571343421936,0.3589285612106323,0.47857141494750977,0.4482142925262451,0.43392857909202576,0.3571428656578064,0.19464285671710968 7 | -------------------------------------------------------------------------------- /resources/paper_version_f1.csv: -------------------------------------------------------------------------------- 1 | model,score,free-0shot,free-1shot,free-2shot,free-3shot,free-4shot,multi-0shot,multi-1shot,multi-2shot,multi-3shot,multi-4shot 2 | gpt4,0.6304570350991271,0.5051931565034457,0.6272116440992214,0.6753696451590082,0.6861244013457889,0.710234078091221,0.6231494728731667,0.7041617582501042,0.7181158658608878,0.7324428742682257,0.739957116296402 3 | chatgpt,0.5812671747722895,0.43136043185938144,0.5842372260675832,0.6185609038892743,0.6439740854703412,0.6616275308044312,0.6066751113861076,0.6350939040357199,0.6684722519543949,0.6959456364130734,0.6661553541487751 4 | chatglm_6b,0.47738158557111326,0.41889583161496635,0.5030107089455829,0.47759512313083746,0.48587183175692494,0.48672203185013746,0.5374210476731485,0.4925884631241774,0.4756143398655009,0.4760878652160164,0.453166523399671 5 | bloomz_7b,0.4414142089914668,0.46902045546913246,0.5328335520299806,0.5106431496090086,0.5090090468892534,0.4925548276684269,0.5067579087564413,0.2924788595094575,0.2792353221312849,0.25273131184632086,0.23371439193392973 6 | vicuna_13b,0.3982888531124886,0.25499944279024334,0.4884546818052457,0.476366857931984,0.49491859218128015,0.3982025769734254,0.446988003347312,0.4173250607751382,0.4148011083804151,0.3503226224930532,0.2160533911591092 7 | -------------------------------------------------------------------------------- /tests/ljp_task.py: -------------------------------------------------------------------------------- 1 | """Test Legal Judgment Prediction Task""" 2 | from tiktoken import encoding_for_model 3 | from transformers import AutoTokenizer 4 | 5 | from lm_eval.tasks.ljp import JudgmentPrediction_Task 6 | 7 | def test_build_task(args, tokenizer): 8 | tasker = JudgmentPrediction_Task( 9 | config = { 10 | 'data_path': args.data_path, 11 | 'prompt_config_file': args.prompt_config 12 | }, 13 | tokenizer = tokenizer 14 | ) 15 | for subtask in tasker.get_all_subtask(): 16 | print(f'Building subtask: {subtask}') 17 | data = tasker.build_subtask(subtask) 18 | print(data[0]) 19 | 20 | if __name__ == '__main__': 21 | from argparse import ArgumentParser 22 | parser = ArgumentParser() 23 | parser.add_argument('--data_path', default = './data_hub/ljp') 24 | parser.add_argument('--prompt_config', default = './config/ljp_prompt.json') 25 | 26 | args = parser.parse_args() 27 | 28 | # print('#'*15 + 'Test ChatGPT') 29 | # tokenizer = encoding_for_model('gpt-3.5-turbo') 30 | # test_build_task(args, tokenizer) 31 | # print('\n\n') 32 | 33 | print('#'*15 + 'Test ChatGLM') 34 | tokenizer = AutoTokenizer.from_pretrained('THUDM/chatglm-6b', trust_remote_code = True) 35 | test_build_task(args, tokenizer) 36 | 37 | -------------------------------------------------------------------------------- /scripts/get_result_table.py: -------------------------------------------------------------------------------- 1 | """ 2 | Merge results of all models 3 | """ 4 | import json 5 | import pandas as pd 6 | from pathlib import Path 7 | 8 | def merge_metric(exp_dir, metric)->pd.DataFrame: 9 | """Merge the specified metric of all models""" 10 | sub_tasks = [f'{tt}-{n}shot' for tt in ['free', 'multi'] for n in range(5)] 11 | 12 | lines = [] 13 | for model in ['gpt4', 'chatgpt', 'bloomz_7b', 'chatglm_6b', 'vicuna_13b']: 14 | eval_file = Path(exp_dir) / model / 'eval_results.txt' 15 | if eval_file.exists(): 16 | with open(eval_file) as f: 17 | model_results = [json.loads(k) for k in f] 18 | task2metric = {rec['subtask']: rec['metrics'][metric] for rec in model_results} 19 | else: 20 | task2metric = {} 21 | lines.append([model] + [task2metric.get(k) for k in sub_tasks]) 22 | df = pd.DataFrame(lines, columns = ['model'] + sub_tasks) 23 | return df 24 | 25 | def main(exp_dir, metric, save_path): 26 | df = merge_metric(exp_dir, metric) 27 | tot_score = (df['free-0shot'] + df['multi-0shot'] + df['free-2shot'] + df['multi-2shot']) / 4 28 | df.insert(1, 'score', tot_score) 29 | df = df.sort_values(by = ['score'], ascending = False) 30 | df.to_csv(save_path, index = False) 31 | 32 | if __name__ == '__main__': 33 | import argparse 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument('--exp_dir', help = 'dir of all model results', default = 'runs/paper_version') 36 | parser.add_argument('--metric', help = 'metric name, f1 or acc', default = 'f1') 37 | parser.add_argument('--save_path', help = 'where to save the table', default = 'resources/f1.csv') 38 | args = parser.parse_args() 39 | main(**vars(args)) -------------------------------------------------------------------------------- /tests/model_generate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test the function of CausalLM and Seq2SeqLM model generation 3 | """ 4 | from lm_eval.models import HF_AutoModel, HF_Model_Config, HF_Gen_Conf 5 | 6 | def display(source, targets): 7 | print(f'Input: {source}') 8 | for i,t in enumerate(targets): 9 | print(f'Output #{i}: {t}') 10 | 11 | def main(args): 12 | model = HF_AutoModel( 13 | HF_Model_Config( 14 | base_model = args.model, 15 | peft_model = args.peft_model, 16 | trust_remote_code = True, 17 | save_memory = False 18 | ), 19 | gen_config = HF_Gen_Conf( 20 | num_return_sequences = 3 21 | ) 22 | ) 23 | 24 | source = 'Today is a good day.' 25 | print(f'Test save_memory={model.config.save_memory}') 26 | targets, out_ids = model.generate(source, return_ids = True) 27 | display(source, targets) 28 | all_tokens = [model.tokenizer.convert_ids_to_tokens(k) for k in out_ids] 29 | for i, tokens in enumerate(all_tokens): 30 | print('Tokens #{}: {}'.format(i, ' '.join(tokens))) 31 | 32 | # print('Test save_memory=False') 33 | # model.config.save_memory = False 34 | # targets = model.generate(source) 35 | # print(targets) 36 | 37 | # print('Test greedy decoding') 38 | # model.gen_config.do_sample = False 39 | # targets = model.generate(source, num_output=1) 40 | # print(targets) 41 | 42 | if __name__ == '__main__': 43 | import argparse 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('--model') 46 | parser.add_argument('--peft_model') 47 | # parser.add_argument('--is_seq2seq', action = 'store_true') 48 | 49 | args = parser.parse_args() 50 | 51 | main(args) 52 | 53 | -------------------------------------------------------------------------------- /lm_eval/tasks/base.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Dict, Any, List 3 | from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast 4 | from tiktoken import Encoding 5 | 6 | class TaskBase: 7 | """ 8 | The base class of Task that produce task data e.g. prompts. 9 | """ 10 | def __init__(self): 11 | self._task_data = {} 12 | 13 | def build_subtask(self, name) -> List[Dict[str, Any]]: 14 | """Return subtask data""" 15 | raise NotImplementedError 16 | 17 | def get_all_subtask(self) -> List[str]: 18 | """Return the name of subtasks""" 19 | raise NotImplementedError 20 | 21 | def get_task_data(self): 22 | """Return task data for evaluation""" 23 | raise NotImplementedError 24 | 25 | def cut_text(self, text, max_len): 26 | """Truncate text to max length""" 27 | tokenizer = self.tokenizer 28 | if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): 29 | # handle transformers tokenizer 30 | outs = tokenizer(text, truncation = True, max_length = max_len) 31 | new_text = tokenizer.decode( 32 | outs.input_ids, skip_special_tokens=True, 33 | clean_up_tokenization_spaces = True, # ensure chinese characters are not splited by space 34 | ) 35 | elif isinstance(tokenizer, Encoding): 36 | # handle OpenAI tokenizer 37 | token_ids = tokenizer.encode(text)[:max_len] 38 | new_text = tokenizer.decode(token_ids) 39 | rep_chr = chr(65533) # this is the token of error utf-8 codes 40 | # replace the last error token if exists 41 | if new_text[-1] == rep_chr: 42 | new_text = new_text[:-1] 43 | else: 44 | raise ValueError(f'Unknown tokenizer type: {tokenizer.__class__.__name__}') 45 | return new_text -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Comprehensive evaluation of a language model on Legal Judgment Prediction 3 | """ 4 | from pathlib import Path 5 | import json 6 | 7 | from lm_eval.models import HF_AutoModel 8 | from lm_eval.tasks.ljp import JudgmentPrediction_Task 9 | from lm_eval.evaluate import evaluate_subtasks 10 | 11 | 12 | def main(run_config, output_dir, sub_tasks): 13 | # Initialize model 14 | model = HF_AutoModel(run_config['model_config'], run_config['gen_config']) 15 | # Initialize task 16 | task = JudgmentPrediction_Task(run_config['task_config'], model.tokenizer) 17 | 18 | Path(output_dir).mkdir(parents=True, exist_ok=True) 19 | with open(Path(output_dir) / 'run_config', 'w') as f: 20 | json.dump(run_config, f, indent = 4, ensure_ascii = False) 21 | 22 | evaluate_subtasks(model, task, output_dir = output_dir, sub_tasks = sub_tasks) 23 | 24 | if __name__ == '__main__': 25 | import argparse 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--config', help = 'default config file path') 28 | parser.add_argument('--output_dir', help = 'path to save evaluation results') 29 | parser.add_argument('--model_type', help = 'openai or hf') 30 | parser.add_argument('--model', help = 'path of the model') 31 | parser.add_argument('--sub_tasks', help = 'subtask names saperated by comma', default = 'all') 32 | parser.add_argument('--speed', action = 'store_true', 33 | help = 'Improve inference speed by consuming more GPU memory. Recommended when GPU memory > 24G') 34 | 35 | args = parser.parse_args() 36 | 37 | run_config = json.load(open(args.config)) 38 | if args.model_type == 'hf': 39 | run_config['model_config']['base_model'] = args.model 40 | run_config['model_config']['save_memory'] = not args.speed 41 | elif args.model_type == 'openai': 42 | ... 43 | else: 44 | raise ValueError(f'model_type: {args.model_type}') 45 | 46 | main(run_config, args.output_dir, args.sub_tasks) 47 | -------------------------------------------------------------------------------- /scripts/parse_all_output.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from pathlib import Path 5 | import time 6 | from typing import List 7 | import numpy as np 8 | from sklearn.metrics import precision_recall_fscore_support 9 | sys.path.insert(0, str(Path(__file__).absolute().parents[1])) 10 | 11 | 12 | from lm_eval.parser import BM25_Parser 13 | from lm_eval.utils import read_jsonl, save_jsonl 14 | from lm_eval.tasks import JudgmentPrediction_Task 15 | 16 | all_run_dir = Path('runs/paper_version') 17 | 18 | def main(): 19 | label2id = json.load(open('/storage/rhshui/workspace/LM-CompEval-Legal/data_hub/ljp/charge2id_clean.json')) 20 | test_ds = [json.loads(k) for k in open('/storage/rhshui/workspace/LM-CompEval-Legal/data_hub/ljp/test_data.json')] 21 | grounds = [label2id[JudgmentPrediction_Task.convert_label(k['charge'])] for k in test_ds] 22 | parser = BM25_Parser(label2id) 23 | 24 | sub_tasks = [f'{tt}-{n}shot' for tt in ['free', 'multi'] for n in range(5)] 25 | 26 | for model in ['gpt4', 'chatgpt', 'bloomz_7b', 'chatglm_6b', 'vicuna_13b']: 27 | output_dir = all_run_dir / model 28 | metric_file = output_dir / 'eval_results.txt' 29 | for task_name in sub_tasks: 30 | subtask_dir = output_dir / task_name 31 | save_file = subtask_dir / 'raw_output.txt' 32 | parse_file = subtask_dir / 'parse_results.txt' 33 | if parse_file.exists(): 34 | continue 35 | # Parse 36 | print(f'Parse: {save_file}') 37 | outputs = read_jsonl(save_file) 38 | idx2out = {k['idx']: k['choices'] for k in outputs} 39 | finish_all = all([k['idx'] in idx2out for k in test_ds]) 40 | assert finish_all 41 | parse_results = [{'idx': k['idx'], 'pred': parser(k['choices'])} for k in outputs] 42 | save_jsonl(parse_results, parse_file) 43 | # Evaluate 44 | idx2pred = {k['idx']: label2id[k['pred']] for k in parse_results} 45 | preds = [idx2pred[k['idx']] for k in test_ds] 46 | metrics = get_ljp_metrics(preds, grounds) 47 | metrics = {k:float(v) for k,v in metrics.items()} # for serialization 48 | log = {'time': time.time(), 'subtask': task_name, 'metrics': metrics} 49 | print('Evaluation results:\n' + str(log)) 50 | with open(metric_file, 'a') as f: 51 | f.write(json.dumps(log) + '\n') 52 | 53 | 54 | def get_ljp_metrics(preds: List[int], targets: List[int]): 55 | preds = np.array(preds, dtype = np.int64) 56 | targets = np.array(targets, dtype = np.int64) 57 | acc = (preds == targets).astype(np.float32).mean() 58 | p,r,f1, _ = precision_recall_fscore_support(targets, preds, average = 'macro') 59 | metrics = {'acc': acc, 60 | 'precision': p, 61 | 'recall': r, 62 | 'f1': f1} 63 | return metrics 64 | 65 | if __name__ == '__main__': 66 | main() -------------------------------------------------------------------------------- /lm_eval/models/openai_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from environs import Env 3 | import openai 4 | from dataclasses import dataclass 5 | from typing import Optional, Union, List, Dict, Any 6 | 7 | @dataclass 8 | class OpenAI_Conf: 9 | """Model name and OpenAI generation parameters""" 10 | model: str 11 | num_output: int = 1 12 | max_tokens: Optional[int] = 30 13 | temperature: Optional[float] = 1.0 14 | stop: Union[str, List[str]] = '\n' 15 | 16 | class OpenAI_Model: 17 | """ 18 | Wrapper of openai chat and completion endpoint. 19 | 20 | Attributes: 21 | - endpoint: chat for chat models and complete for lm 22 | """ 23 | def __init__(self, config: Union[OpenAI_Conf, dict]): 24 | self.config = config if isinstance(config, OpenAI_Conf) else OpenAI_Conf(**config) 25 | self.model = config.model 26 | self.endpoint = 'chat' if self.is_chat_endpoint(self.model) else 'complete' 27 | self.set_api_key() 28 | 29 | def set_api_key(self): 30 | env = Env() 31 | env.read_env() 32 | openai.api_key = os.environ.get("OPENAI_API_KEY") 33 | 34 | def get_gen_kws(self): 35 | """Return generation arguments in dict""" 36 | return dict( 37 | max_tokens = self.config.max_tokens, 38 | temperature = self.config.temperature, 39 | stop = self.config.stop 40 | ) 41 | 42 | def complete(self, prompt, n) -> List[str]: 43 | response = openai.Completion.create( 44 | model = self.model, 45 | prompt = prompt, 46 | n = n, 47 | **self.get_gen_kws() 48 | ) 49 | choices = [c.text for c in response.choices] 50 | return choices 51 | 52 | def _chat_complete(self, messages, n) -> List[str]: 53 | """Use the chat endpoint as a completion endpoint""" 54 | response = openai.ChatCompletion.create( 55 | model = self.model, 56 | messages= messages, 57 | n = n, 58 | **self.get_kws() 59 | ) 60 | choices = [c['message']['content'] for c in response['choices']] 61 | return choices 62 | 63 | def chatcomplete(self, prompt: Union[str, List[dict]], n) -> List[str]: 64 | """ 65 | Args: 66 | prompt: a context or messages 67 | n: number of output 68 | """ 69 | if isinstance(prompt, str): 70 | messages = [{"role": "user", "content": prompt},] 71 | else: 72 | messages = prompt 73 | return self._chat_complete(messages, n) 74 | 75 | def generate(self, input_text, num_output: Optional[int] = None): 76 | num_output = num_output or self.gen_config.num_output or 1 77 | if self.endpoint == 'chat': 78 | choices = self.chatcomplete(input_text, num_output) 79 | elif self.endpoint == 'complete': 80 | choices = self.complete(input_text, num_output) 81 | else: 82 | raise ValueError(f'endpoint: {self.endpoint}') 83 | return choices 84 | 85 | @staticmethod 86 | def is_chat_endpoint(model): 87 | return 'gpt' in model -------------------------------------------------------------------------------- /lm_eval/tasks/ljp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate prompts for legal judgment prediction. 3 | """ 4 | 5 | import re 6 | from pathlib import Path 7 | import pandas as pd 8 | import json 9 | import random 10 | import numpy as np 11 | import pickle 12 | from collections import Counter 13 | from transformers import PreTrainedTokenizer 14 | from tiktoken import Encoding 15 | from dataclasses import dataclass 16 | import jieba 17 | from typing import Union, List, Dict, Any 18 | 19 | from .base import TaskBase 20 | from lm_eval.utils import read_jsonl, read_json 21 | 22 | @dataclass 23 | class JudgmentPrediction_Config: 24 | data_path: str 25 | prompt_config_file: str 26 | query_max_len: int = 1000 27 | demo_max_len: int = 400 28 | 29 | class JudgmentPrediction_Task(TaskBase): 30 | """ 31 | Legal judgment prediction as a multi-class classification task. 32 | 33 | Subtasks: 34 | {free, multi}-{0..5}-shot 35 | - free is free generation, multi is multi-choice question 36 | - support zero-shot and few-shot 37 | 38 | Data Folder: 39 | - `test_data.json`: dict of fields of idx, fact, charge, sim_demo_idx, cdd_charge_list 40 | - `train_data.json`: dict of fields of idx, fact, charge 41 | - `charge2id.json`: mapping from charge names to charge label id 42 | 43 | Prompt Configuration: 44 | Stored as a dict of fields: 45 | - instruction_{free, multi}_{zs, fs}: instruction of four settings 46 | - demo_template: which has the slot of *input* and *output* 47 | Prompt: 48 | Build: + + + 49 | Note: 50 | - is present in multi-choice setting 51 | - is present in few-shot setting 52 | """ 53 | def __init__( 54 | self, 55 | config: Union[JudgmentPrediction_Config, dict], 56 | tokenizer: Union[PreTrainedTokenizer, Encoding] 57 | ): 58 | self.config = (config if isinstance(config,JudgmentPrediction_Config) 59 | else JudgmentPrediction_Config(**config)) 60 | 61 | self.prompt_config = read_json(self.config.prompt_config_file) 62 | data_dir = Path(self.config.data_path) 63 | 64 | train_ds = read_jsonl(data_dir / 'train_data.json') 65 | self.train_ds_map = {k['idx']:k for k in train_ds} 66 | 67 | self.test_ds = read_jsonl(data_dir / 'test_data.json') 68 | 69 | self.tokenizer = tokenizer 70 | 71 | label2id = read_json(data_dir / 'charge2id.json') 72 | self.label2id = {self.convert_label(k):v for k,v in label2id.items()} 73 | 74 | def get_task_data(self): 75 | return self.test_ds 76 | 77 | @staticmethod 78 | def convert_label(label: str): 79 | """Remove the square brackets in original charge names""" 80 | return re.sub(r'[\[\]]', '',label) 81 | 82 | def get_all_subtask(self): 83 | return [f'{t}-{i}shot' for t in ['free', 'multi'] for i in range(5)] 84 | 85 | def build_subtask(self, name)->List[Dict[str, Any]]: 86 | """ 87 | Return a list of example prompt data, each of which is a dict of idx and prompt 88 | """ 89 | # parse subtask name to get task type and demo number 90 | ttype, n_shot = name.split('-') 91 | n_shot = int(n_shot[0]) 92 | 93 | prompt_data = [ 94 | { 95 | 'idx': example['idx'], 96 | 'prompt': self.build_example_prompt(example, ttype, n_shot) 97 | } for example in self.test_ds 98 | ] 99 | return prompt_data 100 | 101 | def build_example_prompt(self, example, task_type, n_shot): 102 | # get instruction 103 | p_config = self.prompt_config 104 | shot_s = 'zs' if n_shot == 0 else 'fs' 105 | instruct = p_config[f'instruction_{task_type}_{shot_s}'] 106 | 107 | # determin label candidate list 108 | if task_type == 'multi': 109 | instruct += '\n' + '\n'.join(map(self.convert_label, example['cdd_charge_list'])) 110 | 111 | # add demonstrations 112 | all_demo_str = [] 113 | for demo_id in example['sim_demo_idx'][:n_shot]: 114 | demo_example = self.train_ds_map[demo_id] 115 | demo_str = p_config['demo_template'].format( 116 | input = self.cut_text(demo_example['fact'], self.config.demo_max_len), 117 | output = self.convert_label(demo_example['charge']) 118 | ) 119 | all_demo_str.append(demo_str) 120 | 121 | # add query example 122 | query_str = p_config['demo_template'].format( 123 | input = self.cut_text(example['fact'], self.config.query_max_len), 124 | output = '' 125 | ) 126 | 127 | # join all prompt components 128 | all_demo_str = '\n\n'.join(all_demo_str) 129 | if len(all_demo_str) > 0: 130 | all_demo_str = '\n\n' + all_demo_str 131 | prompt = instruct + all_demo_str + '\n\n' + query_str 132 | return prompt 133 | -------------------------------------------------------------------------------- /lm_eval/evaluate.py: -------------------------------------------------------------------------------- 1 | """Evaluate a model on subtasks, save results and metrics""" 2 | import os 3 | import numpy as np 4 | from pathlib import Path 5 | import json 6 | from tqdm import tqdm 7 | import traceback 8 | from sklearn.metrics import precision_recall_fscore_support 9 | import time 10 | from typing import List 11 | 12 | from lm_eval.models import HF_AutoModel 13 | from lm_eval.tasks import JudgmentPrediction_Task, TaskBase 14 | from lm_eval.utils import read_jsonl, save_jsonl 15 | from lm_eval.parser import BM25_Parser 16 | 17 | def perform_task_timely_save( 18 | tasks, map_func, save_file, 19 | id_field = 'idx', skip_error = True, simple_log = True 20 | ): 21 | """ 22 | Apply map_func on each task and save returned results to a file immediately. 23 | 24 | This function support resume after disruption by skipping finished tasks in save_file 25 | 26 | Args: 27 | - tasks: a list of task to be applied on by map_func 28 | - map_func: a callable function to apply on each task 29 | - save_file: file path to save map_func results 30 | - id_field: the field name to identify examples 31 | - skip_error: if error occurs in map_func, whether to skip the example 32 | - simple_log: True to print error message, otherwise print the traceback 33 | """ 34 | # Load previous finished tasks if exist 35 | if Path(save_file).exists(): 36 | prev_task = [json.loads(k) for k in open(save_file, encoding='utf8')] 37 | else: 38 | prev_task = [] 39 | prev_idx = set([k[id_field] for k in prev_task]) 40 | print(f'Previous finished: {len(prev_idx)}. Total: {len(tasks)}') 41 | 42 | # Keep unfinished tasks 43 | left_tasks = list(filter(lambda k:k[id_field] not in prev_idx, tasks)) 44 | 45 | # Perform tasks 46 | for sample in tqdm(left_tasks, ncols = 80): 47 | try: 48 | results = map_func(sample) 49 | # add id field 50 | if id_field not in results: 51 | results[id_field] = sample[id_field] 52 | # write results to a file 53 | with open(save_file, 'a', encoding='utf8') as f: 54 | f.write(json.dumps(results, ensure_ascii=False) + '\n') 55 | except Exception as e: 56 | if simple_log: 57 | err_str = str(e) 58 | else: 59 | err_str = traceback.format_exc() 60 | tqdm.write(f'Error {id_field}={sample[id_field]}, {err_str}') 61 | if not skip_error: 62 | exit() 63 | 64 | def evaluate_subtasks( 65 | model, 66 | task: JudgmentPrediction_Task, 67 | output_dir: str, 68 | sub_tasks: str = 'all' 69 | ): 70 | """ 71 | Args: 72 | - model: a HF_Model or Openai_Model 73 | - task: a TaskBase instance 74 | - output_dir: directory to save subtask results. Each subtask will create its own sub-directory. 75 | - sub_tasks: 'all' to evaluate on all sub_tasks, or subtask names joint by comma (,) 76 | 77 | Output dir structure: 78 | 79 | raw_output.txt 80 | ... 81 | eval_results.txt 82 | 83 | Each line of `eval_results.txt` is a json dict of time, subtask, metrics 84 | """ 85 | # Create output_dir 86 | output_dir = Path(output_dir) 87 | output_dir.mkdir(parents = True, exist_ok = True) 88 | metric_file = output_dir / 'eval_results.txt' 89 | 90 | # prepare some objects for evaluation 91 | test_ds = task.get_task_data() 92 | label2id = task.label2id 93 | grounds = [label2id[task.convert_label(k['charge'])] for k in test_ds] 94 | parser = BM25_Parser(label2id) 95 | 96 | if sub_tasks == 'all': 97 | sub_tasks = task.get_all_subtask() 98 | else: 99 | sub_tasks = sub_tasks.split(',') 100 | 101 | def ljp_handler(data): 102 | choices = model.generate(data['prompt']) 103 | return {'idx': data['idx'], 'choices': choices, 'prompt': data['prompt']} 104 | 105 | for task_name in sub_tasks: 106 | # Build task data 107 | print(f'Build {task_name} data') 108 | task_data = task.build_subtask(task_name) 109 | 110 | # Create subtask dir 111 | subtask_dir = output_dir / task_name 112 | subtask_dir.mkdir(exist_ok = True) 113 | save_file = subtask_dir / 'raw_output.txt' 114 | # Perform subtask 115 | perform_task_timely_save(task_data, ljp_handler, save_file, id_field = 'idx') 116 | 117 | # Evaluate subtask 118 | outputs = read_jsonl(save_file) 119 | idx2out = {k['idx']: k['choices'] for k in outputs} 120 | finish_all = all([k['idx'] in idx2out for k in test_ds]) 121 | if not finish_all: 122 | print(( 123 | 'Inference of some examples failed. Skip evaluation.\n' 124 | 'To finish all test examples, run the script again.\n' 125 | 'If the failure is due to limited resources, e.g., GPU Memory,' 126 | 'adjust some hyperparameters, e.g., max_len, and run the script again.' 127 | )) 128 | continue 129 | # parse open generated text to pre-defined label names 130 | print('Save parsed results') 131 | parse_results = [{'idx': k['idx'], 'pred': parser(k['choices'])} for k in outputs] 132 | save_jsonl(parse_results, subtask_dir / 'parse_results.txt') 133 | 134 | # prepare for calculating metrics 135 | idx2pred = {k['idx']: label2id[k['pred']] for k in parse_results} 136 | preds = [idx2pred[k['idx']] for k in test_ds] 137 | metrics = get_ljp_metrics(preds, grounds) 138 | metrics = {k:float(v) for k,v in metrics.items()} # for serialization 139 | log = {'time': time.time(), 'subtask': task_name, 'metrics': metrics} 140 | print('Evaluation results:\n' + str(log)) 141 | with open(metric_file, 'a') as f: 142 | f.write(json.dumps(log) + '\n') 143 | 144 | def get_ljp_metrics(preds: List[int], targets: List[int]): 145 | preds = np.array(preds, dtype = np.int64) 146 | targets = np.array(targets, dtype = np.int64) 147 | acc = (preds == targets).astype(np.float32).mean() 148 | p,r,f1, _ = precision_recall_fscore_support(targets, preds, average = 'macro') 149 | metrics = {'acc': acc, 150 | 'precision': p, 151 | 'recall': r, 152 | 'f1': f1} 153 | return metrics -------------------------------------------------------------------------------- /lm_eval/models/hf_model.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch 3 | from transformers import LlamaConfig, AutoModelForCausalLM, AutoModelForSeq2SeqLM,AutoTokenizer, PreTrainedModel, GenerationConfig, AutoConfig 4 | from peft import PeftModel, PeftConfig, AutoPeftModelForCausalLM 5 | from dataclasses import dataclass, asdict 6 | from typing import Optional, Union, List, Dict, Any 7 | 8 | @dataclass 9 | class HF_Model_Config: 10 | """ 11 | Class that holds arguments to build transformers model. 12 | Args: 13 | base_model: the name or path of the base model 14 | peft_model: model_id of peft model 15 | trust_remote_code: should be true if using third-party models 16 | device_map: default to 'auto' to distribute parameters among all gpus 17 | save_memory: If true, generate one output each time. Otherwise decode multiple outputs together 18 | """ 19 | base_model: str = None 20 | peft_model: str = None 21 | trust_remote_code: bool = False 22 | torch_dtype: Union[str, torch.dtype] = torch.float16 23 | device_map: Union[str, Dict] = 'auto' 24 | save_memory: bool = True 25 | 26 | def __post_init__(self): 27 | # infer base_model from peft config if not specified 28 | if not (self.base_model or self.peft_model): 29 | raise ValueError(f'Please specify at least one of `base_model` or `peft_model`') 30 | if not self.base_model: 31 | self.base_model = PeftConfig.from_pretrained(self.peft_model).base_model_name_or_path 32 | 33 | # convert torch_dtype 34 | if not isinstance(self.torch_dtype, torch.dtype): 35 | if self.torch_dtype == 'float16': 36 | self.torch_dtype = torch.float16 37 | else: 38 | self.torch_dtype = 'auto' 39 | 40 | @dataclass 41 | class HF_Gen_Conf: 42 | """LM generation arguments similar to transformers.GenerationConfig""" 43 | num_return_sequences: Optional[int] = 1 44 | max_new_tokens: Optional[int] = 30 45 | # sampling strategy 46 | do_sample: bool = True 47 | temperature: float = 1.0 48 | top_k: Optional[int] = None 49 | top_p: Optional[float] = None 50 | repetition_penalty: Optional[float] = None 51 | 52 | pad_token_id: Optional[int] = None 53 | eos_token_id: Optional[int] = None 54 | 55 | def to_dict(self): 56 | return {k: deepcopy(v) for k,v in asdict(self).items() if v is not None} 57 | 58 | class HF_AutoModel: 59 | """ 60 | Initialize a huggingface model and handle generation. 61 | 62 | Attributes: 63 | - is_encoder_decoder: If true, load with Seq2SeqLM and 64 | do not remove prefix of generate results. 65 | """ 66 | def __init__( 67 | self, 68 | config: Union[HF_Model_Config, dict], 69 | gen_config: Union[HF_Gen_Conf, dict] = {} 70 | ): 71 | self.config = config if isinstance(config, HF_Model_Config) else HF_Model_Config(**config) 72 | self.gen_config = gen_config if isinstance(gen_config, HF_Gen_Conf) else HF_Gen_Conf(**gen_config) 73 | 74 | self.init_tokenizer() 75 | 76 | self._model = None 77 | # the model will be initialized when accessed 78 | 79 | def init_tokenizer(self): 80 | config = self.config 81 | self.tokenizer = AutoTokenizer.from_pretrained( 82 | config.base_model, trust_remote_code = config.trust_remote_code 83 | ) 84 | # update generation config 85 | # fix pad_token_id issue: https://github.com/huggingface/transformers/issues/25353 86 | self.gen_config.eos_token_id = self.tokenizer.eos_token_id 87 | self.gen_config.pad_token_id = ( 88 | self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None 89 | else self.tokenizer.eos_token_id 90 | ) 91 | 92 | @property 93 | def model(self): 94 | if self._model is None: 95 | self.init_model() 96 | return self._model 97 | 98 | def init_model(self): 99 | """Determin is_encoder_decoder, Load base model and peft model if any""" 100 | config = self.config 101 | kws = self.get_init_kws() 102 | print(f'Initialize base model: {config.base_model}') 103 | # determin is_encoder_decoder 104 | hf_cfg = AutoConfig.from_pretrained(config.base_model, trust_remote_code = config.trust_remote_code) 105 | self.is_encoder_decoder = hf_cfg.is_encoder_decoder 106 | # initialize CausalLM or Seq2SeqLM 107 | if self.is_encoder_decoder: 108 | auto_cls = AutoModelForSeq2SeqLM 109 | else: 110 | has_auto_map = hasattr(hf_cfg, "auto_map") 111 | if not has_auto_map or 'AutoModelForCausalLM' in hf_cfg.auto_map: 112 | auto_cls = AutoModelForCausalLM 113 | elif 'AutoModelForSeq2SeqLM' in hf_cfg.auto_map: 114 | # to handle prefix-lm e.g. chatglm 115 | auto_cls = AutoModelForSeq2SeqLM 116 | else: 117 | raise ValueError(f'Cannot determin the Auto class') 118 | base_model = auto_cls.from_pretrained(config.base_model, **kws) 119 | # load peft model 120 | if config.peft_model: 121 | print(f'Load peft model from {config.peft_model}') 122 | model = PeftModel.from_pretrained(base_model, config.peft_model) 123 | else: 124 | model = base_model 125 | self._base_model = base_model 126 | self._model = model 127 | 128 | def reload_peft_model(self, model_id): 129 | """Reload another peft model""" 130 | # update model config 131 | self.config.peft_model = model_id 132 | if self._model is None: 133 | self.init_model() 134 | else: 135 | print(f'Reload peft model from {model_id}') 136 | self._model = PeftModel.from_pretrained( 137 | self._base_model, model_id, **self.get_init_kws() 138 | ) 139 | 140 | def get_init_kws(self): 141 | config = self.config 142 | return dict( 143 | trust_remote_code = config.trust_remote_code, 144 | torch_dtype = config.torch_dtype, 145 | device_map = config.device_map 146 | ) 147 | 148 | def generate(self, input_text: str, num_output: Optional[int] = None, return_ids = False)->List[str]: 149 | """Generate conditioned on one input""" 150 | num_output = num_output or self.gen_config.num_return_sequences or 1 151 | # tokenize 152 | enc = self.tokenizer([input_text], return_tensors = 'pt') 153 | inputs = {k:v.cuda() for k,v in enc.items()} 154 | 155 | # generation config 156 | hf_gen_cfg = GenerationConfig(**self.gen_config.to_dict()) 157 | # kws = self.get_generation_kws() 158 | if self.config.save_memory: 159 | hf_gen_cfg.num_return_sequences = 1 160 | # kws['num_return_sequences'] = 1 161 | output_ids = [self.model.generate(generation_config = hf_gen_cfg, **inputs)[0] for _ in range(num_output)] 162 | else: 163 | hf_gen_cfg.num_return_sequences = num_output 164 | # kws['num_return_sequences'] = num_output 165 | output_ids = self.model.generate(generation_config = hf_gen_cfg, **inputs) 166 | 167 | # print(output_ids) 168 | pure_output_ids = [self.remove_prefix(k, inputs['input_ids']) for k in output_ids] 169 | choices = self.tokenizer.batch_decode(pure_output_ids, skip_special_tokens=True) 170 | 171 | return (choices, output_ids) if return_ids else choices 172 | 173 | def get_hf_generation_config(self): 174 | arg_names = ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'top_k', 'repetition_penalty'] 175 | kws = {k:v for k,v in self.gen_config.to_dict().items() if k in arg_names} 176 | return GenerationConfig(**kws) 177 | 178 | def get_generation_kws(self): 179 | arg_names = ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'top_k', 'repetition_penalty'] 180 | kws = {k:v for k,v in self.gen_config.to_dict().items() if k in arg_names} 181 | return kws 182 | 183 | def remove_prefix(self, tensor, input_ids): 184 | """Remove the input ids from the output ids""" 185 | if self.is_encoder_decoder: 186 | return tensor 187 | else: 188 | return tensor[len(input_ids[0]):] -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |

2 | ⚖️ 3 |
4 | A Comprehensive Evaluation of LLMs on 5 |
Legal Judgment Prediction 6 |

7 | 8 |
9 | 10 | ![](https://img.shields.io/badge/Task-Text%20Classification-orange) 11 | ![](https://img.shields.io/badge/Code%20License-MIT-green) 12 |
13 | 14 |

15 | 16 | [📜 Paper] • 17 | [🐱 GitHub] 18 |
19 | Quick Start • 20 | Citation 21 |

22 | 23 |

24 | Repo for "A Comprehensive Evaluation of Large Language Models on Legal Judgment Prediction"
published at EMNLP Findings 2023 25 |

26 | 27 | ## 💡 Introduction 28 | To comprehensively evaluate the law capacity of large language models, we propose baseline solutions and conduct evaluation on the task of *legal judgment prediction*. 29 | 30 | **Motivation** 31 | Existing benchmarks, e.g., [lm_eval_harness](https://github.com/EleutherAI/lm-evaluation-harness), mainly adopt a **perplexity**-based approach to select the most possible options as the prediction for classification tasks. *However*, LMs typically interact with humans in the way of open-ended generation. It is critical to directly evaluate the contents generated by greedy decoding or sampling. 32 | 33 | **Evaluation on LM Generated Contents** 34 | We propose an automatic evaluation pipeline to directly evaluate the generated contents for **classification** tasks. 35 | 1. Prompt LMs with task instruction to generate class labels. The generated contents may not strictly match standard label names. 36 | 2. Then, a parser is to map generated contents to labels, based on the text similarity scores. 37 | 38 | **LM + Retrieval System** 39 | 40 | To address the performance with retrieved information of LMs in legal domain, additional information, e.g., label candidates and similar cases as demonstrations, are included into prompts. Considering the combination of the two additional information, there are four sub-settings of prompts: 41 | - (**free, zero shot**): No additional information. Only task instruction. 42 | - (**free, few shot**): Task instruction + demonstrations 43 | - (**multi, zero shot**): Task instruction + label candidates (options) 44 | - (**multi, few shot**): Task instruction + label candidates + demonstrations 45 | 46 |

47 | setting 48 |

49 | 50 | 51 | 52 | ## 🔥 Leaderboard 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 |
rankmodelscorefree-0shotfree-1shotfree-2shotfree-3shotfree-4shotmulti-0shotmulti-1shotmulti-2shotmulti-3shotmulti-4shot
1gpt463.0550.5262.7267.5468.6171.0262.3170.4271.8173.2474.00
2chatgpt58.1343.1458.4261.8664.4066.1660.6763.5166.8569.5966.62
3chatglm_6b47.7441.8950.3047.7648.5948.6753.7449.2647.5647.6145.32
4bloomz_7b44.1446.9053.2851.0650.9049.2650.6829.2527.9225.2723.37
5vicuna_13b39.8325.5048.8547.6449.4939.8244.7041.7341.4835.0321.61
150 | 151 | Note: 152 | - **Metric**: Macro-F1 153 | - $score = (free\text{-}0shot + free\text{-}2shot + multi\text{-}0shot + multi\text{-}2shot)/4$ 154 | - OpenAI model names: *gpt-3.5-turbo-0301*, *gpt-4-0314* 155 | 156 | ## 🚀 Quick Start 157 | ### ⚙️ Install 158 | ```Bash 159 | git clone https://github.com/srhthu/LM-CompEval-Legal.git 160 | 161 | # Enter the repo 162 | cd LM-CompEval-Legal 163 | 164 | pip install -r requirements.txt 165 | 166 | bash download_data.sh 167 | # Download evaluation dataset to data_hub/ljp 168 | # Download model generated results to runs/paper_version 169 | ``` 170 | The data is availabel at [Google Drive](https://drive.google.com/drive/folders/1UWi9F4vtBORsnCUqDDkyehdH4IyDfnME?usp=share_link) 171 | 172 | ### Evaluate Models 173 | 174 | There are totally 10 `sub_tasks`: `{free,multi}-{0..4}`. 175 | 176 | Evaluate a **Huggingface** model on all sub_tasks: 177 | ```Bash 178 | CUDA_VISIBLE_DEVICES=0 python main.py \ 179 | --config ./config/default_hf.json \ 180 | --output_dir ./runs/test/ \ 181 | --model_type hf \ 182 | --model 183 | ``` 184 | 185 | Evaluate a **OpenAI** model on all sub_tasks: 186 | ```Bash 187 | CUDA_VISIBLE_DEVICES=0 python main.py \ 188 | --config ./config/default_openai.json \ 189 | --output_dir ./runs/test/ \ 190 | --model_type openai \ 191 | --model 192 | ``` 193 | 194 | To evaluate some of the whole settings, add one more argument, e.g., 195 | ```Bash 196 | --sub_tasks 'free-0shot,free-2shot,multi-0shot,multi-2shot' 197 | ``` 198 | 199 | The huggingface paths of the evaluated models in the paper are 200 | - ChatGLM: `THUDM/chatglm-6b` 201 | - BLOOMZ: `bigscience/bloomz-7b1-mt` 202 | - Vicuna: `lmsys/vicuna-13b-delta-v1.1` 203 | 204 | **Features**: 205 | > - If the evaluation process is interupted, just run it again with the same parameters. The process saves model outputs immediately and will skip previous finished samples when resuming. 206 | > - Samples that trigger a GPU out-of-memory error will be skipped. You can change the configurations and run the process again. (See suggested GPU configurations below) 207 | 208 | **Suggested GPU configurations** 209 | - 7B model 210 | - 1 GPU with RAM around 24G (RTX 3090, A5000) 211 | - If total RAM >=**32G**, e.g., 2\*RTX3090 or 1\*V100(32G), add the `--speed` argument for faster inference. 212 | - 13B model 213 | - 2 GPU with RAM >= 24G (e.g., 2\*V100) 214 | - If total RAM>=**64G**, e.g., 3\*RTX3090 or 2\*V100, add the `--speed` argument for faster inference 215 | > When context is long, e.g., in multi-4shot setting, 1 GPU of 24G RAM may be insufficient for 7B model. You have to eigher increase the number of GPUs or decrease the demonstration length (default to 500) by modifying the *demo_max_len* parameter in `config/default_hf.json` 216 | 217 | ### Create Result table 218 | After evaluating some models locally, the leaderboard can be generated in csv format: 219 | 220 | ```Bash 221 | python scripts/get_result_table.py \ 222 | --exp_dir runs/paper_version \ 223 | --metric f1 \ 224 | --save_path resources/paper_version_f1.csv 225 | ``` 226 | 227 | ## Citation 228 | 229 | ``` 230 | @misc{shui2023comprehensive, 231 | title={A Comprehensive Evaluation of Large Language Models on Legal Judgment Prediction}, 232 | author={Ruihao Shui and Yixin Cao and Xiang Wang and Tat-Seng Chua}, 233 | year={2023}, 234 | eprint={2310.11761}, 235 | archivePrefix={arXiv}, 236 | primaryClass={cs.CL} 237 | } 238 | ``` 239 | 240 | 264 | --------------------------------------------------------------------------------