├── .env.example ├── .venv ├── lib64 ├── bin │ ├── python │ ├── python3.10 │ └── python3 └── pyvenv.cfg ├── financerag ├── generate │ ├── __init__.py │ └── openai.py ├── rerank │ ├── __init__.py │ └── cross_encoder.py ├── common │ ├── __init__.py │ ├── loader.py │ └── protocols.py ├── retrieval │ ├── __init__.py │ ├── sent_encoder.py │ ├── bm25.py │ └── dense.py ├── __init__.py └── tasks │ ├── __init__.py │ ├── FinDERTask.py │ ├── FinQABenchTask.py │ ├── FinQATask.py │ ├── ConvFinQATask.py │ ├── FinanceBenchTask.py │ ├── MultiHierttTask.py │ ├── TATQATask.py │ ├── TaskMetadata.py │ └── BaseTask.py ├── src ├── main.py └── __init__.py ├── .idea ├── vcs.xml ├── misc.xml ├── .gitignore ├── inspectionProfiles │ ├── profiles_settings.xml │ └── Project_Default.xml ├── modules.xml └── linq-rag-FinanceRAG.iml ├── activate.sh ├── requirements.txt ├── activate.bat ├── LICENSE ├── README.md ├── GAR ├── run_hybrid_search.py ├── Selection_agent.py ├── main.py ├── run_finetuning.py ├── Finetuned_DPO_agent.py ├── DPO_agent.py ├── Embedder_Finetuning.py ├── hybrid_search.py └── DPO_agent_Finetuning.py └── .gitignore /.env.example: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.venv/lib64: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.venv/bin/python: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.venv/pyvenv.cfg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.venv/bin/python3.10: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.venv/bin/python3: -------------------------------------------------------------------------------- 1 | /usr/bin/python3 -------------------------------------------------------------------------------- /financerag/generate/__init__.py: -------------------------------------------------------------------------------- 1 | from .openai import OpenAIGenerator 2 | -------------------------------------------------------------------------------- /financerag/rerank/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_encoder import CrossEncoderReranker 2 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | def main(): 2 | """ 3 | 메인 함수 4 | """ 5 | print("GAR Project Started") 6 | 7 | if __name__ == "__main__": 8 | main() -------------------------------------------------------------------------------- /financerag/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .loader import HFDataLoader 2 | from .protocols import CrossEncoder, Encoder, Generator, Lexical, Reranker, Retrieval 3 | -------------------------------------------------------------------------------- /financerag/retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | from .bm25 import BM25Retriever 2 | from .dense import DenseRetrieval 3 | from .sent_encoder import SentenceTransformerEncoder 4 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # 패키지 초기화 파일 2 | 3 | # 버전 정보 4 | __version__ = '0.1.0' 5 | 6 | # 주요 함수들을 패키지 레벨로 노출 7 | from .main import main 8 | 9 | # 공개할 함수들 정의 10 | __all__ = ['main'] -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /financerag/__init__.py: -------------------------------------------------------------------------------- 1 | from .retrieval import ( 2 | DenseRetrieval, 3 | BM25Retriever, 4 | SentenceTransformerEncoder 5 | ) 6 | 7 | from .rerank import ( 8 | CrossEncoderReranker 9 | ) 10 | 11 | from .generate import ( 12 | OpenAIGenerator 13 | ) -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/linq-rag-FinanceRAG.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /financerag/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .BaseTask import BaseTask 2 | from .ConvFinQATask import ConvFinQA 3 | from .FinDERTask import FinDER 4 | from .FinQABenchTask import FinQABench 5 | from .FinQATask import FinQA 6 | from .FinanceBenchTask import FinanceBench 7 | from .MultiHierttTask import MultiHiertt 8 | from .TATQATask import TATQA 9 | from .TaskMetadata import TaskMetadata 10 | -------------------------------------------------------------------------------- /activate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -d "venv" ]; then 4 | echo "Creating virtual environment..." 5 | python3 -m venv venv 6 | source venv/bin/activate 7 | pip install -r requirements.txt 8 | else 9 | source venv/bin/activate 10 | fi 11 | 12 | 13 | set -a 14 | source .env 15 | set +a 16 | 17 | echo "Virtual environment activated and environment variables set!" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | datasets~=3.0.0 2 | pydantic~=2.9.2 3 | pytrec_eval~=0.5 4 | typing_extensions~=4.12.2 5 | numpy==1.26.4 6 | torch~=2.2.2 7 | torchvision 8 | nltk~=3.9.1 9 | sentence-transformers==3.1.1 10 | transformers==4.45.2 11 | openai~=1.47.1 12 | numpy==1.26.4 13 | pandas==2.2.2 14 | tqdm==4.67.1 15 | python-dotenv 16 | scikit-learn 17 | transformers==4.45.2 18 | huggingface_hub 19 | rank_bm25 20 | flash-attn -------------------------------------------------------------------------------- /activate.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | set VENV_PATH=C:\Users\shw41\GAR\venv 4 | 5 | if exist "%VENV_PATH%" ( 6 | echo Found existing .venv at %VENV_PATH% 7 | call %VENV_PATH%\Scripts\activate.bat 8 | ) else ( 9 | echo Virtual environment .venv not found! 10 | echo Creating virtual environment... 11 | python -m venv %VENV_PATH% 12 | call %VENV_PATH%\Scripts\activate.bat 13 | pip install -r requirements.txt 14 | ) 15 | 16 | set PYTHONPATH=%PYTHONPATH%;%CD% 17 | set VIRTUAL_ENV=%VENV_PATH% 18 | set PATH=%VIRTUAL_ENV%\Scripts;%PATH% 19 | 20 | echo Virtual environment activated and environment variables set! -------------------------------------------------------------------------------- /financerag/tasks/FinDERTask.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from .BaseTask import BaseTask 4 | from .TaskMetadata import TaskMetadata 5 | 6 | 7 | class FinDER(BaseTask): 8 | def __init__(self): 9 | self.metadata: TaskMetadata = TaskMetadata( 10 | name="FinDER", 11 | description="Prepared for competition from Linq", 12 | reference=None, 13 | dataset={ 14 | "path": "Linq-AI-Research/FinanceRAG", 15 | "subset": "FinDER", 16 | }, 17 | type="RAG", 18 | category="s2p", 19 | modalities=["text"], 20 | date=None, 21 | domains=["Report"], 22 | task_subtypes=[ 23 | "Financial retrieval", 24 | "Question answering", 25 | ], 26 | license=None, 27 | annotations_creators="expert-annotated", 28 | dialect=[], 29 | sample_creation="human-generated", 30 | bibtex_citation=None, 31 | ) 32 | super().__init__(self.metadata) 33 | 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 JiH00nKw0n 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /financerag/tasks/FinQABenchTask.py: -------------------------------------------------------------------------------- 1 | from .BaseTask import BaseTask 2 | from .TaskMetadata import TaskMetadata 3 | 4 | 5 | class FinQABench(BaseTask): 6 | def __init__(self): 7 | self.metadata: TaskMetadata = TaskMetadata( 8 | name="FinQABench", 9 | description="FinQABench: A New QA Benchmark for Finance applications", 10 | reference="https://huggingface.co/datasets/lighthouzai/finqabench", 11 | dataset={ 12 | "path": "Linq-AI-Research/FinanceRAG", 13 | "subset": "FinQABench", 14 | }, 15 | type="RAG", 16 | category="s2p", 17 | modalities=["text"], 18 | date=None, 19 | domains=["Report"], 20 | task_subtypes=[ 21 | "Financial retrieval", 22 | "Question answering", 23 | ], 24 | license="apache-2.0", 25 | annotations_creators="LM-generated and reviewed", 26 | dialect=[], 27 | sample_creation="LM-generated and verified", 28 | bibtex_citation=None, 29 | ) 30 | super().__init__(self.metadata) 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GAR 2 | 3 | This project is a RAG-based (Retrieval-Augmented Generation) answer generation and refinement pipeline specifically designed for question answering in the finance domain. 4 | 5 | ## Project Structure 6 | 7 | GAR/ 8 | ├── task1/ # Hybrid Search implementation 9 | │ └── Embedder_Finetuning.py 10 | │ └── hybrid_search.py 11 | │ └── run_finetuning.py 12 | │ └── run_hybrid_search.py 13 | │ └── main.py # Main pipeline execution 14 | └── requirements.txt # List of required packages 15 | 16 | 17 | ## Pipeline Description 18 | Embedder_Finetuning.py: Fine-tuning retrieval model 19 | hybrid_search.py: Execute hybrid search, and find optimal alpha for each task 20 | run_finetuning.py: Run Embedder_Finetuning.py 21 | run_hybrid_search.py: Run hybrid_search.py 22 | 23 | ## Installation 24 | 25 | 1. Clone the repository 26 | git clone https://github.com/seohyunwoo-0407/GAR.git 27 | cd GAR 28 | 29 | 30 | 2. Create and activate a virtual environment 31 | python -m venv .venv 32 | source .venv/bin/activate # Linux/Mac 33 | .venv\Scripts\activate # Windows 34 | 35 | 3. Install required packages 36 | pip install -r requirements.txt 37 | -------------------------------------------------------------------------------- /GAR/run_hybrid_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import nltk 5 | from financerag.tasks import FinDER, FinQABench, FinanceBench, TATQA, FinQA, ConvFinQA, MultiHiertt 6 | from financerag.retrieval import DenseRetrieval, SentenceTransformerEncoder 7 | from sentence_transformers import SentenceTransformer, CrossEncoder 8 | import pandas as pd 9 | from hybrid_search import HybridSearcher 10 | 11 | logging.basicConfig(level=logging.INFO) 12 | 13 | nltk.download('punkt') 14 | hybrid_searcher = HybridSearcher() 15 | 16 | tasks, names = hybrid_searcher.setup_task() 17 | 18 | retrieval_model = hybrid_searcher.retrieval_model_setup() 19 | reranker_model = hybrid_searcher.reranker_model_setup() 20 | 21 | output_dir = "C:/Users/shw41/GAR/data/task1" 22 | 23 | for task, name in zip(tasks, names): 24 | optimal_alpha = hybrid_searcher.tune_alpha(task, retrieval_model) 25 | hybrid_retrieval_results = hybrid_searcher.get_hybrid_score(task, optimal_alpha, retrieval_model) 26 | reranked_results = hybrid_searcher.get_reranker_score(task, hybrid_retrieval_results, reranker_model) 27 | ndcg_score = hybrid_searcher.get_final_ndcg(tasks, names) 28 | 29 | print(f"Task: {name}, NDCG Score: {ndcg_score}") 30 | print(f"Optimal Alpha: {optimal_alpha}") 31 | task.save_results(output_dir, name, reranked_results) 32 | 33 | hybrid_searcher.merge_csv_results(output_dir) 34 | -------------------------------------------------------------------------------- /GAR/Selection_agent.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import os 3 | from dotenv import load_dotenv 4 | import json 5 | load_dotenv() 6 | client = openai.OpenAI( 7 | api_key = os.getenv('OPENAI_API_KEY') 8 | ) 9 | model_name = 'gpt-4o-mini' 10 | def load_jsonl(path): 11 | with open(path, 'r', encoding='utf-8') as file: 12 | data = [json.loads(line) for line in file] 13 | return data 14 | def save_to_json(data, path): 15 | with open(path, 'w', encoding='utf-8') as file: 16 | for i in data: 17 | file.write(json.dumps(i, ensure_ascii=False) + '\n') 18 | def evaluate_corpus_relevance(query_line, corpus_list): 19 | prompt = f"""Given the following query and 10 corpus passages, determine if each corpus is relevant for answering the query. 20 | Respond ONLY with 10 letters (T or F) in sequence, without any additional text. 21 | T means the corpus is relevant, F means it is not relevant. 22 | Query: {query_line} 23 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(corpus_list)]) 24 | response = client.chat.completions.create( 25 | model=model_name, 26 | messages=[{"role": "user", "content": prompt}], 27 | temperature=0 28 | ) 29 | relevance = dict(response)['choices'][0].message.content 30 | selected_corpus = [corpus for corpus, is_relevant in zip(corpus_list, relevance) if is_relevant == 'T'] 31 | combined_corpus = "\n".join([f"###DOCUMENT_{i+1}\n{corpus}\n" for i, corpus in enumerate(selected_corpus)]) 32 | print(f"relevance: {relevance}") 33 | return combined_corpus -------------------------------------------------------------------------------- /financerag/tasks/FinQATask.py: -------------------------------------------------------------------------------- 1 | from .BaseTask import BaseTask 2 | from .TaskMetadata import TaskMetadata 3 | 4 | 5 | class FinQA(BaseTask): 6 | def __init__(self): 7 | self.metadata: TaskMetadata = TaskMetadata( 8 | name="FinQA", 9 | description="FinQA: A Dataset of Numerical Reasoning over Financial Data", 10 | reference="https://github.com/czyssrs/FinQA", 11 | dataset={ 12 | "path": "Linq-AI-Research/FinanceRAG", 13 | "subset": "FinQA", 14 | }, 15 | type="RAG", 16 | category="s2p", 17 | modalities=["text"], 18 | date=None, 19 | domains=["Report"], 20 | task_subtypes=[ 21 | "Financial retrieval", 22 | "Question answering", 23 | ], 24 | license="mit", 25 | annotations_creators="expert-annotated", 26 | dialect=[], 27 | sample_creation="human-generated", 28 | bibtex_citation=""" 29 | @article{chen2021finqa, 30 | title={FinQA: A Dataset of Numerical Reasoning over Financial Data}, 31 | author={Chen, Zhiyu and Chen, Wenhu and Smiley, Charese and Shah, Sameena and Borova, Iana and Langdon, Dylan and Moussa, Reema and Beane, Matt and Huang, Ting-Hao and Routledge, Bryan and Wang, William Yang}, 32 | journal={Proceedings of EMNLP 2021}, 33 | year={2021} 34 | } 35 | """, 36 | ) 37 | super().__init__(self.metadata) 38 | -------------------------------------------------------------------------------- /financerag/tasks/ConvFinQATask.py: -------------------------------------------------------------------------------- 1 | from .BaseTask import BaseTask 2 | from .TaskMetadata import TaskMetadata 3 | 4 | 5 | class ConvFinQA(BaseTask): 6 | def __init__(self): 7 | self.metadata: TaskMetadata = TaskMetadata( 8 | name="ConvFinQA", 9 | description="ConvFinQA: Exploring the Chain of Numerical Reasoning in Conversational Finance Question Answering", 10 | reference="https://github.com/czyssrs/ConvFinQA", 11 | dataset={ 12 | "path": "Linq-AI-Research/FinanceRAG", 13 | "subset": "ConvFinQA", 14 | }, 15 | type="RAG", 16 | category="s2p", 17 | modalities=["text"], 18 | date=None, 19 | domains=["Report"], 20 | task_subtypes=[ 21 | "Financial retrieval", 22 | "Question answering", 23 | ], 24 | license="mit", 25 | annotations_creators="expert-annotated", 26 | dialect=[], 27 | sample_creation="human-generated", 28 | bibtex_citation=""" 29 | @article{chen2022convfinqa, 30 | title={ConvFinQA: Exploring the Chain of Numerical Reasoning in Conversational Finance Question Answering}, 31 | author={Chen, Zhiyu and Li, Shiyang and Smiley, Charese and Ma, Zhiqiang and Shah, Sameena and Wang, William Yang}, 32 | journal={Proceedings of EMNLP 2022}, 33 | year={2022} 34 | } 35 | """, 36 | ) 37 | super().__init__(self.metadata) 38 | -------------------------------------------------------------------------------- /financerag/tasks/FinanceBenchTask.py: -------------------------------------------------------------------------------- 1 | from .BaseTask import BaseTask 2 | from .TaskMetadata import TaskMetadata 3 | 4 | 5 | class FinanceBench(BaseTask): 6 | def __init__(self): 7 | self.metadata: TaskMetadata = TaskMetadata( 8 | name="FinanceBench", 9 | description="FinanceBench: A New Benchmark for Financial Question Answering", 10 | reference="https://github.com/patronus-ai/financebench", 11 | dataset={ 12 | "path": "Linq-AI-Research/FinanceRAG", 13 | "subset": "FinanceBench", 14 | }, 15 | type="RAG", 16 | category="s2p", 17 | modalities=["text"], 18 | date=None, 19 | domains=["Report"], 20 | task_subtypes=[ 21 | "Financial retrieval", 22 | "Question answering", 23 | ], 24 | license=None, 25 | annotations_creators="expert-annotated", 26 | dialect=[], 27 | sample_creation="human-generated", 28 | bibtex_citation=""" 29 | @misc{islam2023financebench, 30 | title={FinanceBench: A New Benchmark for Financial Question Answering}, 31 | author={Pranab Islam and Anand Kannappan and Douwe Kiela and Rebecca Qian and Nino Scherrer and Bertie Vidgen}, 32 | year={2023}, 33 | eprint={2311.11944}, 34 | archivePrefix={arXiv}, 35 | primaryClass={cs.CL} 36 | } 37 | """, 38 | ) 39 | super().__init__(metadata=self.metadata) -------------------------------------------------------------------------------- /GAR/main.py: -------------------------------------------------------------------------------- 1 | # main.py 2 | import os 3 | import pandas as pd 4 | from Selection_agent import SelectionAgent 5 | from DPO_agent import DPO_Agent 6 | from DPO_agent_Finetuning import dpo_agent_finetuning 7 | from Finetuned_DPO_agent import Finetuned_DPO_agent 8 | from openai import OpenAI 9 | def ensure_dir(file_path): 10 | directory = os.path.dirname(file_path) 11 | if not os.path.exists(directory): 12 | os.makedirs(directory) 13 | def main(): 14 | API_KEY = os.getenv('OPENAI_API_KEY') 15 | 16 | dataset_path = 'data/raw/dataset.csv' 17 | selected_docs_path = 'data/processed/selected_docs.csv' 18 | dpo_responses_path = 'data/processed/dpo_responses.csv' 19 | final_answers_path = 'data/processed/final_answers.csv' 20 | 21 | # 1. Selection Agent 22 | selection_agent = SelectionAgent(API_KEY) 23 | df = pd.read_csv(dataset_path) 24 | selection_agent.process_data(df, selected_docs_path) 25 | 26 | # 2. DPO Agent 27 | dpo_agent = DPO_Agent(API_KEY) 28 | selected_df = pd.read_csv(selected_docs_path) 29 | dpo_agent.process_data(selected_df, dpo_responses_path) 30 | 31 | # 3. DPO Finetuning 32 | finetuning_agent = dpo_agent_finetuning(API_KEY) 33 | train_path = 'data/models/fine_tuned/train.jsonl' 34 | eval_path = 'data/models/fine_tuned/eval.jsonl' 35 | finetuning_agent.save_from_csv_to_jsonl(dpo_responses_path, selected_docs_path, train_path) 36 | finetuning_agent.split_jsonl(train_path, train_path, eval_path) 37 | 38 | # 4. Finetuned DPO Agent 39 | finetuned_agent = Finetuned_DPO_agent(API_KEY) 40 | finetuned_agent.process_finetuning(df, final_answers_path) 41 | 42 | if __name__ == '__main__': 43 | main() -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 41 | -------------------------------------------------------------------------------- /financerag/tasks/MultiHierttTask.py: -------------------------------------------------------------------------------- 1 | from .BaseTask import BaseTask 2 | from .TaskMetadata import TaskMetadata 3 | 4 | 5 | class MultiHiertt(BaseTask): 6 | def __init__(self): 7 | self.metadata: TaskMetadata = TaskMetadata( 8 | name="MultiHiertt", 9 | description="MultiHiertt: Numerical Reasoning over Multi Hierarchical Tabular and Textual Data", 10 | reference="https://github.com/psunlpgroup/MultiHiertt", 11 | dataset={ 12 | "path": "Linq-AI-Research/FinanceRAG", 13 | "subset": "MultiHiertt", 14 | }, 15 | type="RAG", 16 | category="s2p", 17 | modalities=["text"], 18 | date=None, 19 | domains=["Report"], 20 | task_subtypes=[ 21 | "Financial retrieval", 22 | "Question answering", 23 | ], 24 | license="mit", 25 | annotations_creators="expert-annotated", 26 | dialect=[], 27 | sample_creation="human-generated", 28 | bibtex_citation=""" 29 | @inproceedings{zhao-etal-2022-multihiertt, 30 | title = "{M}ulti{H}iertt: Numerical Reasoning over Multi Hierarchical Tabular and Textual Data", 31 | author = "Zhao, Yilun and 32 | Li, Yunxiang and 33 | Li, Chenying and 34 | Zhang, Rui", 35 | booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 36 | month = may, 37 | year = "2022", 38 | address = "Dublin, Ireland", 39 | publisher = "Association for Computational Linguistics", 40 | url = "https://aclanthology.org/2022.acl-long.454", 41 | pages = "6588--6600", 42 | } 43 | """, 44 | ) 45 | super().__init__(self.metadata) -------------------------------------------------------------------------------- /GAR/run_finetuning.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | from sentence_transformers import SentenceTransformer 4 | from torch.utils.data import DataLoader 5 | from Embedder_Finetuning import EmbedderFinetuner 6 | 7 | def run_finetuning(): 8 | load_dotenv() 9 | 10 | 11 | names = ['FinDER', 'FinQABench', 'FinanceBench', 'FinQA', 'TATQA', 'ConvFinQA', 'MultiHeirtt'] 12 | markdown_dir = "Markdown" 13 | output_path = "./fine-tune-0203" 14 | hf_token = os.getenv("HUGGINGFACE_API_KEY") 15 | repo_id = os.getenv("repo_id") 16 | repo_owner = os.getenv("repo_owner") 17 | epochs = 2 18 | learning_rate = 2e-5 19 | batch_size = 4 20 | warmup_ratio = 0.1 21 | 22 | 23 | finetuner = EmbedderFinetuner() 24 | corpus_df, queries_df, relevant_docs_data, all_data = finetuner.load_datasets(names, markdown_dir) 25 | train_data, val_data, train_rel, val_rel = finetuner.split_train_val(all_data, relevant_docs_data) 26 | 27 | 28 | train_samples = finetuner.create_train_samples(train_data) 29 | print(f"Training samples: {len(train_samples)}") 30 | eval_examples = finetuner.create_eval_examples(val_data) 31 | eval_loader = DataLoader(eval_examples, shuffle=False, batch_size=batch_size) 32 | 33 | 34 | corpus_dict, queries_dict = finetuner.prepare_corpus_queries(corpus_df, queries_df, val_rel) 35 | relevant_docs = finetuner.build_relevant_docs(val_rel) 36 | ir_evaluator = finetuner.create_ir_evaluator(queries_dict, corpus_dict, relevant_docs, batch_size=batch_size) 37 | 38 | 39 | model = SentenceTransformer( 40 | 'NovaSearch/stella_en_1.5B_v5', 41 | trust_remote_code=True, 42 | config_kwargs={"use_memory_efficient_attention": True, "unpad_inputs": False} 43 | ) 44 | 45 | 46 | print("Evaluating before fine-tuning:") 47 | ir_evaluator(model) 48 | 49 | 50 | finetuner.train_model(model, train_samples, ir_evaluator, output_path, epochs, learning_rate, warmup_ratio, batch_size) 51 | 52 | 53 | finetuner.upload_model_to_hub(output_path, repo_id, hf_token, repo_owner) 54 | 55 | 56 | finetuner.clear_gpu_memory() 57 | 58 | if __name__ == "__main__": 59 | run_finetuning() -------------------------------------------------------------------------------- /financerag/tasks/TATQATask.py: -------------------------------------------------------------------------------- 1 | from .BaseTask import BaseTask 2 | from .TaskMetadata import TaskMetadata 3 | 4 | 5 | class TATQA(BaseTask): 6 | def __init__(self): 7 | self.metadata: TaskMetadata = TaskMetadata( 8 | name="TAT-QA", 9 | description="TAT-QA: A Question Answering Benchmark on a Hybrid of Tabular and Textual Content in Finance", 10 | reference="https://github.com/NExTplusplus/TAT-QA", 11 | dataset={ 12 | "path": "Linq-AI-Research/FinanceRAG", 13 | "subset": "TATQA", 14 | }, 15 | type="RAG", 16 | category="s2p", 17 | modalities=["text"], 18 | date=None, 19 | domains=["Report"], 20 | task_subtypes=[ 21 | "Financial retrieval", 22 | "Question answering", 23 | ], 24 | license="mit", 25 | annotations_creators="human-annotated", 26 | dialect=[], 27 | sample_creation="human-generated", 28 | bibtex_citation=""" 29 | @inproceedings{zhu-etal-2021-tat, 30 | title = "{TAT}-{QA}: A Question Answering Benchmark on a Hybrid of Tabular and Textual Content in Finance", 31 | author = "Zhu, Fengbin and 32 | Lei, Wenqiang and 33 | Huang, Youcheng and 34 | Wang, Chao and 35 | Zhang, Shuo and 36 | Lv, Jiancheng and 37 | Feng, Fuli and 38 | Chua, Tat-Seng", 39 | booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)", 40 | month = aug, 41 | year = "2021", 42 | address = "Online", 43 | publisher = "Association for Computational Linguistics", 44 | url = "https://aclanthology.org/2021.acl-long.254", 45 | doi = "10.18653/v1/2021.acl-long.254", 46 | pages = "3277--3287" 47 | } 48 | """, 49 | ) 50 | super().__init__(self.metadata) 51 | -------------------------------------------------------------------------------- /financerag/retrieval/sent_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Literal, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | from sentence_transformers import SentenceTransformer 5 | from torch import Tensor 6 | 7 | from financerag.common import Encoder 8 | 9 | 10 | # Adopted by https://github.com/beir-cellar/beir/blob/main/beir/retrieval/models/sentence_bert.py 11 | class SentenceTransformerEncoder(Encoder): 12 | 13 | def __init__( 14 | self, 15 | model_name_or_path: Union[str, Tuple[str, str]], 16 | query_prompt: Optional[str] = None, 17 | doc_prompt: Optional[str] = None, 18 | **kwargs 19 | ): 20 | if isinstance(model_name_or_path, str): 21 | self.q_model = SentenceTransformer(model_name_or_path, **kwargs) 22 | self.doc_model = self.q_model 23 | elif isinstance(model_name_or_path, Tuple): 24 | self.q_model = SentenceTransformer(model_name_or_path[0], **kwargs) 25 | self.doc_model = SentenceTransformer(model_name_or_path[1], **kwargs) 26 | else: 27 | raise TypeError 28 | self.query_prompt = query_prompt 29 | self.doc_prompt = doc_prompt 30 | 31 | def encode_queries( 32 | self, queries: List[str], batch_size: int = 16, **kwargs 33 | ) -> Union[np.ndarray, Tensor]: 34 | if self.query_prompt is not None: 35 | queries = [self.query_prompt + query for query in queries] 36 | return self.q_model.encode(queries, batch_size=batch_size, **kwargs) 37 | 38 | def encode_corpus( 39 | self, 40 | corpus: Union[ 41 | List[Dict[Literal["title", "text"], str]], 42 | Dict[Literal["title", "text"], List], 43 | ], 44 | batch_size: int = 8, 45 | **kwargs 46 | ) -> Union[np.ndarray, Tensor]: 47 | if isinstance(corpus, dict): 48 | sentences = [ 49 | ( 50 | (corpus["title"][i] + " " + corpus["text"][i]).strip() 51 | if "title" in corpus 52 | else corpus["text"][i].strip() 53 | ) 54 | for i in range(len(corpus["text"])) 55 | ] 56 | else: 57 | sentences = [ 58 | ( 59 | (doc["title"] + " " + doc["text"]).strip() 60 | if "title" in doc 61 | else doc["text"].strip() 62 | ) 63 | for doc in corpus 64 | ] 65 | if self.doc_prompt is not None: 66 | sentences = [self.doc_prompt + s for s in sentences] 67 | return self.doc_model.encode(sentences, batch_size=batch_size, **kwargs) 68 | -------------------------------------------------------------------------------- /GAR/Finetuned_DPO_agent.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import pandas as pd 3 | import numpy as np 4 | from tqdm import tqdm 5 | import json 6 | import random 7 | import dotenv 8 | import os 9 | dotenv.load_dotenv() 10 | def save_from_csv_to_jsonl(csv_file, jsonl_file): 11 | df = pd.read_csv(csv_file) 12 | with open(jsonl_file, 'w', encoding='utf-8') as f: 13 | for index, row in tqdm(df.iterrows(), total=df.shape[0]): 14 | json_obj = { 15 | "input": { 16 | "messages": [ 17 | {"role": "user", "content": row['query']} 18 | ], 19 | "tools": [], 20 | "parallel_toolcalls": True 21 | }, 22 | "preferred_output": [ 23 | {"role": "assistant", "content": row['response1']} 24 | ], 25 | "non_preferred_output": [ 26 | {"role": "assistant", "content": row['response2']} 27 | ] 28 | } 29 | f.write(json.dumps(json_obj, ensure_ascii=False) + '\n') 30 | path = "/content/drive/MyDrive/FinanceRAG-GAR/본선/data_files/dpo_data_to67.csv" 31 | saving_path = "/content/drive/MyDrive/FinanceRAG-GAR/본선/data_files/dpo_data_to67.jsonl" 32 | save_from_csv_to_jsonl(path, saving_path) 33 | """ 34 | query | response1 | response2 | 35 | { 36 | "input": { 37 | "messages": [ 38 | { 39 | "role": "user", 40 | "content": "Hello, can you tell me how cold San Francisco is today?" 41 | } 42 | ], 43 | "tools": [], 44 | "parallel_tool_calls": true 45 | }, 46 | "preferred_output": [ 47 | { 48 | "role": "assistant", 49 | "content": "Today in San Francisco, it is not quite cold as expected. Morning clouds will give away to sunshine, with a high near 68°F (20°C) and a low around 57°F (14°C)." 50 | } 51 | ], 52 | "non_preferred_output": [ 53 | { 54 | "role": "assistant", 55 | "content": "It is not particularly cold in San Francisco today." 56 | } 57 | ] 58 | } 59 | ... 60 | """ 61 | model_name = 'gpt-4o-mini' # 'gpt-4o' 62 | def load_jsonl(path): 63 | with open(path, 'r', encoding='utf-8') as file: 64 | data = [json.loads(line) for line in file] 65 | return data 66 | training_data = load_jsonl(saving_path) 67 | client = openai.OpenAI( 68 | api_key = os.getenv('OPENAI_API_KEY') 69 | ) 70 | response = client.fine_tuning.jobs.create( 71 | training_file=training_data, 72 | model=model_name, 73 | method={ 74 | "type": "dpo", 75 | "dpo": { 76 | "hyperparameters": {"beta": 0.1} 77 | } 78 | } 79 | ) 80 | response -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /financerag/retrieval/bm25.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any, Callable, Dict, List, Literal, Optional 3 | 4 | import numpy as np 5 | from nltk.tokenize import word_tokenize 6 | 7 | from financerag.common import Lexical, Retrieval 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def tokenize_list(input_list: List[str]) -> List[List[str]]: 13 | """ 14 | Tokenizes a list of strings using the `nltk.word_tokenize` function. 15 | 16 | Args: 17 | input_list (`List[str]`): 18 | A list of input strings to be tokenized. 19 | 20 | Returns: 21 | `List[List[str]]`: 22 | A list where each element is a list of tokens corresponding to an input string. 23 | """ 24 | return list(map(word_tokenize, input_list)) 25 | 26 | 27 | class BM25Retriever(Retrieval): 28 | """ 29 | A retrieval class that utilizes a lexical model (e.g., BM25) to search for the most relevant documents 30 | from a given corpus based on the input queries. This retriever tokenizes the queries and uses the provided 31 | lexical model to compute relevance scores between the queries and documents in the corpus. 32 | 33 | Methods: 34 | - retrieve: Searches for relevant documents based on the given queries, returning the top-k results. 35 | """ 36 | 37 | def __init__(self, model: Lexical, tokenizer: Callable[[List[str]], List[List[str]]] = tokenize_list): 38 | """ 39 | Initializes the `BM25Retriever` class with a lexical model and a tokenizer function. 40 | 41 | Args: 42 | model (`Lexical`): 43 | A lexical model (e.g., BM25) implementing the `Lexical` protocol, responsible for calculating relevance scores. 44 | tokenizer (`Callable[[List[str]], List[List[str]]]`, *optional*): 45 | A function that tokenizes the input queries. Defaults to `tokenize_list`, which uses `nltk.word_tokenize`. 46 | """ 47 | self.model: Lexical = model 48 | self.tokenizer: Callable[[List[str]], List[List[str]]] = tokenizer 49 | self.results: Optional[Dict[str, Any]] = {} 50 | 51 | def retrieve( 52 | self, 53 | corpus: Dict[str, Dict[Literal["title", "text"], str]], 54 | queries: Dict[str, str], 55 | top_k: Optional[int] = None, 56 | score_function: Optional[str] = None, 57 | return_sorted: bool = False, 58 | **kwargs 59 | ) -> Dict[str, Dict[str, float]]: 60 | """ 61 | Searches the corpus for the most relevant documents based on the given queries. The retrieval process involves 62 | tokenizing the queries, calculating relevance scores using the lexical model, and returning the top-k results 63 | for each query. 64 | 65 | Args: 66 | corpus (`Dict[str, Dict[Literal["title", "text"], str]]`): 67 | A dictionary representing the corpus, where each key is a document ID, and each value is another dictionary 68 | containing document fields such as 'id', 'title', and 'text'. 69 | queries (`Dict[str, str]`): 70 | A dictionary containing query IDs and corresponding query texts. 71 | top_k (`Optional[int]`, *optional*): 72 | The number of top documents to return for each query. If not provided, all documents are returned. Defaults to `None`. 73 | return_sorted (`bool`, *optional*): 74 | Whether to return the results sorted by score. Defaults to `False`. 75 | **kwargs: 76 | Additional keyword arguments passed to the lexical model during scoring. 77 | 78 | Returns: 79 | `Dict[str, Dict[str, float]]`: 80 | A dictionary where each key is a query ID, and the value is another dictionary mapping document IDs to relevance scores. 81 | """ 82 | query_ids = list(queries.keys()) 83 | self.results = {qid: {} for qid in query_ids} 84 | 85 | logger.info("Tokenizing queries with lower cases") 86 | query_lower_tokens = self.tokenizer([queries[qid].lower() for qid in queries]) 87 | 88 | corpus_ids = list(corpus.keys()) 89 | 90 | for qid, query in zip(query_ids, query_lower_tokens): 91 | scores = self.model.get_scores(query) 92 | top_k_result = np.argsort(scores)[::-1][:top_k] 93 | for idx in top_k_result: 94 | self.results[qid][corpus_ids[idx]] = scores[idx] 95 | 96 | return self.results -------------------------------------------------------------------------------- /GAR/DPO_agent.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from openai import OpenAI 3 | import os 4 | import json 5 | import numpy as np 6 | import openai 7 | from tqdm import tqdm 8 | from dotenv import load_dotenv 9 | load_dotenv() 10 | class DPOTrainer: 11 | def __init__(self, model_name): 12 | load_dotenv() 13 | self.model_name = model_name 14 | self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) 15 | def generate_output(self, system_prompt, user_prompt, temp=0, returnRaw=False): 16 | response = self.client.chat.completions.create( 17 | model=self.model_name, 18 | messages=[ 19 | {"role": "system", "content": system_prompt}, 20 | {"role": "user", "content": user_prompt} 21 | ], 22 | temperature=temp 23 | ) 24 | text = response.choices[0].message.content 25 | return response if returnRaw else text 26 | def evaluate_corpus_relevance(self, query_line, corpus_list): 27 | relevancy_instruction = """ 28 | You are an expert financial advisor and evaluator focused on improving responses. 29 | Your task is to enhance answers based on detailed evaluation scores while: 30 | - Maintaining factual accuracy with the provided documents 31 | - Ensuring responses are clear and well-structured for financial contexts 32 | - Providing comprehensive answers that address all aspects of the query 33 | - Using professional financial terminology appropriately 34 | You are given the pair of Query, Corpus (same query) 35 | Out of the 10 documents, only provide the list of indices of those that are RELEVANT (e.g. the content is somehow needed to answer the question), from 0~9. 36 | Example : [0, 2, 8, 9] 37 | """ 38 | user_prompt = f""" 39 | Query: {query_line} 40 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(corpus_list)]) 41 | try: 42 | relevancy = eval(self.generate_output(relevancy_instruction, user_prompt, temp=0)) 43 | except Exception as error: 44 | print(error) 45 | relevancy = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 46 | if relevancy == []: 47 | relevancy = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 48 | return relevancy 49 | def answer_query(self, query_line, corpus_list, temp=0, returnRaw=False): 50 | financial_prompt = """ 51 | You are an expert financial advisor and evaluator focused on improving responses. 52 | Your task is to enhance answers based on detailed evaluation scores while: 53 | - Maintaining factual accuracy with the provided documents 54 | - Ensuring responses are clear and well-structured for financial contexts 55 | - Providing comprehensive answers that address all aspects of the query 56 | - Using professional financial terminology appropriately 57 | """ 58 | query_prompt = f""" 59 | Query: {query_line} 60 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(corpus_list)]) 61 | query_prompt += "\nDo not add any closing pleasantries or phrases like 'please feel free to ask!'" 62 | response = self.generate_output(financial_prompt, query_prompt, temp=temp, returnRaw=returnRaw) 63 | return response 64 | def process_data(self, input_csv_path, output_csv_path): 65 | df = pd.read_csv(input_csv_path) 66 | dpo_data = pd.DataFrame({'query': [], 'response1': [], 'response2': []}) 67 | for i in tqdm(range(0, len(df), 10)): 68 | chunk = df[i:i+10] 69 | query_line = chunk['Query'].iloc[0] 70 | corpus_list = chunk['Corpus'].tolist() 71 | corpus_rel = self.evaluate_corpus_relevance(query_line, corpus_list) 72 | try: 73 | relevant_corpus = chunk.iloc[corpus_rel] 74 | except: 75 | relevant_corpus = chunk 76 | relevant_corpus = relevant_corpus['Corpus'].tolist() 77 | response1 = self.answer_query(query_line, relevant_corpus, temp=0) 78 | response2 = self.answer_query(query_line, relevant_corpus, temp=0.5) 79 | dpo_data.loc[i] = [query_line, response1, response2] 80 | dpo_data.to_csv(output_csv_path) 81 | return dpo_data 82 | if __name__ == "__main__": 83 | model_name = "gpt-4o-mini" # or "ft:gpt-4o-mini-2024-07-18:personal::AkFnSqzI" 84 | trainer = DPOTrainer(model_name) 85 | 86 | input_path = '/content/drive/MyDrive/FinanceRAG-GAR/본선/data_files/sampled_data.csv' 87 | output_path = '/content/drive/MyDrive/FinanceRAG-GAR/본선/data_files/sampled_answer_4o.csv' 88 | 89 | trainer.process_data(input_path, output_path) -------------------------------------------------------------------------------- /financerag/generate/openai.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import multiprocessing 3 | import os 4 | from multiprocessing import Pool 5 | from typing import Any, Dict, List, Tuple, cast 6 | 7 | import openai 8 | from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam 9 | 10 | from financerag.common.protocols import Generator 11 | 12 | openai.api_key = os.getenv("OPENAI_API_KEY") 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class OpenAIGenerator(Generator): 18 | """ 19 | A class that interfaces with the OpenAI API to generate responses using a specified model. It implements the 20 | `Generator` protocol and supports generating responses in parallel using multiple processes. 21 | 22 | Args: 23 | model_name (`str`): 24 | The name of the OpenAI model to use for generating responses (e.g., "gpt-4", "gpt-3.5-turbo"). 25 | """ 26 | 27 | def __init__(self, model_name: str): 28 | """ 29 | Initializes the OpenAIGenerator with the specified model name. 30 | 31 | Args: 32 | model_name (`str`): 33 | The OpenAI model name used to generate responses. 34 | """ 35 | self.model_name: str = model_name 36 | self.results: Dict = {} 37 | 38 | def _process_query( 39 | self, args: Tuple[str, List[ChatCompletionMessageParam], Dict[str, Any]] 40 | ) -> Tuple[str, str]: 41 | """ 42 | Internal method to process a single query using the OpenAI model. It sends the query and messages to the 43 | OpenAI API and retrieves the response. 44 | 45 | Args: 46 | args (`Tuple[str, List[ChatCompletionMessageParam], Dict[str, Any]]`): 47 | Contains the query ID, a list of messages (query), and additional arguments for the model. 48 | 49 | Returns: 50 | `Tuple[str, str]`: 51 | A tuple containing the query ID and the generated response. 52 | """ 53 | q_id, messages, kwargs = args 54 | temperature = kwargs.pop("temperature", 1.0) 55 | top_p = kwargs.pop("top_p", 1.0) 56 | stream = kwargs.pop("stream", False) 57 | max_tokens = kwargs.pop("max_tokens", 10000) 58 | presence_penalty = kwargs.pop("presence_penalty", 0.0) 59 | frequency_penalty = kwargs.pop("frequency_penalty", 0.0) 60 | 61 | client = openai.OpenAI() 62 | response = client.chat.completions.create( 63 | model=self.model_name, 64 | messages=messages, 65 | temperature=temperature, 66 | top_p=top_p, 67 | stream=stream, 68 | max_tokens=max_tokens, 69 | presence_penalty=presence_penalty, 70 | frequency_penalty=frequency_penalty, 71 | ) 72 | return q_id, response.choices[0].message.content 73 | 74 | def generation( 75 | self, 76 | messages: Dict[str, List[Dict[str, str]]], 77 | num_processes: int = multiprocessing.cpu_count(), # Number of parallel processes 78 | **kwargs, 79 | ) -> Dict[str, str]: 80 | """ 81 | Generate responses for the given messages using the OpenAI model. This method supports parallel processing 82 | using multiprocessing to speed up the generation process for multiple queries. 83 | 84 | Args: 85 | messages (`Dict[str, List[Dict[str, str]]]`): 86 | A dictionary where the keys are query IDs, and the values are lists of dictionaries representing the 87 | messages (queries). 88 | num_processes (`int`, *optional*, defaults to `multiprocessing.cpu_count()`): 89 | The number of processes to use for parallel generation. 90 | **kwargs: 91 | Additional keyword arguments for the OpenAI model (e.g., temperature, top_p, max_tokens). 92 | 93 | Returns: 94 | `Dict[str, str]`: 95 | A dictionary where each key is a query ID, and the value is the generated response. 96 | """ 97 | logger.info( 98 | f"Starting generation for {len(messages)} queries using {num_processes} processes..." 99 | ) 100 | 101 | # Prepare arguments for multiprocessing 102 | query_args = [ 103 | (q_id, cast(list[ChatCompletionMessageParam], msg), kwargs.copy()) 104 | for q_id, msg in messages.items() 105 | ] 106 | 107 | # Use multiprocessing Pool for parallel generation 108 | with Pool(processes=num_processes) as pool: 109 | results = pool.map(self._process_query, query_args) 110 | 111 | # Collect results 112 | self.results = {q_id: content for q_id, content in results} 113 | 114 | logger.info( 115 | f"Generation completed for all queries. Collected {len(self.results)} results." 116 | ) 117 | 118 | return self.results -------------------------------------------------------------------------------- /financerag/rerank/cross_encoder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Optional 3 | 4 | from financerag.common import CrossEncoder, Reranker 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | # Adapted from https://github.com/beir-cellar/beir/blob/main/beir/reranking/rerank.py 10 | class CrossEncoderReranker(Reranker): 11 | """ 12 | A reranker class that utilizes a cross-encoder model from the `sentence-transformers` library 13 | to rerank search results based on query-document pairs. This class implements a reranking 14 | mechanism using cross-attention, where each query-document pair is passed through the 15 | cross-encoder model to compute relevance scores. 16 | 17 | The cross-encoder model expects two inputs (query and document) and directly computes a 18 | score indicating the relevance of the document to the query. The model follows the 19 | `CrossEncoder` protocol, ensuring it is compatible with `sentence-transformers` cross-encoder models. 20 | 21 | Methods: 22 | rerank: 23 | Takes in a corpus, queries, and initial retrieval results, and reranks 24 | the top-k documents using the cross-encoder model. 25 | """ 26 | 27 | def __init__(self, model: CrossEncoder): 28 | """ 29 | Initializes the `CrossEncoderReranker` class with a cross-encoder model. 30 | 31 | Args: 32 | model (`CrossEncoder`): 33 | A cross-encoder model implementing the `CrossEncoder` protocol from the `sentence-transformers` library. 34 | """ 35 | self.model: CrossEncoder = model 36 | self.results: Dict[str, Dict[str, float]] = {} 37 | 38 | def rerank( 39 | self, 40 | corpus: Dict[str, Dict[str, str]], 41 | queries: Dict[str, str], 42 | results: Dict[str, Dict[str, float]], 43 | top_k: int, 44 | batch_size: Optional[int] = None, 45 | **kwargs 46 | ) -> Dict[str, Dict[str, float]]: 47 | """ 48 | Reranks the top-k documents for each query based on cross-encoder model predictions. 49 | 50 | Args: 51 | corpus (`Dict[str, Dict[str, str]]`): 52 | A dictionary representing the corpus, where each key is a document ID and each value is a dictionary 53 | containing the title and text fields of the document. 54 | queries (`Dict[str, str]`): 55 | A dictionary containing query IDs as keys and the corresponding query texts as values. 56 | results (`Dict[str, Dict[str, float]]`): 57 | A dictionary containing query IDs and the initial retrieval results. Each query ID is mapped to another 58 | dictionary where document IDs are keys and initial retrieval scores are values. 59 | top_k (`int`): 60 | The number of top documents to rerank for each query. 61 | batch_size (`Optional[int]`, *optional*): 62 | The batch size used when passing the query-document pairs through the cross-encoder model. 63 | Defaults to None. 64 | **kwargs: 65 | Additional arguments passed to the cross-encoder model during prediction. 66 | 67 | Returns: 68 | `Dict[str, Dict[str, float]]`: 69 | A dictionary containing query IDs as keys and dictionaries of reranked document IDs and their scores as values. 70 | """ 71 | sentence_pairs, pair_ids = [], [] 72 | 73 | for query_id in results: 74 | if len(results[query_id]) > top_k: 75 | for doc_id, _ in sorted( 76 | results[query_id].items(), key=lambda item: item[1], reverse=True 77 | )[:top_k]: 78 | pair_ids.append([query_id, doc_id]) 79 | corpus_text = ( 80 | corpus[doc_id].get("title", "") 81 | + " " 82 | + corpus[doc_id].get("text", "") 83 | ).strip() 84 | sentence_pairs.append([queries[query_id], corpus_text]) 85 | 86 | else: 87 | for doc_id in results[query_id]: 88 | pair_ids.append([query_id, doc_id]) 89 | corpus_text = ( 90 | corpus[doc_id].get("title", "") 91 | + " " 92 | + corpus[doc_id].get("text", "") 93 | ).strip() 94 | sentence_pairs.append([queries[query_id], corpus_text]) 95 | 96 | #### Starting to Rerank using cross-attention 97 | logger.info(f"Starting To Rerank Top-{top_k}....") 98 | rerank_scores = [ 99 | float(score) 100 | for score in self.model.predict( 101 | sentences=sentence_pairs, batch_size=batch_size, **kwargs 102 | ) 103 | ] 104 | 105 | #### Reranker results 106 | self.results = {query_id: {} for query_id in results} 107 | for pair, score in zip(pair_ids, rerank_scores): 108 | query_id, doc_id = pair[0], pair[1] 109 | self.results[query_id][doc_id] = score 110 | 111 | return self.results -------------------------------------------------------------------------------- /financerag/tasks/TaskMetadata.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/abstasks/TaskMetadata.py 2 | from __future__ import annotations 3 | 4 | import logging 5 | from datetime import date 6 | from typing import Any, Dict, Optional, Union 7 | 8 | from pydantic import AnyUrl, BaseModel, BeforeValidator, TypeAdapter, field_validator 9 | from typing_extensions import Annotated, Literal 10 | 11 | TASK_SUBTYPE = Literal[ 12 | "Financial retrieval", 13 | "Question answering", 14 | ] 15 | 16 | TASK_DOMAIN = Literal["Report",] 17 | 18 | SAMPLE_CREATION_METHOD = Literal[ 19 | "found", 20 | "human-generated", 21 | "LM-generated and verified", 22 | ] 23 | 24 | TASK_TYPE = Literal["RAG",] 25 | 26 | TASK_CATEGORY = Literal["s2p",] # Sentence-to-paragraph 27 | 28 | ANNOTATOR_TYPE = Literal[ 29 | "expert-annotated", 30 | "human-annotated", 31 | "derived", 32 | "LM-generated", 33 | "LM-generated and reviewed", # reviewed by humans 34 | ] 35 | 36 | http_url_adapter = TypeAdapter(AnyUrl) 37 | STR_URL = Annotated[ 38 | str, BeforeValidator(lambda value: str(http_url_adapter.validate_python(value))) 39 | ] # Allows the type to be a string, but ensures that the string is a URL 40 | 41 | pastdate_adapter = TypeAdapter(date) 42 | STR_DATE = Annotated[ 43 | str, BeforeValidator(lambda value: str(pastdate_adapter.validate_python(value))) 44 | ] # Allows the type to be a string, but ensures that the string is a valid date 45 | 46 | SPLIT_NAME = str 47 | HFSubset = str 48 | 49 | LICENSES = ( # this list can be extended as needed 50 | Literal[ # we use lowercase for the licenses similar to the huggingface datasets 51 | "not specified", # or none found 52 | "mit", 53 | "cc-by-2.0", 54 | "cc-by-3.0", 55 | "cc-by-4.0", 56 | "cc-by-sa-3.0", 57 | "cc-by-sa-4.0", 58 | "cc-by-nc-4.0", 59 | "cc-by-nc-sa-3.0", 60 | "cc-by-nc-sa-4.0", 61 | "cc-by-nc-nd-4.0", 62 | "openrail", 63 | "openrail++", 64 | "odc-by", 65 | "afl-3.0", 66 | "apache-2.0", 67 | "cc-by-nd-2.1-jp", 68 | "cc0-1.0", 69 | "bsd-3-clause", 70 | "gpl-3.0", 71 | "cdla-sharing-1.0", 72 | "mpl-2.0", 73 | ] 74 | ) 75 | 76 | METRIC_NAME = str 77 | METRIC_VALUE = Union[int, float, Dict[str, Any]] 78 | 79 | logger = logging.getLogger(__name__) 80 | 81 | 82 | class TaskMetadata(BaseModel): 83 | """ 84 | Metadata for a task. 85 | 86 | Args: 87 | dataset: A dictionary containing the arguments to pass to `datasets.load_dataset` to load the dataset for the task. Must include 'path' and *Optional*ly 'revision'. 88 | Refer to https://huggingface.co/docs/datasets/v3.0.0/en/package_reference/loading_methods for more details. 89 | name: The name of the task. 90 | description: A description of the task. 91 | type: (*Optional*) The type of the task, such as "Retrieval" or "Generation". Corresponds to the TASK_TYPE literal. 92 | modalities: The input modality of the dataset. In this case, it is set to ["text"], meaning the dataset deals with textual data. 93 | category: (*Optional*) The category of the task, e.g., "s2p" (sentence-to-paragraph). Corresponds to the TASK_CATEGORY literal. 94 | reference: (*Optional*) A URL to documentation or a published paper about the task. Must be a valid URL. 95 | date: (*Optional*) A tuple containing the start and end dates when the dataset was collected, ensuring the data reflects a certain time frame. 96 | domains: (*Optional*) The domain(s) of the data, e.g., "Report". Defined as TASK_DOMAIN literals. 97 | task_subtypes: (*Optional*) Subtypes of the task, providing more specific details (e.g., "Financial retrieval", "Question answering"). 98 | license: (*Optional*) The license under which the dataset is released. Uses a predefined list of licenses (e.g., "cc-by-4.0"), but custom licenses can be provided via URLs. 99 | annotations_creators: (*Optional*) The type of annotators who created or verified the dataset annotations, such as "expert-annotated" or "LM-generated and reviewed". 100 | dialect: (*Optional*) The dialect of the data, if applicable. Ideally specified as a BCP-47 language tag. Empty if no dialects are present. 101 | sample_creation: (*Optional*) The method used to create the dataset samples, such as "found", "human-generated", or "LM-generated and verified". 102 | bibtex_citation: (*Optional*) The BibTeX citation for the dataset. Should be provided if available; otherwise, it is an empty string. 103 | 104 | Methods: 105 | validate_metadata: Validates that the necessary metadata fields (like dataset path and revision) are specified. 106 | is_filled: Checks if all required metadata fields are filled in the TaskMetadata instance. 107 | intext_citation: Generates an in-text citation based on the BibTeX entry provided. If no BibTeX is available, returns an empty string. 108 | 109 | Validators: 110 | _check_dataset_path_is_specified: Ensures that the dataset dictionary contains the 'path' key. 111 | _check_dataset_revision_is_specified: Ensures that the dataset dictionary contains the 'revision' key or provides a warning if it's missing. 112 | """ 113 | 114 | dataset: dict 115 | 116 | name: str 117 | description: str 118 | type: Optional[TASK_TYPE] = None 119 | modalities: list[Literal["text"]] = ["text"] 120 | category: Optional[TASK_CATEGORY] = None 121 | reference: Optional[STR_URL] = None 122 | 123 | date: Optional[tuple[STR_DATE, STR_DATE]] = None 124 | domains: Optional[list[TASK_DOMAIN]] = None 125 | task_subtypes: Optional[list[TASK_SUBTYPE]] = None 126 | license: Optional[LICENSES | STR_URL] = None 127 | 128 | annotations_creators: Optional[ANNOTATOR_TYPE] = None 129 | dialect: Optional[list[str]] = None 130 | 131 | sample_creation: Optional[SAMPLE_CREATION_METHOD] = None 132 | bibtex_citation: Optional[str] = None 133 | 134 | @field_validator("dataset") 135 | def _check_dataset_path_is_specified( 136 | cls, dataset: dict[str, Any] 137 | ) -> dict[str, Any]: 138 | if "path" not in dataset: 139 | raise ValueError("Dataset path must be specified") 140 | return dataset 141 | 142 | @field_validator("dataset") 143 | def _check_dataset_subset_is_specified( 144 | cls, dataset: dict[str, Any] 145 | ) -> dict[str, Any]: 146 | if "subset" not in dataset: 147 | raise ValueError("Dataset subset must be specified") 148 | return dataset 149 | 150 | def is_filled(self) -> bool: 151 | """Check if all the metadata fields are filled.""" 152 | return all( 153 | getattr(self, field_name) is not None for field_name in self.model_fields 154 | ) 155 | 156 | @property 157 | def intext_citation(self, include_cite: bool = True) -> str: 158 | """Create an in-text citation for the dataset.""" 159 | cite = "" 160 | if self.bibtex_citation: 161 | cite = f"{self.bibtex_citation.split(',')[0].split('{')[1]}" 162 | if include_cite and cite: 163 | # check for whitespace in the citation 164 | if " " in cite: 165 | logger.warning( 166 | "Citation contains whitespace. Please ensure that the citation is correctly formatted." 167 | ) 168 | return f"\\cite{{{cite}}}" 169 | return cite 170 | -------------------------------------------------------------------------------- /GAR/Embedder_Finetuning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import pandas as pd 4 | from sklearn.model_selection import train_test_split 5 | from sentence_transformers import SentenceTransformer, InputExample 6 | from sentence_transformers.evaluation import InformationRetrievalEvaluator 7 | from torch.utils.data import DataLoader 8 | from sentence_transformers.datasets import NoDuplicatesDataLoader 9 | from sentence_transformers import losses 10 | from huggingface_hub import HfApi, login 11 | import torch 12 | import gc 13 | 14 | class EmbedderFinetuner: 15 | def __init__(self): 16 | self.model = None 17 | 18 | @staticmethod 19 | def format_query(query: str) -> str: 20 | return f'Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: {query}' 21 | 22 | @staticmethod 23 | def format_text(title: str, text: str) -> str: 24 | return f"Title: {title}\nText: {text}" 25 | 26 | def load_datasets(self, names: list, markdown_dir: str = "Markdown"): 27 | corpus_df = pd.DataFrame() 28 | queries_df = pd.DataFrame() 29 | relevant_docs_data = pd.DataFrame() 30 | all_data = pd.DataFrame() 31 | 32 | for name in names: 33 | print(f"Loading data for {name} ...") 34 | qrels_path = f'{name}_qrels.tsv' 35 | queries_path = os.path.join(markdown_dir, name, 'queries.jsonl') 36 | corpus_path = os.path.join(markdown_dir, name, 'corpus.jsonl') 37 | 38 | qrels = pd.read_csv(qrels_path, sep='\t') 39 | queries_temp = pd.read_json(queries_path, lines=True) 40 | corpus_temp = pd.read_json(corpus_path, lines=True) 41 | 42 | corpus_df = pd.concat([corpus_df, corpus_temp], ignore_index=True) 43 | queries_df = pd.concat([queries_df, queries_temp], ignore_index=True) 44 | relevant_docs_data = pd.concat([relevant_docs_data, qrels], ignore_index=True) 45 | 46 | queries_temp.rename(columns={'_id': 'query_id', 'title': 'title_queries', 'text': 'text_queries'}, inplace=True) 47 | corpus_temp.rename(columns={'_id': 'corpus_id', 'title': 'title_corpus', 'text': 'text_corpus'}, inplace=True) 48 | 49 | data = qrels.merge(queries_temp, on='query_id').merge(corpus_temp, on='corpus_id') 50 | all_data = pd.concat([all_data, data], ignore_index=True) 51 | 52 | return corpus_df, queries_df, relevant_docs_data, all_data 53 | 54 | def split_train_val(self, all_data: pd.DataFrame, relevant_docs_data: pd.DataFrame, 55 | test_size: float = 0.2, random_state: int = 42): 56 | train_rel, val_rel = train_test_split(relevant_docs_data, test_size=test_size, random_state=random_state) 57 | train_data = all_data[all_data['query_id'].isin(train_rel['query_id'])] 58 | val_data = all_data[all_data['query_id'].isin(val_rel['query_id'])] 59 | return train_data, val_data, train_rel, val_rel 60 | 61 | def create_train_samples(self, train_data: pd.DataFrame) -> list: 62 | train_samples = [] 63 | for _, row in train_data.iterrows(): 64 | sample = InputExample( 65 | texts=[ 66 | self.format_query(self.format_text(row['title_queries'], row['text_queries'])), 67 | self.format_text(row['title_corpus'], row['text_corpus']) 68 | ] 69 | ) 70 | train_samples.append(sample) 71 | return train_samples 72 | 73 | def add_ir_doc(self, df: pd.DataFrame) -> pd.DataFrame: 74 | irrelevant_docs = [] 75 | for _, row in df.iterrows(): 76 | text = row['text_corpus'] 77 | candidates = df[df['text_corpus'] != text] 78 | if len(candidates) > 0: 79 | irrelevant = candidates.sample(n=1) 80 | irrelevant_docs.append(self.format_text(irrelevant.iloc[0]['title_corpus'], irrelevant.iloc[0]['text_corpus'])) 81 | else: 82 | irrelevant_docs.append("") 83 | df = df.copy() 84 | df['irrelevant_docs'] = irrelevant_docs 85 | return df 86 | 87 | def create_eval_examples(self, val_data: pd.DataFrame) -> list: 88 | examples = [] 89 | val_data_ir = self.add_ir_doc(val_data) 90 | for _, row in val_data_ir.iterrows(): 91 | examples.append(InputExample( 92 | texts=[ 93 | self.format_query(self.format_text(row['title_queries'], row['text_queries'])), 94 | self.format_text(row['title_corpus'], row['text_corpus']) 95 | ], 96 | label=1.0 97 | )) 98 | examples.append(InputExample( 99 | texts=[ 100 | self.format_query(self.format_text(row['title_queries'], row['text_queries'])), 101 | row['irrelevant_docs'] 102 | ], 103 | label=0.0 104 | )) 105 | return examples 106 | 107 | def prepare_corpus_queries(self, corpus_df: pd.DataFrame, queries_df: pd.DataFrame, 108 | val_rel: pd.DataFrame, random_sample: int = 3000): 109 | corpus_df['text'] = corpus_df.apply(lambda row: self.format_text(row['title'], row['text']), axis=1) 110 | corpus_df = corpus_df.drop(columns=['title']) 111 | 112 | queries_df['text'] = queries_df.apply(lambda row: self.format_query(self.format_text(row['title'], row['text'])), axis=1) 113 | queries_df = queries_df.drop(columns=['title']) 114 | 115 | required_corpus_ids = set(map(str, val_rel["corpus_id"])) 116 | all_ids = corpus_df["_id"].tolist() 117 | additional_ids = set(random.sample(all_ids, k=random_sample)) if len(all_ids) >= random_sample else set(all_ids) 118 | required_corpus_ids |= additional_ids 119 | 120 | corpus_df = corpus_df.loc[corpus_df["_id"].astype(str).isin(required_corpus_ids)] 121 | corpus_dict = dict(zip(corpus_df["_id"].astype(str), corpus_df["text"])) 122 | queries_dict = dict(zip(queries_df["_id"].astype(str), queries_df["text"])) 123 | 124 | return corpus_dict, queries_dict 125 | 126 | def build_relevant_docs(self, val_rel: pd.DataFrame) -> dict: 127 | relevant_docs = {} 128 | for qid, corpus_id in zip(val_rel["query_id"], val_rel["corpus_id"]): 129 | qid_str = str(qid) 130 | corpus_str = str(corpus_id) 131 | if qid_str not in relevant_docs: 132 | relevant_docs[qid_str] = set() 133 | relevant_docs[qid_str].add(corpus_str) 134 | return relevant_docs 135 | 136 | def create_ir_evaluator(self, queries: dict, corpus: dict, relevant_docs: dict, 137 | batch_size: int = 4, name: str = "Evaluate") -> InformationRetrievalEvaluator: 138 | return InformationRetrievalEvaluator( 139 | queries=queries, 140 | corpus=corpus, 141 | relevant_docs=relevant_docs, 142 | name=name, 143 | batch_size=batch_size 144 | ) 145 | 146 | def train_model(self, model: SentenceTransformer, train_samples: list, evaluator: InformationRetrievalEvaluator, 147 | output_path: str, epochs: int = 2, learning_rate: float = 2e-5, 148 | warmup_ratio: float = 0.1, batch_size: int = 4): 149 | self.model = model 150 | loader = NoDuplicatesDataLoader(train_samples, batch_size=batch_size) 151 | loss = losses.MultipleNegativesRankingLoss(model) 152 | warmup_steps = int(len(loader) * epochs * warmup_ratio) 153 | 154 | model.fit( 155 | train_objectives=[(loader, loss)], 156 | evaluator=evaluator, 157 | epochs=epochs, 158 | warmup_steps=warmup_steps, 159 | output_path=output_path, 160 | optimizer_params={'lr': learning_rate}, 161 | show_progress_bar=True, 162 | use_amp=True, 163 | evaluation_steps=len(loader), 164 | save_best_model=True, 165 | ) 166 | 167 | def upload_model_to_hub(self, save_path: str, repo_id: str, hf_token: str, repo_owner: str = None): 168 | login(token=hf_token) 169 | api = HfApi() 170 | try: 171 | api.create_repo(repo_id=repo_id) 172 | except Exception as e: 173 | print(f"Repo creation: {e}") 174 | full_repo_id = repo_id if repo_owner is None else f"{repo_owner}/{repo_id}" 175 | api.upload_folder( 176 | folder_path=save_path, 177 | repo_id=full_repo_id, 178 | repo_type="model", 179 | ) 180 | 181 | def clear_gpu_memory(self): 182 | torch.cuda.empty_cache() 183 | gc.collect() 184 | -------------------------------------------------------------------------------- /GAR/hybrid_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import logging 4 | import random 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from tqdm import tqdm 9 | from torch.nn.functional import softmax 10 | from rank_bm25 import BM25Okapi 11 | from huggingface_hub import login 12 | 13 | 14 | from sentence_transformers import SentenceTransformer, CrossEncoder 15 | from financerag.rerank import CrossEncoderReranker 16 | from financerag.retrieval import DenseRetrieval, SentenceTransformerEncoder 17 | from financerag.tasks import FinDER, FinQABench, FinanceBench, TATQA, FinQA, ConvFinQA, MultiHiertt 18 | from financerag.common import Retrieval 19 | 20 | logging.basicConfig(level=logging.INFO) 21 | 22 | 23 | 24 | def tokenize_list(input_list): 25 | from nltk.tokenize import word_tokenize 26 | return list(map(word_tokenize, input_list)) 27 | 28 | 29 | class BM25Retriever(Retrieval): 30 | def __init__(self, model, tokenizer=tokenize_list): 31 | self.model = model 32 | self.tokenizer = tokenizer 33 | self.results = {} 34 | 35 | def retrieve(self, corpus, queries, top_k=None, score_function=None, return_sorted=False, **kwargs): 36 | query_ids = list(queries.keys()) 37 | self.results = {qid: {} for qid in query_ids} 38 | query_lower_tokens = self.tokenizer([queries[qid].lower() for qid in queries]) 39 | corpus_ids = list(corpus.keys()) 40 | 41 | for qid, query in tqdm(zip(query_ids, query_lower_tokens), total=len(query_ids), desc="BM25 Retrieval"): 42 | scores = self.model.get_scores(query) 43 | top_k_result = np.argsort(scores)[::-1][:top_k] 44 | for idx in top_k_result: 45 | self.results[qid][corpus_ids[idx]] = scores[idx] 46 | return self.results 47 | 48 | 49 | class HybridSearcher: 50 | def __init__(self): 51 | self.retrieval_results = {} 52 | 53 | @staticmethod 54 | def softmax_normalize(retrieval_results): 55 | for query_id, results in retrieval_results.items(): 56 | scores = torch.tensor(list(results.values()), dtype=torch.float32) 57 | probs = softmax(scores, dim=0).tolist() 58 | doc_ids = list(results.keys()) 59 | for i, doc_id in enumerate(doc_ids): 60 | results[doc_id] = probs[i] 61 | return retrieval_results 62 | 63 | def setup_task(self): 64 | finder_task=FinDER() 65 | finqabench_task=FinQABench() 66 | finqa_task=FinQA() 67 | financebench_task=FinanceBench() 68 | tatqa_task=TATQA() 69 | convfinqa_task=ConvFinQA() 70 | multihiertt_task=MultiHiertt() 71 | tasks=[finder_task, finqabench_task, finqa_task, financebench_task, tatqa_task, convfinqa_task, multihiertt_task] 72 | names=["Finder", "FinQABench", "FinQA", "FinanceBench", "TATQA", "ConvFinQA", "MultiHiertt"] 73 | 74 | for i, task in enumerate(tasks): 75 | task.metadata.dataset['path']='thomaskim1130/FinanceRAG-Lingua' 76 | task.metadata.dataset['qrels']= f'markdown-{names[i]}' 77 | task.load_data() 78 | return tasks, names 79 | 80 | def retrieval_model_setup(self): 81 | login(token=os.getenv('HF_TOKEN')) 82 | stella=SentenceTransformer( 83 | model_name_or_path='thomaskim1130/stella_en_1.5B_v5-FinanceRAG-TAT-MH-v2', 84 | trust_remote_code=True, 85 | query_prompt='Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: ', 86 | doc_prompt='', 87 | config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False} 88 | ) 89 | retrieval_stella=DenseRetrieval( 90 | model=stella, 91 | ) 92 | 93 | return retrieval_stella 94 | 95 | def reranker_model_setup(self): 96 | reranker=CrossEncoderReranker( 97 | model=CrossEncoder('BAAI/bge-reranker-v2-e3', 98 | trust_remote_code=True, 99 | config_args={"torch_dtype": torch.bfloat16, "attn_implementation": "sdpa", "device_map":"auto", "offload_folder":"offload"} 100 | ), 101 | ) 102 | return reranker 103 | 104 | def get_sparse_score(self, task): 105 | document_list = task.corpus 106 | query_list = task.queries 107 | tokenized_corpus = tokenize_list([doc["title"].lower() + ' ' + doc["text"].lower() 108 | for doc in document_list.values()]) 109 | bm25 = BM25Okapi(tokenized_corpus) 110 | retriever = BM25Retriever(bm25) 111 | top_k = len(document_list) 112 | sparse_retrieval_results = retriever.retrieve(corpus=document_list, queries=query_list, top_k=top_k) 113 | return sparse_retrieval_results 114 | 115 | def get_dense_score(self, task, retrieval_model): 116 | document_list = task.corpus 117 | query_list = task.queries 118 | dense_retrieval_results = retrieval_model.retrieve( 119 | corpus=document_list, 120 | queries=query_list, 121 | top_k=100 122 | ) 123 | return dense_retrieval_results 124 | 125 | def get_hybrid_score(self, task, alpha, retrieval_model): 126 | sparse_alpha = 1 - alpha 127 | hybrid_retrieval_result = {} 128 | dense_retrieval_result = self.get_dense_score(task, retrieval_model) 129 | dense_retrieval_result = self.softmax_normalize(dense_retrieval_result) 130 | 131 | for query_id, results in dense_retrieval_result.items(): 132 | for doc_id, dense_score_val in results.items(): 133 | if query_id not in hybrid_retrieval_result: 134 | hybrid_retrieval_result[query_id] = {} 135 | hybrid_retrieval_result[query_id][doc_id] = alpha * dense_score_val 136 | 137 | sparse_retrieval_result = self.get_sparse_score(task) 138 | sparse_retrieval_result = self.softmax_normalize(sparse_retrieval_result) 139 | 140 | for query_id, results in sparse_retrieval_result.items(): 141 | for doc_id, sparse_score_val in results.items(): 142 | if doc_id in hybrid_retrieval_result.get(query_id, {}): 143 | hybrid_retrieval_result[query_id][doc_id] += sparse_alpha * sparse_score_val 144 | 145 | return hybrid_retrieval_result 146 | 147 | def evaluate_hybrid(self, task, qrels_dict, hybrid_retrieval_result): 148 | result = task.evaluate(qrels_dict, hybrid_retrieval_result, [10]) 149 | return result[0]['NDCG@10'] 150 | 151 | def tune_alpha(self, task, retrieval_model): 152 | alpha_values = np.linspace(0, 1, 41) 153 | ndcg_values = [] 154 | qrels_path = f'Dataset/{task}_qrels.tsv' 155 | qrels_df = pd.read_csv(qrels_path, sep='\t') 156 | qrels_dict = qrels_df.groupby('query_id').apply(lambda x: dict(zip(x['corpus_id'], x['score']))).to_dict() 157 | 158 | for alpha in alpha_values: 159 | hybrid_retrieval_result = self.get_hybrid_score(task, alpha=alpha, retrieval_model=retrieval_model) 160 | ndcg_value = self.evaluate_hybrid(task, qrels_dict, hybrid_retrieval_result) 161 | ndcg_values.append(ndcg_value) 162 | max_ndcg_index = np.argmax(ndcg_values) 163 | optimal_alpha = alpha_values[max_ndcg_index] 164 | return optimal_alpha 165 | 166 | def get_reranker_score(self, task, hybrid_retrieval_result, reranker_model): 167 | reranker_result = task.rerank( 168 | reranker=reranker_model, 169 | results=hybrid_retrieval_result, 170 | top_k=10, 171 | batch_size=32 172 | ) 173 | return reranker_result 174 | 175 | def get_ndcg_score(self, task, name): 176 | qrels = pd.read_csv(f'Dataset/{name}_qrels.tsv', sep='\t').groupby('query_id').apply(lambda x: dict(zip(x['corpus_id'], x['score']))).to_dict() 177 | return task.evaluate(qrels, task.retrieve_results, [10])[0]['NDCG@10'] 178 | 179 | def get_final_ndcg(self, tasks, names): 180 | result = 0 181 | task_lengths = [] 182 | 183 | for n, task in enumerate(tasks): 184 | task_lengths.append(len(task.queries)) 185 | print(f"{names[n]} : {len(task.queries)} Queries") 186 | result += self.get_ndcg_score(task, names[n])*task_lengths[-1] 187 | 188 | result /= sum(task_lengths) 189 | return result 190 | 191 | @staticmethod 192 | def merge_csv_results(output_dir): 193 | import glob 194 | csv_files = glob.glob(os.path.join(output_dir, "*", "*.csv")) 195 | dataframes = [] 196 | for file in csv_files: 197 | df = pd.read_csv(file) 198 | dataframes.append(df) 199 | merged_df = pd.concat(dataframes, ignore_index=True) 200 | merged_csv_path = os.path.join(output_dir, 'merged_output.csv') 201 | merged_df.to_csv(merged_csv_path, index=False) 202 | logging.info(f"Merged CSV saved to {merged_csv_path}") 203 | -------------------------------------------------------------------------------- /financerag/common/loader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from typing import Optional, Tuple, cast 4 | 5 | from datasets import Dataset, Value, load_dataset 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | class HFDataLoader: 10 | """ 11 | A Hugging Face Dataset loader for corpus and query data. Supports loading datasets from local files 12 | (in JSONL format) or directly from a Hugging Face repository. 13 | 14 | Args: 15 | hf_repo (`str`, *optional*): 16 | The Hugging Face repository containing the dataset. If provided, it overrides the 17 | data folder, prefix, and *_file arguments. 18 | data_folder (`str`, *optional*): 19 | Path to the folder containing the dataset files when loading from local files. 20 | subset (`str`, *optional*): 21 | The subset of the dataset to load (e.g., "FinDER", "FinQA"). Used in both local and HF repo loading. 22 | prefix (`str`, *optional*): 23 | A prefix to add to the file names (e.g., "train_", "test_"). 24 | corpus_file (`str`, defaults to `"corpus.jsonl"`): 25 | The filename for the corpus when loading locally. 26 | query_file (`str`, defaults to `"queries.jsonl"`): 27 | The filename for the queries when loading locally. 28 | keep_in_memory (`bool`, defaults to `False`): 29 | Whether to keep the dataset in memory. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | hf_repo: Optional[str] = None, 35 | data_folder: Optional[str] = None, 36 | subset: Optional[str] = None, 37 | prefix: Optional[str] = None, 38 | corpus_file: str = "corpus.jsonl", 39 | query_file: str = "queries.jsonl", 40 | keep_in_memory: bool = False, 41 | ): 42 | """ 43 | Initializes the HFDataLoader class. 44 | 45 | Args: 46 | hf_repo (`str`, *optional*): 47 | The Hugging Face repository containing the dataset. 48 | data_folder (`str`, *optional*): 49 | Path to the folder containing the dataset files when loading from local files. 50 | subset (`str`, *optional*): 51 | The subset of the dataset to load. 52 | prefix (`str`, *optional*): 53 | A prefix to add to the file names. 54 | corpus_file (`str`, defaults to `"corpus.jsonl"`): 55 | The filename for the corpus when loading locally. 56 | query_file (`str`, defaults to `"queries.jsonl"`): 57 | The filename for the queries when loading locally. 58 | keep_in_memory (`bool`, defaults to `False`): 59 | Whether to keep the dataset in memory. 60 | """ 61 | self.corpus: Optional[Dataset] = None 62 | self.queries: Optional[Dataset] = None 63 | self.hf_repo = hf_repo 64 | self.subset = subset 65 | if hf_repo: 66 | logger.warning( 67 | "A Hugging Face repository is provided. This will override the data_folder, prefix, and *_file arguments." 68 | ) 69 | else: 70 | if (data_folder is None) or (subset is None): 71 | raise ValueError( 72 | "A Hugging Face repository or local directory is required." 73 | ) 74 | 75 | if prefix: 76 | query_file = prefix + "_" + query_file 77 | 78 | self.corpus_file = ( 79 | (Path(data_folder) / subset / corpus_file).as_posix() 80 | if data_folder 81 | else corpus_file 82 | ) 83 | self.query_file = ( 84 | (Path(data_folder) / subset / query_file).as_posix() 85 | if data_folder 86 | else query_file 87 | ) 88 | self.streaming = False 89 | self.keep_in_memory = keep_in_memory 90 | 91 | @staticmethod 92 | def check(file_in: str, ext: str): 93 | """ 94 | Check if the given file exists and has the correct extension. 95 | 96 | Args: 97 | file_in (`str`): The path of the file to check. 98 | ext (`str`): The expected file extension. 99 | 100 | Raises: 101 | `ValueError`: If the file does not exist or if the extension does not match. 102 | """ 103 | if not Path(file_in).exists(): 104 | raise ValueError( 105 | "File {} not present! Please provide an accurate file.".format(file_in) 106 | ) 107 | 108 | if not file_in.endswith(ext): 109 | raise ValueError( 110 | "File {} must have the extension {}".format(file_in, ext) 111 | ) 112 | 113 | def load(self) -> Tuple[Dataset, Dataset]: 114 | """ 115 | Loads both the corpus and query datasets. If the datasets are not already loaded, 116 | they are loaded from the specified source (either local files or Hugging Face repository). 117 | 118 | Returns: 119 | `Tuple[Dataset, Dataset]`: A tuple containing the loaded corpus and queries datasets. 120 | """ 121 | if not self.hf_repo: 122 | self.check(file_in=self.corpus_file, ext="jsonl") 123 | self.check(file_in=self.query_file, ext="jsonl") 124 | 125 | if self.corpus is None: 126 | logger.info("Loading Corpus...") 127 | self._load_corpus() 128 | self.corpus = cast(Dataset, self.corpus) 129 | logger.info("Loaded %d Documents.", len(self.corpus)) 130 | logger.info("Corpus Example: %s", self.corpus[0]) 131 | 132 | if self.queries is None: 133 | logger.info("Loading Queries...") 134 | self._load_queries() 135 | self.queries = cast(Dataset, self.queries) 136 | 137 | logger.info("Loaded %d Queries.", len(self.queries)) 138 | logger.info("Query Example: %s", self.queries[0]) 139 | 140 | return self.corpus, self.queries 141 | 142 | def load_corpus(self) -> Dataset: 143 | """ 144 | Loads the corpus dataset. If the corpus is already loaded, returns the existing dataset. 145 | 146 | Returns: 147 | `Dataset`: The loaded corpus dataset. 148 | """ 149 | if not self.hf_repo: 150 | self.check(file_in=self.corpus_file, ext="jsonl") 151 | 152 | if (self.corpus is None) or (not len(self.corpus)): 153 | logger.info("Loading Corpus...") 154 | self._load_corpus() 155 | self.corpus = cast(Dataset, self.corpus) 156 | logger.info("Loaded %d Documents.", len(self.corpus)) 157 | logger.info("Corpus Example: %s", self.corpus[0]) 158 | 159 | return self.corpus 160 | 161 | def _load_corpus(self): 162 | """ 163 | Internal method to load the corpus dataset from either local files or Hugging Face repository. 164 | The dataset is processed by renaming and removing unnecessary columns. 165 | """ 166 | if self.hf_repo: 167 | corpus_ds = load_dataset( 168 | path=self.hf_repo, 169 | name=self.subset, 170 | split="corpus", 171 | keep_in_memory=self.keep_in_memory, 172 | streaming=self.streaming, 173 | ) 174 | else: 175 | corpus_ds = load_dataset( 176 | "json", 177 | data_files=self.corpus_file, 178 | streaming=self.streaming, 179 | keep_in_memory=self.keep_in_memory, 180 | ) 181 | 182 | corpus_ds = cast(Dataset, corpus_ds) 183 | corpus_ds = corpus_ds.cast_column("_id", Value("string")) 184 | corpus_ds = corpus_ds.rename_column("_id", "id") 185 | corpus_ds = corpus_ds.remove_columns( 186 | [ 187 | col 188 | for col in corpus_ds.column_names 189 | if col not in ["id", "text", "title"] 190 | ] 191 | ) 192 | self.corpus = corpus_ds 193 | 194 | def _load_queries(self): 195 | """ 196 | Internal method to load the queries dataset from either local files or Hugging Face repository. 197 | The dataset is processed by renaming and removing unnecessary columns. 198 | """ 199 | if self.hf_repo: 200 | queries_ds = load_dataset( 201 | path=self.hf_repo, 202 | name=self.subset, 203 | split="queries", 204 | keep_in_memory=self.keep_in_memory, 205 | streaming=self.streaming, 206 | ) 207 | else: 208 | queries_ds = load_dataset( 209 | "json", 210 | data_files=self.query_file, 211 | streaming=self.streaming, 212 | keep_in_memory=self.keep_in_memory, 213 | ) 214 | queries_ds = cast(Dataset, queries_ds) 215 | queries_ds = queries_ds.cast_column("_id", Value("string")) 216 | queries_ds = queries_ds.rename_column("_id", "id") 217 | queries_ds = queries_ds.remove_columns( 218 | [col for col in queries_ds.column_names if col not in ["id", "text"]] 219 | ) 220 | self.queries = queries_ds -------------------------------------------------------------------------------- /financerag/common/protocols.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Dict, List, Literal, Optional, Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | __all__ = [ 8 | "Encoder", 9 | "Lexical", 10 | "Retrieval", 11 | "CrossEncoder", 12 | "Reranker", 13 | "Generator", 14 | ] 15 | 16 | 17 | class Lexical(abc.ABC): 18 | """ 19 | Abstract class for lexical models that defines an interface for calculating relevance scores 20 | between a query and a set of documents. This abstract class is designed to be implemented by 21 | classes that calculate document-query relevance using lexical methods such as BM25 or 22 | other term-based approaches. 23 | """ 24 | 25 | @abc.abstractmethod 26 | def get_scores(self, query: List[str], **kwargs) -> List[float]: 27 | """ 28 | Calculates relevance scores for a given query against a set of documents. 29 | 30 | Args: 31 | query (`List[str]`): 32 | A tokenized query in the form of a list of words. This represents the query 33 | to be evaluated for relevance against the documents. 34 | 35 | Returns: 36 | `List[float]`: 37 | A list of relevance scores, where each score corresponds to the relevance of 38 | a document in the indexed corpus to the provided query. 39 | """ 40 | raise NotImplementedError 41 | 42 | 43 | class Encoder(abc.ABC): 44 | """ 45 | Abstract class for dense encoders, providing methods to encode texts, queries, and corpora into dense vectors. 46 | """ 47 | 48 | @abc.abstractmethod 49 | def encode_queries( 50 | self, queries: List[str], **kwargs 51 | ) -> Union[torch.Tensor, np.ndarray]: 52 | """ 53 | Encodes a list of queries into dense vector representations. 54 | 55 | Args: 56 | queries (`List[str]`): 57 | A list of query strings to encode. 58 | **kwargs: 59 | Additional arguments passed to the encoder. 60 | 61 | Returns: 62 | `Union[torch.Tensor, np.ndarray]`: 63 | Encoded queries as a tensor or numpy array. 64 | """ 65 | raise NotImplementedError 66 | 67 | def encode_corpus( 68 | self, 69 | corpus: Union[ 70 | List[Dict[Literal["title", "text"], str]], 71 | Dict[Literal["title", "text"], List], 72 | ], 73 | **kwargs 74 | ) -> Union[torch.Tensor, np.ndarray]: 75 | """ 76 | Encodes a list of corpus documents into dense vector representations. 77 | 78 | Args: 79 | corpus (`Union[List[Dict[Literal["title", "text"], str]], Dict[Literal["title", "text"], List]]`): 80 | A list or dictionary of corpus documents to encode. 81 | **kwargs: 82 | Additional arguments passed to the encoder. 83 | 84 | Returns: 85 | `Union[torch.Tensor, np.ndarray]`: 86 | Encoded corpus documents as a tensor or numpy array. 87 | """ 88 | raise NotImplementedError 89 | 90 | 91 | class Retrieval(abc.ABC): 92 | """ 93 | Abstract class for retrieval modules, providing a method to search for the most relevant documents based on queries. 94 | """ 95 | 96 | @abc.abstractmethod 97 | def retrieve( 98 | self, 99 | corpus: Dict[str, Dict[Literal["title", "text"], str]], 100 | queries: Dict[str, str], 101 | top_k: Optional[int] = None, 102 | score_function: Optional[str] = None, 103 | **kwargs 104 | ) -> Dict[str, Dict[str, float]]: 105 | """ 106 | Searches the corpus for the most relevant documents to the given queries. 107 | 108 | Args: 109 | corpus (`Dict[str, Dict[Literal["title", "text"], str]]`): 110 | A dictionary where each key is a document ID and each value is another dictionary containing document fields 111 | (e.g., {'text': str, 'title': str}). 112 | queries (`Dict[str, str]`): 113 | A dictionary where each key is a query ID and each value is the query text. 114 | top_k (`Optional[int]`, *optional*): 115 | The number of top documents to return for each query. If None, return all documents. Defaults to None. 116 | score_function (`Optional[str]`, *optional*): 117 | The scoring function to use when ranking the documents (e.g., 'cosine', 'dot', etc.). Defaults to None. 118 | **kwargs: 119 | Additional arguments passed to the search method. 120 | 121 | Returns: 122 | `Dict[str, Dict[str, float]]`: 123 | A dictionary where each key is a query ID, and each value is another dictionary mapping document IDs to 124 | relevance scores (e.g., {'doc1': 0.9, 'doc2': 0.8}). 125 | """ 126 | raise NotImplementedError 127 | 128 | 129 | class CrossEncoder(abc.ABC): 130 | """ 131 | Abstract class for rerankers, providing methods to predict sentence similarity and rank documents based on queries. 132 | """ 133 | 134 | @abc.abstractmethod 135 | def predict( 136 | self, 137 | sentences: Union[ 138 | List[Tuple[str, str]], List[List[str]], Tuple[str, str], List[str] 139 | ], 140 | batch_size: Optional[int] = None, 141 | **kwargs 142 | ) -> Union[torch.Tensor, np.ndarray]: 143 | """ 144 | Predicts similarity or relevance scores for pairs of sentences or lists of sentences. 145 | 146 | Args: 147 | sentences (`Union[List[Tuple[str, str]], List[List[str]], Tuple[str, str], List[str]]`): 148 | Sentences to predict similarity scores for. Can be a list of sentence pairs, list of sentence lists, 149 | a single sentence pair, or a list of sentences. 150 | batch_size (`Optional[int]`, *optional*): 151 | Batch size for prediction. Defaults to None. 152 | 153 | Returns: 154 | `Union[torch.Tensor, np.ndarray]`: 155 | Predicted similarity or relevance scores as a tensor or numpy array. 156 | """ 157 | raise NotImplementedError 158 | 159 | 160 | class Reranker(abc.ABC): 161 | """ 162 | Abstract class for reranking modules that defines methods to rerank search results based on queries. 163 | """ 164 | 165 | @abc.abstractmethod 166 | def rerank( 167 | self, 168 | corpus: Dict[str, Dict[str, str]], 169 | queries: Dict[str, str], 170 | results: Dict[str, Dict[str, float]], 171 | top_k: int, 172 | batch_size: Optional[int] = None, 173 | **kwargs 174 | ) -> Dict[str, Dict[str, float]]: 175 | """ 176 | Reranks the search results based on the given queries and the initial ranking scores. 177 | 178 | Args: 179 | corpus (`Dict[str, Dict[str, str]]`): 180 | A dictionary where keys are document IDs and values are dictionaries containing 181 | document metadata, such as content or other features. 182 | queries (`Dict[str, str]`): 183 | A dictionary where keys are query IDs and values are the corresponding query texts. 184 | results (`Dict[str, Dict[str, float]]`): 185 | A dictionary where keys are query IDs and values are dictionaries mapping document 186 | IDs to their initial relevance scores. 187 | top_k (`int`): 188 | The number of top documents to rerank. 189 | batch_size (`Optional[int]`, *optional*): 190 | The batch size to use during reranking. Useful for models that process data in 191 | batches. Defaults to None. 192 | **kwargs: 193 | Additional keyword arguments for custom configurations in the reranking process. 194 | 195 | Returns: 196 | `Dict[str, Dict[str, float]]`: 197 | The reranked relevance scores, returned as a dictionary mapping query IDs to dictionaries of document IDs and their scores. 198 | """ 199 | raise NotImplementedError 200 | 201 | 202 | class Generator(abc.ABC): 203 | """ 204 | Abstract class for text generators, providing methods for generating text completions in a chat-like interface. 205 | """ 206 | 207 | @abc.abstractmethod 208 | def generation( 209 | self, messages: Dict[str, List[Dict[str, str]]], **kwargs 210 | ) -> Dict[str, str]: 211 | """ 212 | Generates a chat completion based on a sequence of messages. 213 | 214 | Args: 215 | messages (`Dict[str, List[Dict[str, str]]]`): 216 | A list of message dictionaries per `query_id`. 217 | Each dictionary in list must contain: 218 | - 'role' (str): The role of the speaker (e.g., 'user' or 'system'). 219 | - 'content' (str): The content of the message. 220 | **kwargs: 221 | Additional arguments passed to the generator. 222 | 223 | Returns: 224 | `Dict[str, str]`: 225 | A dictionary containing the generated response, where each key is the `query_id` and the value is the generated text. 226 | """ 227 | raise NotImplementedError -------------------------------------------------------------------------------- /financerag/retrieval/dense.py: -------------------------------------------------------------------------------- 1 | import heapq 2 | import logging 3 | from typing import Any, Callable, Dict, Literal, Optional 4 | 5 | import torch 6 | 7 | from financerag.common.protocols import Encoder, Retrieval 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | # Copied from https://github.com/beir-cellar/beir/blob/main/beir/retrieval/search/dense/util.py 13 | @torch.no_grad() 14 | def cos_sim(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 15 | """ 16 | Computes the cosine similarity between two tensors. 17 | 18 | Args: 19 | a (`torch.Tensor`): 20 | Tensor representing query embeddings. 21 | b (`torch.Tensor`): 22 | Tensor representing corpus embeddings. 23 | 24 | Returns: 25 | `torch.Tensor`: 26 | Cosine similarity scores for all pairs. 27 | """ 28 | a = _ensure_tensor(a) 29 | b = _ensure_tensor(b) 30 | return torch.mm( 31 | torch.nn.functional.normalize(a, p=2, dim=1), 32 | torch.nn.functional.normalize(b, p=2, dim=1).transpose(0, 1), 33 | ) 34 | 35 | 36 | @torch.no_grad() 37 | def dot_score(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 38 | """ 39 | Computes the dot-product score between two tensors. 40 | 41 | Args: 42 | a (`torch.Tensor`): 43 | Tensor representing query embeddings. 44 | b (`torch.Tensor`): 45 | Tensor representing corpus embeddings. 46 | 47 | Returns: 48 | `torch.Tensor`: 49 | Dot-product scores for all pairs. 50 | """ 51 | a = _ensure_tensor(a) 52 | b = _ensure_tensor(b) 53 | return torch.mm(a, b.transpose(0, 1)) 54 | 55 | 56 | def _ensure_tensor(x: Any) -> torch.Tensor: 57 | """ 58 | Ensures the input is a torch.Tensor, converting if necessary. 59 | 60 | Args: 61 | x (`Any`): 62 | Input to be checked. 63 | 64 | Returns: 65 | `torch.Tensor`: 66 | Converted tensor. 67 | """ 68 | if not isinstance(x, torch.Tensor): 69 | x = torch.tensor(x) 70 | if len(x.shape) == 1: 71 | x = x.unsqueeze(0) 72 | return x 73 | 74 | 75 | # Adapted from https://github.com/beir-cellar/beir/blob/main/beir/retrieval/search/dense/exact_search.py 76 | class DenseRetrieval(Retrieval): 77 | """ 78 | Encoder-based dense retrieval that performs similarity-based search over a corpus. 79 | 80 | This class uses dense embeddings from an encoder model to compute similarity scores (e.g., cosine similarity or 81 | dot product) between query embeddings and corpus embeddings. It retrieves the top-k most relevant documents 82 | based on these scores. 83 | """ 84 | 85 | def __init__( 86 | self, 87 | model: Encoder, 88 | batch_size: int = 64, 89 | score_functions: Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] | None = None, 90 | corpus_chunk_size: int = 50000 91 | ): 92 | """ 93 | Initializes the DenseRetrieval class. 94 | 95 | Args: 96 | model (`Encoder`): 97 | An encoder model implementing the `Encoder` protocol, responsible for encoding queries and corpus documents. 98 | batch_size (`int`, *optional*, defaults to `64`): 99 | The batch size to use when encoding queries and corpus documents. 100 | score_functions (`Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]`, *optional*): 101 | A dictionary mapping score function names (e.g., "cos_sim", "dot") to functions that compute similarity 102 | scores between query and corpus embeddings. Defaults to cosine similarity and dot product. 103 | corpus_chunk_size (`int`, *optional*, defaults to `50000`): 104 | The number of documents to process in each batch when encoding the corpus. 105 | """ 106 | self.model: Encoder = model 107 | self.batch_size: int = batch_size 108 | if score_functions is None: 109 | score_functions = {"cos_sim": cos_sim, "dot": dot_score} 110 | self.score_functions: Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = score_functions 111 | self.corpus_chunk_size: int = corpus_chunk_size 112 | self.results: Dict = {} 113 | 114 | def retrieve( 115 | self, 116 | corpus: Dict[str, Dict[Literal["title", "text"], str]], 117 | queries: Dict[str, str], 118 | top_k: Optional[int] = None, 119 | score_function: Literal["cos_sim", "dot"] | None = "cos_sim", 120 | return_sorted: bool = False, 121 | **kwargs, 122 | ) -> Dict[str, Dict[str, float]]: 123 | """ 124 | Retrieves the top-k most relevant documents from the corpus based on the given queries. 125 | 126 | This method encodes the queries and corpus documents, computes similarity scores using the specified scoring 127 | function, and retrieves the top-k most relevant documents for each query. 128 | 129 | Args: 130 | corpus (`Dict[str, Dict[Literal["title", "text"], str]]`): 131 | A dictionary where each key is a document ID, and each value contains document metadata 132 | such as 'title' and 'text'. 133 | queries (`Dict[str, str]`): 134 | A dictionary where each key is a query ID and each value is the query text. 135 | top_k (`Optional[int]`, *optional*): 136 | The number of top documents to return for each query. If `None`, returns all documents. 137 | return_sorted (`bool`, *optional*, defaults to `False`): 138 | Whether to return the results sorted by score. 139 | score_function (`Literal["cos_sim", "dot"]`, *optional*, defaults to `"cos_sim"`): 140 | The scoring function to use, either 'cos_sim' for cosine similarity or 'dot' for dot product. 141 | **kwargs: 142 | Additional arguments passed to the encoder model. 143 | 144 | Returns: 145 | `Dict[str, Dict[str, float]]`: 146 | A dictionary where each key is a query ID, and the value is another dictionary mapping document 147 | IDs to their similarity scores. 148 | """ 149 | if score_function not in self.score_functions: 150 | raise ValueError( 151 | f"Score function: {score_function} must be either 'cos_sim' for cosine similarity or 'dot' for dot product." 152 | ) 153 | 154 | logger.info("Encoding queries...") 155 | query_ids = list(queries.keys()) 156 | self.results = {qid: {} for qid in query_ids} 157 | query_texts = [queries[qid] for qid in queries] 158 | query_embeddings = self.model.encode_queries( 159 | query_texts, batch_size=self.batch_size, **kwargs 160 | ) 161 | 162 | logger.info("Sorting corpus by document length...") 163 | sorted_corpus_ids = sorted( 164 | corpus, 165 | key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")), 166 | reverse=True, 167 | ) 168 | 169 | logger.info("Encoding corpus in batches... This may take a while.") 170 | result_heaps = { 171 | qid: [] for qid in query_ids 172 | } # Keep only the top-k docs for each query 173 | 174 | corpus_list = [corpus[cid] for cid in sorted_corpus_ids] 175 | 176 | for batch_num, start_idx in enumerate( 177 | range(0, len(corpus), self.corpus_chunk_size) 178 | ): 179 | logger.info( 180 | f"Encoding batch {batch_num + 1}/{len(range(0, len(corpus_list), self.corpus_chunk_size))}..." 181 | ) 182 | end_idx = min(start_idx + self.corpus_chunk_size, len(corpus_list)) 183 | 184 | # Encode chunk of corpus 185 | sub_corpus_embeddings = self.model.encode_corpus( 186 | corpus_list[start_idx:end_idx], batch_size=self.batch_size, **kwargs 187 | ) 188 | 189 | # Compute similarities using either cosine similarity or dot product 190 | cos_scores = self.score_functions[score_function]( 191 | query_embeddings, sub_corpus_embeddings 192 | ) 193 | cos_scores[torch.isnan(cos_scores)] = -1 194 | 195 | # Get top-k values 196 | if top_k is None: 197 | top_k = len(cos_scores[1]) 198 | 199 | cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk( 200 | cos_scores, 201 | min(top_k + 1, len(cos_scores[1])), 202 | dim=1, 203 | largest=True, 204 | sorted=return_sorted, 205 | ) 206 | 207 | cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist() 208 | cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist() 209 | 210 | for query_itr in range(len(query_embeddings)): 211 | query_id = query_ids[query_itr] 212 | for sub_corpus_id, score in zip( 213 | cos_scores_top_k_idx[query_itr], cos_scores_top_k_values[query_itr] 214 | ): 215 | corpus_id = sorted_corpus_ids[start_idx + sub_corpus_id] 216 | if corpus_id != query_id: 217 | if len(result_heaps[query_id]) < top_k: 218 | heapq.heappush(result_heaps[query_id], (score, corpus_id)) 219 | else: 220 | heapq.heappushpop( 221 | result_heaps[query_id], (score, corpus_id) 222 | ) 223 | 224 | for qid in result_heaps: 225 | for score, corpus_id in result_heaps[qid]: 226 | self.results[qid][corpus_id] = score 227 | 228 | return self.results -------------------------------------------------------------------------------- /GAR/DPO_agent_Finetuning.py: -------------------------------------------------------------------------------- 1 | # dpo_agent_finetuning.py 2 | 3 | import json 4 | import numpy as np 5 | import openai 6 | from tqdm import tqdm 7 | import pandas as pd 8 | import random 9 | 10 | class dpo_agent_finetuning: 11 | def __init__(self, api_key, default_model="gpt-4o-mini", ft_model="ft:gpt-4o-mini-2024-07-18:personal::AkFnSqzI", top_probs=10): 12 | 13 | self.api_key = api_key 14 | self.default_model = default_model 15 | self.ft_model = ft_model 16 | self.top_probs = top_probs 17 | self.client = openai.OpenAI(api_key=self.api_key) 18 | 19 | def save_from_csv_to_jsonl(self, answer_csv, data_csv, jsonl_file): 20 | answer_df = pd.read_csv(answer_csv) 21 | data_df = pd.read_csv(data_csv) 22 | 23 | financial_prompt = """ 24 | You are an expert financial advisor and evaluator focused on improving responses. 25 | Your task is to enhance answers based on detailed evaluation scores while: 26 | - Maintaining factual accuracy with the provided documents 27 | - Ensuring responses are clear and well-structured for financial contexts 28 | - Providing comprehensive answers that address all aspects of the query 29 | - Using professional financial terminology appropriately 30 | """ 31 | 32 | with open(jsonl_file, 'w', encoding='utf-8') as f: 33 | for index, row in tqdm(answer_df.iterrows(), total=answer_df.shape[0]): 34 | chunk = data_df[index:index+10] 35 | corpus_list = chunk['Corpus'].tolist() 36 | 37 | query_prompt = f""" 38 | Query: {row['query']} 39 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(corpus_list)]) 40 | query_prompt += "\nDo not add any closing pleasantries or phrases like 'please feel free to ask!'" 41 | 42 | choice = random.choice([1, 1, 1, 2]) 43 | json_obj = { 44 | "messages": [ 45 | {"role": "system", "content": financial_prompt}, 46 | {"role": "user", "content": query_prompt}, 47 | {"role": "assistant", "content": row[f'response{choice}']} 48 | ] 49 | } 50 | f.write(json.dumps(json_obj, ensure_ascii=False) + "\n") 51 | 52 | def split_jsonl(self, jsonl_file, train_path, eval_path, split_ratio=0.8): 53 | 54 | with open(jsonl_file, "r", encoding="utf-8") as f: 55 | data = [json.loads(line) for line in f] 56 | random.shuffle(data) 57 | split_index = int(len(data) * split_ratio) 58 | train_data = data[:split_index] 59 | eval_data = data[split_index:] 60 | with open(train_path, "w", encoding="utf-8") as f_train: 61 | for record in train_data: 62 | f_train.write(json.dumps(record, ensure_ascii=False) + "\n") 63 | with open(eval_path, "w", encoding="utf-8") as f_eval: 64 | for record in eval_data: 65 | f_eval.write(json.dumps(record, ensure_ascii=False) + "\n") 66 | 67 | def load_jsonl(self, path): 68 | 69 | with open(path, 'r', encoding='utf-8') as file: 70 | data = [json.loads(line) for line in file] 71 | return data 72 | 73 | def calculate_weighted_score(self, response, scores): 74 | 75 | token_probs = [ 76 | np.exp(float(dict(response)['choices'][0].logprobs.content[0].top_logprobs[i].logprob)) 77 | for i in range(self.top_probs) 78 | ] 79 | weighted_scores = sum([token_probs[i] * scores[i] for i in range(self.top_probs)]) 80 | return weighted_scores 81 | 82 | def g_eval(self, query, response, document_list): 83 | 84 | prompt_final_score = f""" 85 | You are an evaluation assistant tasked with assessing the quality of a generated answer. 86 | You will be given one query, one generated answer, and a summarized document related to the query. 87 | Follow these steps to evaluate the answer based on the criteria below. 88 | 89 | 90 | 1. Relevance (0~5): Does the answer directly address the query and is contextually relevant? 91 | 2. Alignment (0~5): Does the answer align with the facts provided in the documents? 92 | 3. Clarity (0~5): Is the answer easy to understand and free of confusion or unnecessary complexity? 93 | 4. Completeness (0~5): Does the answer address all parts of the question thoroughly? 94 | 5. Coherence (0~5): Does the answer follow the following criteria?: Collective properties of all sentences. We match this dimension with DUC's quality problems with structure and consistency. 95 | 96 | 97 | 1. Assign and record each five scores, each score is on a scale of 0 to 5, where 0 is the lowest and 5 is the highest, based on each . 98 | 2. Get the final score by adding up all the recorded scores. The final score must be between 0 and 25. 99 | 3. YOU MUST Return ONLY final score. You must not add any explanations. Do not contain sentences like 'score'. ONLY integer form of the final score, like 11, 28, 30, etc. 100 | 101 | : {query} 102 | 103 | : 104 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(document_list)]) + f""" 105 | 106 | : {dict(response)['choices'][0].message.content} 107 | 108 | : Do NOT add any other text or string, like 'twenty'. ONLY integer form, like 9, 14, 20 etc. 109 | """ 110 | cot_response_final_score = self.client.chat.completions.create( 111 | model=self.default_model, 112 | messages=[ 113 | {"role": "system", "content": "You are a detailed evaluator. Be sure to respond ONLY in int. (e.g., 11, 20, 19, etc.)"}, 114 | {"role": "user", "content": prompt_final_score} 115 | ], 116 | temperature=0, 117 | logprobs=True, 118 | top_logprobs=self.top_probs, 119 | ) 120 | measured_score = int(dict(cot_response_final_score)['choices'][0].message.content) 121 | token_scores = [] 122 | detailed_scores = {} 123 | 124 | for i in range(self.top_probs): 125 | top_logprob_item = dict(cot_response_final_score)['choices'][0].logprobs.content[0].top_logprobs[i] 126 | temp_token = top_logprob_item.token 127 | temp_probs = np.exp(float(top_logprob_item.logprob)) * 100 128 | detailed_scores[temp_token] = f"{temp_probs}%" 129 | try: 130 | token_scores.append(int(top_logprob_item.token)) 131 | except Exception as e: 132 | print(e) 133 | token_scores.append(20) 134 | weighted_score = self.calculate_weighted_score(cot_response_final_score, token_scores) 135 | return measured_score, weighted_score, str(detailed_scores) 136 | 137 | def improve_response(self, query, original_response, document_list, scores_list): 138 | 139 | improvement_prompt = f""" 140 | You need to improve the following answer based on the evaluation scores and criteria. 141 | 142 | : {query} 143 | 144 | : 145 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(document_list)]) + f""" 146 | 147 | : {dict(original_response)['choices'][0].message.content} 148 | 149 | 150 | 1. Relevance (0~5): Does the answer directly address the query and is contextually relevant? 151 | 2. Alignment (0~5): Does the answer align with the facts provided in the documents? 152 | 3. Clarity (0~5): Is the answer easy to understand and free of confusion or unnecessary complexity? 153 | 4. Completeness (0~5): Does the answer address all parts of the question thoroughly? 154 | 5. Coherence (0~5): Does the answer follow the following criteria?: Collective properties of all sentences. We match this dimension with DUC's quality problems with structure and consistency. 155 | 156 | Please provide an improved answer that: 157 | 1. Aims to achieve the highest possible score (25/25) 158 | 2. Focus on creating a response that would: 159 | - Maintain the strengths of the original answer 160 | - Ensure accuracy with the provided documents 161 | - Be clear and well-structured 162 | 3. Your goal is to increase the probability of achieving a perfect score of 25 163 | 164 | 165 | This is the score information which is the probability of the final score. This is the top 10 score information based on the probability. format: {{"score": "probability"}} 166 | {scores_list} 167 | 168 | Provide only the improved answer without any explanations. 169 | """ 170 | improved_response = self.client.chat.completions.create( 171 | model=self.default_model, 172 | messages=[ 173 | {"role": "system", "content": """ 174 | You are an expert financial advisor and evaluator focused on improving responses. 175 | Your task is to enhance answers based on detailed evaluation scores while: 176 | - Maintaining factual accuracy with the provided documents 177 | - Focusing especially on areas that received low scores 178 | - Ensuring responses are clear and well-structured for financial contexts 179 | - Providing comprehensive answers that address all aspects of the query 180 | - Using professional financial terminology appropriately 181 | 182 | You should maintain the strengths of the original response while addressing its weaknesses. 183 | """ }, 184 | {"role": "user", "content": improvement_prompt} 185 | ], 186 | temperature=0 187 | ) 188 | return improved_response 189 | 190 | def evaluate_corpus_relevance(self, query_line, corpus_list): 191 | 192 | relevancy_instruction = """ 193 | You are an expert financial advisor and evaluator focused on improving responses. 194 | Your task is to enhance answers based on detailed evaluation scores while: 195 | - Maintaining factual accuracy with the provided documents 196 | - Ensuring responses are clear and well-structured for financial contexts 197 | - Providing comprehensive answers that address all aspects of the query 198 | - Using professional financial terminology appropriately 199 | 200 | You are given the pair of Query, Corpus (same query) 201 | Out of the 10 documents, only provide the list of indices of those that are RELEVANT (e.g. the content is somehow needed to answer the question), from 0~9. 202 | Example : [0, 2, 8, 9] 203 | """ 204 | user_prompt = f""" 205 | Query: {query_line} 206 | 207 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(corpus_list)]) 208 | try: 209 | relevancy = eval(self.generate_output(relevancy_instruction, user_prompt, temp=0)) 210 | except Exception as error: 211 | print(error) 212 | relevancy = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 213 | if relevancy == []: 214 | relevancy = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] 215 | return relevancy 216 | 217 | def answer_query(self, query_line, corpus_list, temp=0, return_raw=False, model_name=None): 218 | 219 | if model_name is None: 220 | model_name = self.default_model 221 | financial_prompt = """ 222 | You are an expert financial advisor and evaluator focused on improving responses. 223 | Your task is to enhance answers based on detailed evaluation scores while: 224 | - Maintaining factual accuracy with the provided documents 225 | - Ensuring responses are clear and well-structured for financial contexts 226 | - Providing comprehensive answers that address all aspects of the query 227 | - Using professional financial terminology appropriately 228 | """ 229 | query_prompt = f""" 230 | Query: {query_line} 231 | 232 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(corpus_list)]) 233 | query_prompt += "\nDo not add any closing pleasantries or phrases like 'please feel free to ask!'" 234 | response = self.generate_output(financial_prompt, query_prompt, temp=temp, return_raw=return_raw) 235 | return response 236 | 237 | 238 | def process_finetuning(self, df, submission_output_path, start_index=39610): 239 | 240 | submission_df = pd.DataFrame({'query_id': [], 'response': []}) 241 | 242 | for i in tqdm(range(start_index, len(df), 10)): 243 | chunk = df[i:i+10] 244 | query_id = chunk['total_query_id'].iloc[0] 245 | query_line = chunk['total_query'].iloc[0] 246 | corpus_list = chunk['total_corpus'].tolist() 247 | 248 | corpus_rel = self.evaluate_corpus_relevance(query_line, corpus_list) 249 | try: 250 | relevant_corpus = chunk.iloc[corpus_rel] 251 | except Exception as e: 252 | print(e) 253 | relevant_corpus = chunk 254 | relevant_corpus = relevant_corpus['total_corpus'].tolist() 255 | 256 | response = self.answer_query(query_line, relevant_corpus, temp=0, return_raw=True, model_name=self.ft_model) 257 | 258 | measured_score, weighted_score, detailed_scores = self.g_eval(query_line, response, relevant_corpus) 259 | 260 | attempt_count = 0 261 | best_score = weighted_score 262 | best_response = response 263 | 264 | while weighted_score < 22 and attempt_count < 5: 265 | attempt_count += 1 266 | improved_response = self.improve_response(query_line, best_response, relevant_corpus, detailed_scores) 267 | improved_score, improved_weighted, improved_detailed = self.g_eval(query_line, improved_response, relevant_corpus) 268 | if improved_weighted > best_score: 269 | best_score = improved_weighted 270 | best_response = improved_response 271 | weighted_score = improved_weighted 272 | detailed_scores = improved_detailed 273 | else: 274 | break 275 | 276 | submission_df.loc[int(i/10)] = [query_id, best_response.choices[0].message.content] 277 | 278 | submission_df.to_csv(submission_output_path, index=False) 279 | print(f"Submission file saved to {submission_output_path}") 280 | 281 | def upload_finetune_files(self, train_path, eval_path): 282 | 283 | training_file = self.client.files.create( 284 | file=open(train_path, "rb"), 285 | purpose="fine-tune" 286 | ) 287 | eval_file = self.client.files.create( 288 | file=open(eval_path, "rb"), 289 | purpose="fine-tune" 290 | ) 291 | return training_file, eval_file 292 | 293 | def create_finetune_job(self, training_file, eval_file, model="gpt-4o-mini-2024-07-18", method=None): 294 | 295 | job_response = self.client.fine_tuning.jobs.create( 296 | training_file=training_file.id, 297 | model=model, 298 | validation_file=eval_file.id, 299 | # method 옵션이 필요하면 추가로 전달 가능 (예시 주석 참고) 300 | ) 301 | return job_response 302 | 303 | def retrieve_finetune_job(self, job_id): 304 | 305 | return self.client.fine_tuning.jobs.retrieve(job_id) 306 | -------------------------------------------------------------------------------- /financerag/tasks/BaseTask.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import logging 4 | import os 5 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple 6 | 7 | import pytrec_eval 8 | 9 | from financerag.common import Generator, HFDataLoader, Reranker, Retrieval 10 | from financerag.tasks.TaskMetadata import TaskMetadata 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | # Adapted from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/abstasks/AbsTask.py 16 | class BaseTask: 17 | """ 18 | Base class for handling tasks related to document retrieval, reranking, and generation in the finance domain. 19 | The class loads data, handles retrieval and reranking operations, and can generate results using a language model. 20 | 21 | Attributes: 22 | metadata (`TaskMetadata`): 23 | Metadata containing task-specific information, such as dataset paths and subsets. 24 | queries (`Optional[Dict[str, str]]`, defaults to `None`): 25 | A dictionary mapping query IDs to query text. 26 | corpus (`Optional[Dict[str, Dict[str, str]]]`, defaults to `None`): 27 | A dictionary mapping document IDs to a dictionary containing the document title and text. 28 | retrieve_results (`Optional[Dict]`, defaults to `None`): 29 | The results of the retrieval process. 30 | rerank_results (`Optional[Dict]`, defaults to `None`): 31 | The results of the reranking process. 32 | generate_results (`Optional[Dict]`, defaults to `None`): 33 | The results generated by the model. 34 | 35 | Methods: 36 | load_data(): 37 | Loads the dataset (queries and corpus) into memory from the provided metadata. 38 | retrieve(retriever: Retrieval, top_k: Optional[int] = 100, **kwargs): 39 | Performs document retrieval based on the given queries and corpus. 40 | rerank(reranker: Reranker, results: Optional[Dict] = None, top_k: Optional[int] = 100, batch_size: Optional[int] = None, **kwargs): 41 | Reranks the retrieved results using the given reranker model. 42 | generate(model: Generator, results: Optional[Dict] = None, prepare_messages: Optional[Callable] = None, **kwargs): 43 | Generates results based on the highest-scoring documents from the reranked or retrieved results. 44 | prepare_generation_inputs(results: Dict, prepare_messages: Callable) -> Dict[str, List[dict]]: 45 | Prepares the input format required for generating results by the model. 46 | save_results(top_k: int = 10, output_dir: Optional[str] = None) -> None: 47 | Saves the results (retrieval, reranking, and generated) to CSV and JSONL files. 48 | """ 49 | 50 | def __init__(self, metadata: TaskMetadata): 51 | """ 52 | Initializes the BaseTask class with metadata for loading and processing retrieval tasks. 53 | 54 | Args: 55 | metadata (`TaskMetadata`): 56 | Task-specific metadata that contains dataset information and configurations. 57 | """ 58 | self.metadata: TaskMetadata = metadata 59 | self.queries: Optional[Dict[str, str]] = None 60 | self.corpus: Optional[Dict[str, Dict[Literal["title", "text"], str]]] = None 61 | self.retrieve_results: Optional[Dict] = None 62 | self.rerank_results: Optional[Dict] = None 63 | self.generate_results: Optional[Dict] = None 64 | 65 | self.load_data() 66 | 67 | @property 68 | def metadata_dict(self) -> Dict[str, Any]: 69 | """ 70 | Converts the task metadata into a dictionary format. 71 | 72 | Returns: 73 | `Dict[str, Any]`: 74 | A dictionary representation of the task metadata. 75 | """ 76 | return dict(self.metadata) 77 | 78 | def load_data(self): 79 | """ 80 | Loads the corpus and queries from the specified dataset path and subset in the metadata. 81 | 82 | Raises: 83 | `ValueError`: 84 | If the dataset cannot be loaded from the specified path and subset. 85 | """ 86 | if (self.corpus is None) or (self.queries is None): 87 | dataset_path = self.metadata_dict["dataset"]["path"] 88 | subset = self.metadata_dict["dataset"]["subset"] 89 | 90 | corpus, queries = HFDataLoader( 91 | hf_repo=dataset_path, 92 | subset=subset, 93 | keep_in_memory=False, 94 | ).load() 95 | 96 | self.queries = {query["id"]: query["text"] for query in queries} 97 | self.corpus = { 98 | doc["id"]: {"title": doc["title"], "text": doc["text"]} 99 | for doc in corpus 100 | } 101 | 102 | def retrieve( 103 | self, retriever: Retrieval, top_k: Optional[int] = 100, **kwargs 104 | ) -> Dict[str, Dict[str, float]]: 105 | """ 106 | Performs document retrieval using the provided retriever model. 107 | 108 | Args: 109 | retriever (`Retrieval`): 110 | The retrieval model to use for retrieving documents. 111 | top_k (`Optional[int]`, defaults to `100`): 112 | The number of top results to return for each query. 113 | **kwargs: 114 | Additional keyword arguments for the retriever. 115 | 116 | Returns: 117 | `Dict[str, Dict[str, float]]`: 118 | A dictionary where the key is the query ID and the value is another dictionary 119 | mapping document IDs to their retrieval scores. 120 | 121 | Raises: 122 | `TypeError`: 123 | If the `retriever` is not a subclass of `Retrieval`. 124 | `ValueError`: 125 | If the data (corpus or queries) is not loaded before retrieval. 126 | """ 127 | if not issubclass(type(retriever), Retrieval): 128 | raise TypeError(f"{type(retriever)} must be a subclass of the `Retrieval` class") 129 | 130 | if (self.corpus is None) or (self.queries is None): 131 | raise ValueError("Data has not been loaded.") 132 | 133 | self.retrieve_results = retriever.retrieve( 134 | queries=self.queries, corpus=self.corpus, top_k=top_k, **kwargs 135 | ) 136 | 137 | return self.retrieve_results 138 | 139 | def rerank( 140 | self, 141 | reranker: Reranker, 142 | results: Optional[Dict[str, Dict[str, float]]] = None, 143 | top_k: int = 100, 144 | batch_size: Optional[int] = None, 145 | **kwargs, 146 | ) -> Dict[str, Dict[str, float]]: 147 | """ 148 | Reranks the retrieved results using the provided reranker model. 149 | 150 | Args: 151 | reranker (`Reranker`): 152 | The reranker model to use for reranking the retrieved results. 153 | results (`Optional[Dict]`, defaults to `None`): 154 | The initial results to rerank. If not provided, the method uses the retrieval results. 155 | top_k (`Optional[int]`, defaults to `100`): 156 | The number of top results to return after reranking. 157 | batch_size (`Optional[int]`, defaults to `None`): 158 | The batch size to use for reranking. 159 | **kwargs: 160 | Additional keyword arguments for the reranker. 161 | 162 | Returns: 163 | `Dict[str, Dict[str, float]]`: 164 | A dictionary where the key is the query ID and the value is another dictionary 165 | mapping document IDs to reranked scores. 166 | 167 | Raises: 168 | `TypeError`: 169 | If the `reranker` is not a subclass of `Reranker`. 170 | `ValueError`: 171 | If the data (corpus or queries) is not loaded before reranking or both `results` and `retrieve_results` are None. 172 | """ 173 | if not issubclass(type(reranker), Reranker): 174 | raise TypeError(f"{type(reranker)} must be a subclass of the `Reranker` class") 175 | 176 | if (self.corpus is None) or (self.queries is None): 177 | raise ValueError("Data has not been loaded.") 178 | 179 | if results is None: 180 | if self.retrieve_results is not None: 181 | results = self.retrieve_results 182 | else: 183 | raise ValueError("Neither retrieve_results nor results can be None simultaneously.") 184 | 185 | self.rerank_results = reranker.rerank( 186 | queries=self.queries, 187 | corpus=self.corpus, 188 | results=results, 189 | top_k=top_k, 190 | batch_size=batch_size, 191 | **kwargs, 192 | ) 193 | 194 | return self.rerank_results 195 | 196 | def generate( 197 | self, 198 | model: Generator, 199 | results: Optional[Dict] = None, 200 | prepare_messages: Optional[Callable] = None, 201 | **kwargs, 202 | ) -> Dict[str, str]: 203 | """ 204 | Generates responses based on the highest-scoring documents from either the reranked or retrieved results. 205 | 206 | Args: 207 | model (`Generator`): 208 | The model used to generate responses. 209 | results (`Optional[Dict]`, defaults to `None`): 210 | The results to generate responses from. If not provided, uses reranked or retrieved results. 211 | prepare_messages (`Optional[Callable]`, defaults to `None`): 212 | A function to prepare messages for the generation model. If not provided, a default message 213 | preparation function is used. 214 | **kwargs: 215 | Additional keyword arguments for the generation process. 216 | 217 | Returns: 218 | `Dict[str, str]`: 219 | A dictionary where the key is the query ID and the value is the generated response. 220 | 221 | Raises: 222 | `TypeError`: 223 | If the `model` is not a subclass of `Generator`. 224 | `AssertionError`: 225 | If neither rerank_results nor retrieve_results are available for generating responses. 226 | """ 227 | if not issubclass(type(model), Generator): 228 | raise TypeError(f"{type(model)} must be a subclass of the `Generator` class") 229 | 230 | if prepare_messages is None: 231 | logger.info( 232 | "No prepare_messages function provided. " 233 | "Using default message preparation function, which selects the highest scored document for each query." 234 | ) 235 | 236 | def default_messages( 237 | query: str, documents: List[Tuple[str, float]] 238 | ) -> List[Dict]: 239 | first_document = max(documents, key=lambda x: x[1])[0] 240 | messages = [ 241 | {"role": "system", "content": "You are a helpful assistant."}, 242 | { 243 | "role": "user", 244 | "content": f"Document: {first_document}" 245 | f"\nGenerate an answer to the question from the document." 246 | f"\nQuestion: {query}", 247 | }, 248 | ] 249 | return messages 250 | 251 | prepare_messages = default_messages 252 | 253 | if results is None: 254 | results = ( 255 | self.rerank_results 256 | if self.rerank_results is None 257 | else self.retrieve_results 258 | ) 259 | assert results is not None, ( 260 | "Neither rerank_results nor retrieve_results are available. " 261 | "One of them must be provided." 262 | ) 263 | 264 | messages_dict = self.prepare_generation_inputs(results, prepare_messages) 265 | self.generate_results = model.generation(messages_dict, **kwargs) 266 | 267 | return self.generate_results 268 | 269 | def prepare_generation_inputs( 270 | self, results, prepare_messages 271 | ) -> Dict[str, List[dict]]: 272 | """ 273 | Prepares the input messages required for the generation model. 274 | 275 | Args: 276 | results (`Dict`): 277 | The results from retrieval or reranking, which are used to generate responses. 278 | prepare_messages (`Callable`): 279 | A function that prepares the messages required for the generation model. 280 | 281 | Returns: 282 | `Dict[str, List[dict]]`: 283 | A dictionary where the key is the query ID and the value is a list of messages (dictionaries) 284 | that will be passed to the generation model. 285 | 286 | Raises: 287 | `ValueError`: 288 | If the data (corpus or queries) is not loaded. 289 | """ 290 | if (self.corpus is None) or (self.queries is None): 291 | raise ValueError("Data has not been loaded.") 292 | 293 | messages_dict: Dict[str, List[Dict[str, str]]] = {} 294 | logger.info("Preparing generation inputs for %d queries.", len(results)) 295 | for query_id, result in results.items(): 296 | query = self.queries[query_id] 297 | documents = [ 298 | (self.corpus[doc_id], score) for doc_id, score in result.items() 299 | ] 300 | messages = prepare_messages(query, documents) 301 | messages_dict[query_id] = messages 302 | 303 | logger.info("Successfully prepared generation inputs for all queries.") 304 | return messages_dict 305 | 306 | def save_results(self, top_k: int = 10, output_dir: Optional[str] = None) -> None: 307 | """ 308 | Saves the top retrieval or reranking, and generated results to CSV and JSONL files. 309 | 310 | Args: 311 | top_k (`int`, defaults to `10`): 312 | The number of top results to save for each query. 313 | output_dir (`Optional[str]`, defaults to `None`): 314 | The directory where the results should be saved. If not provided, results are not saved. 315 | 316 | Saves: 317 | - Top `top_k` retrieval or reranked results in CSV format. 318 | - Generated responses in JSONL format. 319 | """ 320 | # If no output directory is provided, stop saving. 321 | if output_dir is None: 322 | return 323 | # Create the output directory if it does not exist 324 | output_dir = os.path.join(output_dir, self.metadata.name) 325 | os.makedirs(output_dir, exist_ok=True) 326 | 327 | logger.info(f"Output directory set to: {output_dir}") 328 | 329 | # Path to save the CSV file 330 | csv_file_path = os.path.join(output_dir, "results.csv") 331 | logger.info(f"Saving top {top_k} results to CSV file: {csv_file_path}") 332 | 333 | # Determine whether to use rerank results or retrieve results 334 | final_result = ( 335 | self.rerank_results 336 | if self.rerank_results is not None 337 | else self.retrieve_results 338 | ) 339 | 340 | # Process the final result if it's not None 341 | if final_result is not None: 342 | with open(csv_file_path, mode="w", newline="") as csv_file: 343 | writer = csv.writer(csv_file) 344 | # Write the header to the CSV file 345 | writer.writerow(["query_id", "corpus_id"]) 346 | logger.info("Writing header ['query_id', 'corpus_id'] to CSV.") 347 | 348 | # For each query_id, save the top_k corpus_ids sorted by score 349 | for q_id, doc_scores in final_result.items(): 350 | # Sort doc_scores by score and select top_k documents 351 | sorted_docs = sorted( 352 | doc_scores.items(), key=lambda item: item[1], reverse=True 353 | )[:top_k] 354 | 355 | # Write the query_id and corpus_id to the CSV 356 | for doc_id, _ in sorted_docs: 357 | writer.writerow([q_id, doc_id]) 358 | 359 | logger.info(f"Top {top_k} results saved successfully to {csv_file_path}") 360 | 361 | # Save generate_results to JSON Lines format 362 | if self.generate_results is not None: 363 | jsonl_file_path = os.path.join(output_dir, "output.jsonl") 364 | logger.info(f"Saving generate_results to JSONL file: {jsonl_file_path}") 365 | 366 | with open(jsonl_file_path, "w") as f: 367 | f.writelines( 368 | json.dumps({"query_id": q_id, "answer": answer}) + "\n" 369 | for q_id, answer in self.generate_results.items() 370 | ) 371 | 372 | logger.info(f"generate_results saved successfully to {jsonl_file_path}") 373 | 374 | # adapted from https://github.com/beir-cellar/beir/blob/main/beir/retrieval/evaluation.py 375 | @staticmethod 376 | def evaluate( 377 | qrels: Dict[str, Dict[str, int]], 378 | results: Dict[str, Dict[str, float]], 379 | k_values: List[int], 380 | ignore_identical_ids: bool = True 381 | ) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]]: 382 | 383 | if ignore_identical_ids: 384 | logger.info( 385 | 'For evaluation, we ignore identical query and document ids (default), ' 386 | 'please explicitly set ``ignore_identical_ids=False`` to ignore this.') 387 | popped = [] 388 | for qid, rels in results.items(): 389 | for pid in list(rels): 390 | if qid == pid: 391 | results[qid].pop(pid) # remove identical query-document pairs 392 | popped.append(pid) 393 | 394 | # Filter results to only keep queries that are present in qrels 395 | filtered_results = {qid: rels for qid, rels in results.items() if qid in qrels} 396 | 397 | # Initialize dictionaries for evaluation metrics 398 | ndcg = {} 399 | _map = {} 400 | recall = {} 401 | precision = {} 402 | 403 | # Initialize metric values for each k in k_values 404 | for k in k_values: 405 | ndcg[f"NDCG@{k}"] = 0.0 406 | _map[f"MAP@{k}"] = 0.0 407 | recall[f"Recall@{k}"] = 0.0 408 | precision[f"P@{k}"] = 0.0 409 | 410 | # Define strings for pytrec_eval evaluation 411 | map_string = "map_cut." + ",".join([str(k) for k in k_values]) 412 | ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values]) 413 | recall_string = "recall." + ",".join([str(k) for k in k_values]) 414 | precision_string = "P." + ",".join([str(k) for k in k_values]) 415 | 416 | # Perform evaluation using pytrec_eval with filtered results 417 | evaluator = pytrec_eval.RelevanceEvaluator(qrels, 418 | {map_string, ndcg_string, recall_string, precision_string}) 419 | scores = evaluator.evaluate(filtered_results) 420 | 421 | # Aggregate the scores for each query and each k 422 | for query_id in scores.keys(): 423 | for k in k_values: 424 | ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)] 425 | _map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)] 426 | recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)] 427 | precision[f"P@{k}"] += scores[query_id]["P_" + str(k)] 428 | 429 | # Compute the average scores for each k 430 | for k in k_values: 431 | ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"] / len(scores), 5) 432 | _map[f"MAP@{k}"] = round(_map[f"MAP@{k}"] / len(scores), 5) 433 | recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"] / len(scores), 5) 434 | precision[f"P@{k}"] = round(precision[f"P@{k}"] / len(scores), 5) 435 | 436 | # Log the results for each metric 437 | for _eval in [ndcg, _map, recall, precision]: 438 | logger.info("\n") 439 | for k in _eval.keys(): 440 | logger.info("{}: {:.4f}".format(k, _eval[k])) 441 | 442 | return ndcg, _map, recall, precision --------------------------------------------------------------------------------