├── .env.example
├── .venv
├── lib64
├── bin
│ ├── python
│ ├── python3.10
│ └── python3
└── pyvenv.cfg
├── financerag
├── generate
│ ├── __init__.py
│ └── openai.py
├── rerank
│ ├── __init__.py
│ └── cross_encoder.py
├── common
│ ├── __init__.py
│ ├── loader.py
│ └── protocols.py
├── retrieval
│ ├── __init__.py
│ ├── sent_encoder.py
│ ├── bm25.py
│ └── dense.py
├── __init__.py
└── tasks
│ ├── __init__.py
│ ├── FinDERTask.py
│ ├── FinQABenchTask.py
│ ├── FinQATask.py
│ ├── ConvFinQATask.py
│ ├── FinanceBenchTask.py
│ ├── MultiHierttTask.py
│ ├── TATQATask.py
│ ├── TaskMetadata.py
│ └── BaseTask.py
├── src
├── main.py
└── __init__.py
├── .idea
├── vcs.xml
├── misc.xml
├── .gitignore
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── modules.xml
└── linq-rag-FinanceRAG.iml
├── activate.sh
├── requirements.txt
├── activate.bat
├── LICENSE
├── README.md
├── GAR
├── run_hybrid_search.py
├── Selection_agent.py
├── main.py
├── run_finetuning.py
├── Finetuned_DPO_agent.py
├── DPO_agent.py
├── Embedder_Finetuning.py
├── hybrid_search.py
└── DPO_agent_Finetuning.py
└── .gitignore
/.env.example:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.venv/lib64:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.venv/bin/python:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.venv/pyvenv.cfg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.venv/bin/python3.10:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.venv/bin/python3:
--------------------------------------------------------------------------------
1 | /usr/bin/python3
--------------------------------------------------------------------------------
/financerag/generate/__init__.py:
--------------------------------------------------------------------------------
1 | from .openai import OpenAIGenerator
2 |
--------------------------------------------------------------------------------
/financerag/rerank/__init__.py:
--------------------------------------------------------------------------------
1 | from .cross_encoder import CrossEncoderReranker
2 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | def main():
2 | """
3 | 메인 함수
4 | """
5 | print("GAR Project Started")
6 |
7 | if __name__ == "__main__":
8 | main()
--------------------------------------------------------------------------------
/financerag/common/__init__.py:
--------------------------------------------------------------------------------
1 | from .loader import HFDataLoader
2 | from .protocols import CrossEncoder, Encoder, Generator, Lexical, Reranker, Retrieval
3 |
--------------------------------------------------------------------------------
/financerag/retrieval/__init__.py:
--------------------------------------------------------------------------------
1 | from .bm25 import BM25Retriever
2 | from .dense import DenseRetrieval
3 | from .sent_encoder import SentenceTransformerEncoder
4 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | # 패키지 초기화 파일
2 |
3 | # 버전 정보
4 | __version__ = '0.1.0'
5 |
6 | # 주요 함수들을 패키지 레벨로 노출
7 | from .main import main
8 |
9 | # 공개할 함수들 정의
10 | __all__ = ['main']
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Editor-based HTTP Client requests
5 | /httpRequests/
6 | # Datasource local storage ignored files
7 | /dataSources/
8 | /dataSources.local.xml
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/financerag/__init__.py:
--------------------------------------------------------------------------------
1 | from .retrieval import (
2 | DenseRetrieval,
3 | BM25Retriever,
4 | SentenceTransformerEncoder
5 | )
6 |
7 | from .rerank import (
8 | CrossEncoderReranker
9 | )
10 |
11 | from .generate import (
12 | OpenAIGenerator
13 | )
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/linq-rag-FinanceRAG.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/financerag/tasks/__init__.py:
--------------------------------------------------------------------------------
1 | from .BaseTask import BaseTask
2 | from .ConvFinQATask import ConvFinQA
3 | from .FinDERTask import FinDER
4 | from .FinQABenchTask import FinQABench
5 | from .FinQATask import FinQA
6 | from .FinanceBenchTask import FinanceBench
7 | from .MultiHierttTask import MultiHiertt
8 | from .TATQATask import TATQA
9 | from .TaskMetadata import TaskMetadata
10 |
--------------------------------------------------------------------------------
/activate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ ! -d "venv" ]; then
4 | echo "Creating virtual environment..."
5 | python3 -m venv venv
6 | source venv/bin/activate
7 | pip install -r requirements.txt
8 | else
9 | source venv/bin/activate
10 | fi
11 |
12 |
13 | set -a
14 | source .env
15 | set +a
16 |
17 | echo "Virtual environment activated and environment variables set!"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets~=3.0.0
2 | pydantic~=2.9.2
3 | pytrec_eval~=0.5
4 | typing_extensions~=4.12.2
5 | numpy==1.26.4
6 | torch~=2.2.2
7 | torchvision
8 | nltk~=3.9.1
9 | sentence-transformers==3.1.1
10 | transformers==4.45.2
11 | openai~=1.47.1
12 | numpy==1.26.4
13 | pandas==2.2.2
14 | tqdm==4.67.1
15 | python-dotenv
16 | scikit-learn
17 | transformers==4.45.2
18 | huggingface_hub
19 | rank_bm25
20 | flash-attn
--------------------------------------------------------------------------------
/activate.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | set VENV_PATH=C:\Users\shw41\GAR\venv
4 |
5 | if exist "%VENV_PATH%" (
6 | echo Found existing .venv at %VENV_PATH%
7 | call %VENV_PATH%\Scripts\activate.bat
8 | ) else (
9 | echo Virtual environment .venv not found!
10 | echo Creating virtual environment...
11 | python -m venv %VENV_PATH%
12 | call %VENV_PATH%\Scripts\activate.bat
13 | pip install -r requirements.txt
14 | )
15 |
16 | set PYTHONPATH=%PYTHONPATH%;%CD%
17 | set VIRTUAL_ENV=%VENV_PATH%
18 | set PATH=%VIRTUAL_ENV%\Scripts;%PATH%
19 |
20 | echo Virtual environment activated and environment variables set!
--------------------------------------------------------------------------------
/financerag/tasks/FinDERTask.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | from .BaseTask import BaseTask
4 | from .TaskMetadata import TaskMetadata
5 |
6 |
7 | class FinDER(BaseTask):
8 | def __init__(self):
9 | self.metadata: TaskMetadata = TaskMetadata(
10 | name="FinDER",
11 | description="Prepared for competition from Linq",
12 | reference=None,
13 | dataset={
14 | "path": "Linq-AI-Research/FinanceRAG",
15 | "subset": "FinDER",
16 | },
17 | type="RAG",
18 | category="s2p",
19 | modalities=["text"],
20 | date=None,
21 | domains=["Report"],
22 | task_subtypes=[
23 | "Financial retrieval",
24 | "Question answering",
25 | ],
26 | license=None,
27 | annotations_creators="expert-annotated",
28 | dialect=[],
29 | sample_creation="human-generated",
30 | bibtex_citation=None,
31 | )
32 | super().__init__(self.metadata)
33 |
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 JiH00nKw0n
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/financerag/tasks/FinQABenchTask.py:
--------------------------------------------------------------------------------
1 | from .BaseTask import BaseTask
2 | from .TaskMetadata import TaskMetadata
3 |
4 |
5 | class FinQABench(BaseTask):
6 | def __init__(self):
7 | self.metadata: TaskMetadata = TaskMetadata(
8 | name="FinQABench",
9 | description="FinQABench: A New QA Benchmark for Finance applications",
10 | reference="https://huggingface.co/datasets/lighthouzai/finqabench",
11 | dataset={
12 | "path": "Linq-AI-Research/FinanceRAG",
13 | "subset": "FinQABench",
14 | },
15 | type="RAG",
16 | category="s2p",
17 | modalities=["text"],
18 | date=None,
19 | domains=["Report"],
20 | task_subtypes=[
21 | "Financial retrieval",
22 | "Question answering",
23 | ],
24 | license="apache-2.0",
25 | annotations_creators="LM-generated and reviewed",
26 | dialect=[],
27 | sample_creation="LM-generated and verified",
28 | bibtex_citation=None,
29 | )
30 | super().__init__(self.metadata)
31 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GAR
2 |
3 | This project is a RAG-based (Retrieval-Augmented Generation) answer generation and refinement pipeline specifically designed for question answering in the finance domain.
4 |
5 | ## Project Structure
6 |
7 | GAR/
8 | ├── task1/ # Hybrid Search implementation
9 | │ └── Embedder_Finetuning.py
10 | │ └── hybrid_search.py
11 | │ └── run_finetuning.py
12 | │ └── run_hybrid_search.py
13 | │ └── main.py # Main pipeline execution
14 | └── requirements.txt # List of required packages
15 |
16 |
17 | ## Pipeline Description
18 | Embedder_Finetuning.py: Fine-tuning retrieval model
19 | hybrid_search.py: Execute hybrid search, and find optimal alpha for each task
20 | run_finetuning.py: Run Embedder_Finetuning.py
21 | run_hybrid_search.py: Run hybrid_search.py
22 |
23 | ## Installation
24 |
25 | 1. Clone the repository
26 | git clone https://github.com/seohyunwoo-0407/GAR.git
27 | cd GAR
28 |
29 |
30 | 2. Create and activate a virtual environment
31 | python -m venv .venv
32 | source .venv/bin/activate # Linux/Mac
33 | .venv\Scripts\activate # Windows
34 |
35 | 3. Install required packages
36 | pip install -r requirements.txt
37 |
--------------------------------------------------------------------------------
/GAR/run_hybrid_search.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import logging
4 | import nltk
5 | from financerag.tasks import FinDER, FinQABench, FinanceBench, TATQA, FinQA, ConvFinQA, MultiHiertt
6 | from financerag.retrieval import DenseRetrieval, SentenceTransformerEncoder
7 | from sentence_transformers import SentenceTransformer, CrossEncoder
8 | import pandas as pd
9 | from hybrid_search import HybridSearcher
10 |
11 | logging.basicConfig(level=logging.INFO)
12 |
13 | nltk.download('punkt')
14 | hybrid_searcher = HybridSearcher()
15 |
16 | tasks, names = hybrid_searcher.setup_task()
17 |
18 | retrieval_model = hybrid_searcher.retrieval_model_setup()
19 | reranker_model = hybrid_searcher.reranker_model_setup()
20 |
21 | output_dir = "C:/Users/shw41/GAR/data/task1"
22 |
23 | for task, name in zip(tasks, names):
24 | optimal_alpha = hybrid_searcher.tune_alpha(task, retrieval_model)
25 | hybrid_retrieval_results = hybrid_searcher.get_hybrid_score(task, optimal_alpha, retrieval_model)
26 | reranked_results = hybrid_searcher.get_reranker_score(task, hybrid_retrieval_results, reranker_model)
27 | ndcg_score = hybrid_searcher.get_final_ndcg(tasks, names)
28 |
29 | print(f"Task: {name}, NDCG Score: {ndcg_score}")
30 | print(f"Optimal Alpha: {optimal_alpha}")
31 | task.save_results(output_dir, name, reranked_results)
32 |
33 | hybrid_searcher.merge_csv_results(output_dir)
34 |
--------------------------------------------------------------------------------
/GAR/Selection_agent.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import os
3 | from dotenv import load_dotenv
4 | import json
5 | load_dotenv()
6 | client = openai.OpenAI(
7 | api_key = os.getenv('OPENAI_API_KEY')
8 | )
9 | model_name = 'gpt-4o-mini'
10 | def load_jsonl(path):
11 | with open(path, 'r', encoding='utf-8') as file:
12 | data = [json.loads(line) for line in file]
13 | return data
14 | def save_to_json(data, path):
15 | with open(path, 'w', encoding='utf-8') as file:
16 | for i in data:
17 | file.write(json.dumps(i, ensure_ascii=False) + '\n')
18 | def evaluate_corpus_relevance(query_line, corpus_list):
19 | prompt = f"""Given the following query and 10 corpus passages, determine if each corpus is relevant for answering the query.
20 | Respond ONLY with 10 letters (T or F) in sequence, without any additional text.
21 | T means the corpus is relevant, F means it is not relevant.
22 | Query: {query_line}
23 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(corpus_list)])
24 | response = client.chat.completions.create(
25 | model=model_name,
26 | messages=[{"role": "user", "content": prompt}],
27 | temperature=0
28 | )
29 | relevance = dict(response)['choices'][0].message.content
30 | selected_corpus = [corpus for corpus, is_relevant in zip(corpus_list, relevance) if is_relevant == 'T']
31 | combined_corpus = "\n".join([f"###DOCUMENT_{i+1}\n{corpus}\n" for i, corpus in enumerate(selected_corpus)])
32 | print(f"relevance: {relevance}")
33 | return combined_corpus
--------------------------------------------------------------------------------
/financerag/tasks/FinQATask.py:
--------------------------------------------------------------------------------
1 | from .BaseTask import BaseTask
2 | from .TaskMetadata import TaskMetadata
3 |
4 |
5 | class FinQA(BaseTask):
6 | def __init__(self):
7 | self.metadata: TaskMetadata = TaskMetadata(
8 | name="FinQA",
9 | description="FinQA: A Dataset of Numerical Reasoning over Financial Data",
10 | reference="https://github.com/czyssrs/FinQA",
11 | dataset={
12 | "path": "Linq-AI-Research/FinanceRAG",
13 | "subset": "FinQA",
14 | },
15 | type="RAG",
16 | category="s2p",
17 | modalities=["text"],
18 | date=None,
19 | domains=["Report"],
20 | task_subtypes=[
21 | "Financial retrieval",
22 | "Question answering",
23 | ],
24 | license="mit",
25 | annotations_creators="expert-annotated",
26 | dialect=[],
27 | sample_creation="human-generated",
28 | bibtex_citation="""
29 | @article{chen2021finqa,
30 | title={FinQA: A Dataset of Numerical Reasoning over Financial Data},
31 | author={Chen, Zhiyu and Chen, Wenhu and Smiley, Charese and Shah, Sameena and Borova, Iana and Langdon, Dylan and Moussa, Reema and Beane, Matt and Huang, Ting-Hao and Routledge, Bryan and Wang, William Yang},
32 | journal={Proceedings of EMNLP 2021},
33 | year={2021}
34 | }
35 | """,
36 | )
37 | super().__init__(self.metadata)
38 |
--------------------------------------------------------------------------------
/financerag/tasks/ConvFinQATask.py:
--------------------------------------------------------------------------------
1 | from .BaseTask import BaseTask
2 | from .TaskMetadata import TaskMetadata
3 |
4 |
5 | class ConvFinQA(BaseTask):
6 | def __init__(self):
7 | self.metadata: TaskMetadata = TaskMetadata(
8 | name="ConvFinQA",
9 | description="ConvFinQA: Exploring the Chain of Numerical Reasoning in Conversational Finance Question Answering",
10 | reference="https://github.com/czyssrs/ConvFinQA",
11 | dataset={
12 | "path": "Linq-AI-Research/FinanceRAG",
13 | "subset": "ConvFinQA",
14 | },
15 | type="RAG",
16 | category="s2p",
17 | modalities=["text"],
18 | date=None,
19 | domains=["Report"],
20 | task_subtypes=[
21 | "Financial retrieval",
22 | "Question answering",
23 | ],
24 | license="mit",
25 | annotations_creators="expert-annotated",
26 | dialect=[],
27 | sample_creation="human-generated",
28 | bibtex_citation="""
29 | @article{chen2022convfinqa,
30 | title={ConvFinQA: Exploring the Chain of Numerical Reasoning in Conversational Finance Question Answering},
31 | author={Chen, Zhiyu and Li, Shiyang and Smiley, Charese and Ma, Zhiqiang and Shah, Sameena and Wang, William Yang},
32 | journal={Proceedings of EMNLP 2022},
33 | year={2022}
34 | }
35 | """,
36 | )
37 | super().__init__(self.metadata)
38 |
--------------------------------------------------------------------------------
/financerag/tasks/FinanceBenchTask.py:
--------------------------------------------------------------------------------
1 | from .BaseTask import BaseTask
2 | from .TaskMetadata import TaskMetadata
3 |
4 |
5 | class FinanceBench(BaseTask):
6 | def __init__(self):
7 | self.metadata: TaskMetadata = TaskMetadata(
8 | name="FinanceBench",
9 | description="FinanceBench: A New Benchmark for Financial Question Answering",
10 | reference="https://github.com/patronus-ai/financebench",
11 | dataset={
12 | "path": "Linq-AI-Research/FinanceRAG",
13 | "subset": "FinanceBench",
14 | },
15 | type="RAG",
16 | category="s2p",
17 | modalities=["text"],
18 | date=None,
19 | domains=["Report"],
20 | task_subtypes=[
21 | "Financial retrieval",
22 | "Question answering",
23 | ],
24 | license=None,
25 | annotations_creators="expert-annotated",
26 | dialect=[],
27 | sample_creation="human-generated",
28 | bibtex_citation="""
29 | @misc{islam2023financebench,
30 | title={FinanceBench: A New Benchmark for Financial Question Answering},
31 | author={Pranab Islam and Anand Kannappan and Douwe Kiela and Rebecca Qian and Nino Scherrer and Bertie Vidgen},
32 | year={2023},
33 | eprint={2311.11944},
34 | archivePrefix={arXiv},
35 | primaryClass={cs.CL}
36 | }
37 | """,
38 | )
39 | super().__init__(metadata=self.metadata)
--------------------------------------------------------------------------------
/GAR/main.py:
--------------------------------------------------------------------------------
1 | # main.py
2 | import os
3 | import pandas as pd
4 | from Selection_agent import SelectionAgent
5 | from DPO_agent import DPO_Agent
6 | from DPO_agent_Finetuning import dpo_agent_finetuning
7 | from Finetuned_DPO_agent import Finetuned_DPO_agent
8 | from openai import OpenAI
9 | def ensure_dir(file_path):
10 | directory = os.path.dirname(file_path)
11 | if not os.path.exists(directory):
12 | os.makedirs(directory)
13 | def main():
14 | API_KEY = os.getenv('OPENAI_API_KEY')
15 |
16 | dataset_path = 'data/raw/dataset.csv'
17 | selected_docs_path = 'data/processed/selected_docs.csv'
18 | dpo_responses_path = 'data/processed/dpo_responses.csv'
19 | final_answers_path = 'data/processed/final_answers.csv'
20 |
21 | # 1. Selection Agent
22 | selection_agent = SelectionAgent(API_KEY)
23 | df = pd.read_csv(dataset_path)
24 | selection_agent.process_data(df, selected_docs_path)
25 |
26 | # 2. DPO Agent
27 | dpo_agent = DPO_Agent(API_KEY)
28 | selected_df = pd.read_csv(selected_docs_path)
29 | dpo_agent.process_data(selected_df, dpo_responses_path)
30 |
31 | # 3. DPO Finetuning
32 | finetuning_agent = dpo_agent_finetuning(API_KEY)
33 | train_path = 'data/models/fine_tuned/train.jsonl'
34 | eval_path = 'data/models/fine_tuned/eval.jsonl'
35 | finetuning_agent.save_from_csv_to_jsonl(dpo_responses_path, selected_docs_path, train_path)
36 | finetuning_agent.split_jsonl(train_path, train_path, eval_path)
37 |
38 | # 4. Finetuned DPO Agent
39 | finetuned_agent = Finetuned_DPO_agent(API_KEY)
40 | finetuned_agent.process_finetuning(df, final_answers_path)
41 |
42 | if __name__ == '__main__':
43 | main()
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
18 |
19 |
20 |
31 |
32 |
33 |
39 |
40 |
41 |
--------------------------------------------------------------------------------
/financerag/tasks/MultiHierttTask.py:
--------------------------------------------------------------------------------
1 | from .BaseTask import BaseTask
2 | from .TaskMetadata import TaskMetadata
3 |
4 |
5 | class MultiHiertt(BaseTask):
6 | def __init__(self):
7 | self.metadata: TaskMetadata = TaskMetadata(
8 | name="MultiHiertt",
9 | description="MultiHiertt: Numerical Reasoning over Multi Hierarchical Tabular and Textual Data",
10 | reference="https://github.com/psunlpgroup/MultiHiertt",
11 | dataset={
12 | "path": "Linq-AI-Research/FinanceRAG",
13 | "subset": "MultiHiertt",
14 | },
15 | type="RAG",
16 | category="s2p",
17 | modalities=["text"],
18 | date=None,
19 | domains=["Report"],
20 | task_subtypes=[
21 | "Financial retrieval",
22 | "Question answering",
23 | ],
24 | license="mit",
25 | annotations_creators="expert-annotated",
26 | dialect=[],
27 | sample_creation="human-generated",
28 | bibtex_citation="""
29 | @inproceedings{zhao-etal-2022-multihiertt,
30 | title = "{M}ulti{H}iertt: Numerical Reasoning over Multi Hierarchical Tabular and Textual Data",
31 | author = "Zhao, Yilun and
32 | Li, Yunxiang and
33 | Li, Chenying and
34 | Zhang, Rui",
35 | booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
36 | month = may,
37 | year = "2022",
38 | address = "Dublin, Ireland",
39 | publisher = "Association for Computational Linguistics",
40 | url = "https://aclanthology.org/2022.acl-long.454",
41 | pages = "6588--6600",
42 | }
43 | """,
44 | )
45 | super().__init__(self.metadata)
--------------------------------------------------------------------------------
/GAR/run_finetuning.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dotenv import load_dotenv
3 | from sentence_transformers import SentenceTransformer
4 | from torch.utils.data import DataLoader
5 | from Embedder_Finetuning import EmbedderFinetuner
6 |
7 | def run_finetuning():
8 | load_dotenv()
9 |
10 |
11 | names = ['FinDER', 'FinQABench', 'FinanceBench', 'FinQA', 'TATQA', 'ConvFinQA', 'MultiHeirtt']
12 | markdown_dir = "Markdown"
13 | output_path = "./fine-tune-0203"
14 | hf_token = os.getenv("HUGGINGFACE_API_KEY")
15 | repo_id = os.getenv("repo_id")
16 | repo_owner = os.getenv("repo_owner")
17 | epochs = 2
18 | learning_rate = 2e-5
19 | batch_size = 4
20 | warmup_ratio = 0.1
21 |
22 |
23 | finetuner = EmbedderFinetuner()
24 | corpus_df, queries_df, relevant_docs_data, all_data = finetuner.load_datasets(names, markdown_dir)
25 | train_data, val_data, train_rel, val_rel = finetuner.split_train_val(all_data, relevant_docs_data)
26 |
27 |
28 | train_samples = finetuner.create_train_samples(train_data)
29 | print(f"Training samples: {len(train_samples)}")
30 | eval_examples = finetuner.create_eval_examples(val_data)
31 | eval_loader = DataLoader(eval_examples, shuffle=False, batch_size=batch_size)
32 |
33 |
34 | corpus_dict, queries_dict = finetuner.prepare_corpus_queries(corpus_df, queries_df, val_rel)
35 | relevant_docs = finetuner.build_relevant_docs(val_rel)
36 | ir_evaluator = finetuner.create_ir_evaluator(queries_dict, corpus_dict, relevant_docs, batch_size=batch_size)
37 |
38 |
39 | model = SentenceTransformer(
40 | 'NovaSearch/stella_en_1.5B_v5',
41 | trust_remote_code=True,
42 | config_kwargs={"use_memory_efficient_attention": True, "unpad_inputs": False}
43 | )
44 |
45 |
46 | print("Evaluating before fine-tuning:")
47 | ir_evaluator(model)
48 |
49 |
50 | finetuner.train_model(model, train_samples, ir_evaluator, output_path, epochs, learning_rate, warmup_ratio, batch_size)
51 |
52 |
53 | finetuner.upload_model_to_hub(output_path, repo_id, hf_token, repo_owner)
54 |
55 |
56 | finetuner.clear_gpu_memory()
57 |
58 | if __name__ == "__main__":
59 | run_finetuning()
--------------------------------------------------------------------------------
/financerag/tasks/TATQATask.py:
--------------------------------------------------------------------------------
1 | from .BaseTask import BaseTask
2 | from .TaskMetadata import TaskMetadata
3 |
4 |
5 | class TATQA(BaseTask):
6 | def __init__(self):
7 | self.metadata: TaskMetadata = TaskMetadata(
8 | name="TAT-QA",
9 | description="TAT-QA: A Question Answering Benchmark on a Hybrid of Tabular and Textual Content in Finance",
10 | reference="https://github.com/NExTplusplus/TAT-QA",
11 | dataset={
12 | "path": "Linq-AI-Research/FinanceRAG",
13 | "subset": "TATQA",
14 | },
15 | type="RAG",
16 | category="s2p",
17 | modalities=["text"],
18 | date=None,
19 | domains=["Report"],
20 | task_subtypes=[
21 | "Financial retrieval",
22 | "Question answering",
23 | ],
24 | license="mit",
25 | annotations_creators="human-annotated",
26 | dialect=[],
27 | sample_creation="human-generated",
28 | bibtex_citation="""
29 | @inproceedings{zhu-etal-2021-tat,
30 | title = "{TAT}-{QA}: A Question Answering Benchmark on a Hybrid of Tabular and Textual Content in Finance",
31 | author = "Zhu, Fengbin and
32 | Lei, Wenqiang and
33 | Huang, Youcheng and
34 | Wang, Chao and
35 | Zhang, Shuo and
36 | Lv, Jiancheng and
37 | Feng, Fuli and
38 | Chua, Tat-Seng",
39 | booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)",
40 | month = aug,
41 | year = "2021",
42 | address = "Online",
43 | publisher = "Association for Computational Linguistics",
44 | url = "https://aclanthology.org/2021.acl-long.254",
45 | doi = "10.18653/v1/2021.acl-long.254",
46 | pages = "3277--3287"
47 | }
48 | """,
49 | )
50 | super().__init__(self.metadata)
51 |
--------------------------------------------------------------------------------
/financerag/retrieval/sent_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Literal, Optional, Tuple, Union
2 |
3 | import numpy as np
4 | from sentence_transformers import SentenceTransformer
5 | from torch import Tensor
6 |
7 | from financerag.common import Encoder
8 |
9 |
10 | # Adopted by https://github.com/beir-cellar/beir/blob/main/beir/retrieval/models/sentence_bert.py
11 | class SentenceTransformerEncoder(Encoder):
12 |
13 | def __init__(
14 | self,
15 | model_name_or_path: Union[str, Tuple[str, str]],
16 | query_prompt: Optional[str] = None,
17 | doc_prompt: Optional[str] = None,
18 | **kwargs
19 | ):
20 | if isinstance(model_name_or_path, str):
21 | self.q_model = SentenceTransformer(model_name_or_path, **kwargs)
22 | self.doc_model = self.q_model
23 | elif isinstance(model_name_or_path, Tuple):
24 | self.q_model = SentenceTransformer(model_name_or_path[0], **kwargs)
25 | self.doc_model = SentenceTransformer(model_name_or_path[1], **kwargs)
26 | else:
27 | raise TypeError
28 | self.query_prompt = query_prompt
29 | self.doc_prompt = doc_prompt
30 |
31 | def encode_queries(
32 | self, queries: List[str], batch_size: int = 16, **kwargs
33 | ) -> Union[np.ndarray, Tensor]:
34 | if self.query_prompt is not None:
35 | queries = [self.query_prompt + query for query in queries]
36 | return self.q_model.encode(queries, batch_size=batch_size, **kwargs)
37 |
38 | def encode_corpus(
39 | self,
40 | corpus: Union[
41 | List[Dict[Literal["title", "text"], str]],
42 | Dict[Literal["title", "text"], List],
43 | ],
44 | batch_size: int = 8,
45 | **kwargs
46 | ) -> Union[np.ndarray, Tensor]:
47 | if isinstance(corpus, dict):
48 | sentences = [
49 | (
50 | (corpus["title"][i] + " " + corpus["text"][i]).strip()
51 | if "title" in corpus
52 | else corpus["text"][i].strip()
53 | )
54 | for i in range(len(corpus["text"]))
55 | ]
56 | else:
57 | sentences = [
58 | (
59 | (doc["title"] + " " + doc["text"]).strip()
60 | if "title" in doc
61 | else doc["text"].strip()
62 | )
63 | for doc in corpus
64 | ]
65 | if self.doc_prompt is not None:
66 | sentences = [self.doc_prompt + s for s in sentences]
67 | return self.doc_model.encode(sentences, batch_size=batch_size, **kwargs)
68 |
--------------------------------------------------------------------------------
/GAR/Finetuned_DPO_agent.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import pandas as pd
3 | import numpy as np
4 | from tqdm import tqdm
5 | import json
6 | import random
7 | import dotenv
8 | import os
9 | dotenv.load_dotenv()
10 | def save_from_csv_to_jsonl(csv_file, jsonl_file):
11 | df = pd.read_csv(csv_file)
12 | with open(jsonl_file, 'w', encoding='utf-8') as f:
13 | for index, row in tqdm(df.iterrows(), total=df.shape[0]):
14 | json_obj = {
15 | "input": {
16 | "messages": [
17 | {"role": "user", "content": row['query']}
18 | ],
19 | "tools": [],
20 | "parallel_toolcalls": True
21 | },
22 | "preferred_output": [
23 | {"role": "assistant", "content": row['response1']}
24 | ],
25 | "non_preferred_output": [
26 | {"role": "assistant", "content": row['response2']}
27 | ]
28 | }
29 | f.write(json.dumps(json_obj, ensure_ascii=False) + '\n')
30 | path = "/content/drive/MyDrive/FinanceRAG-GAR/본선/data_files/dpo_data_to67.csv"
31 | saving_path = "/content/drive/MyDrive/FinanceRAG-GAR/본선/data_files/dpo_data_to67.jsonl"
32 | save_from_csv_to_jsonl(path, saving_path)
33 | """
34 | query | response1 | response2 |
35 | {
36 | "input": {
37 | "messages": [
38 | {
39 | "role": "user",
40 | "content": "Hello, can you tell me how cold San Francisco is today?"
41 | }
42 | ],
43 | "tools": [],
44 | "parallel_tool_calls": true
45 | },
46 | "preferred_output": [
47 | {
48 | "role": "assistant",
49 | "content": "Today in San Francisco, it is not quite cold as expected. Morning clouds will give away to sunshine, with a high near 68°F (20°C) and a low around 57°F (14°C)."
50 | }
51 | ],
52 | "non_preferred_output": [
53 | {
54 | "role": "assistant",
55 | "content": "It is not particularly cold in San Francisco today."
56 | }
57 | ]
58 | }
59 | ...
60 | """
61 | model_name = 'gpt-4o-mini' # 'gpt-4o'
62 | def load_jsonl(path):
63 | with open(path, 'r', encoding='utf-8') as file:
64 | data = [json.loads(line) for line in file]
65 | return data
66 | training_data = load_jsonl(saving_path)
67 | client = openai.OpenAI(
68 | api_key = os.getenv('OPENAI_API_KEY')
69 | )
70 | response = client.fine_tuning.jobs.create(
71 | training_file=training_data,
72 | model=model_name,
73 | method={
74 | "type": "dpo",
75 | "dpo": {
76 | "hyperparameters": {"beta": 0.1}
77 | }
78 | }
79 | )
80 | response
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/financerag/retrieval/bm25.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Any, Callable, Dict, List, Literal, Optional
3 |
4 | import numpy as np
5 | from nltk.tokenize import word_tokenize
6 |
7 | from financerag.common import Lexical, Retrieval
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def tokenize_list(input_list: List[str]) -> List[List[str]]:
13 | """
14 | Tokenizes a list of strings using the `nltk.word_tokenize` function.
15 |
16 | Args:
17 | input_list (`List[str]`):
18 | A list of input strings to be tokenized.
19 |
20 | Returns:
21 | `List[List[str]]`:
22 | A list where each element is a list of tokens corresponding to an input string.
23 | """
24 | return list(map(word_tokenize, input_list))
25 |
26 |
27 | class BM25Retriever(Retrieval):
28 | """
29 | A retrieval class that utilizes a lexical model (e.g., BM25) to search for the most relevant documents
30 | from a given corpus based on the input queries. This retriever tokenizes the queries and uses the provided
31 | lexical model to compute relevance scores between the queries and documents in the corpus.
32 |
33 | Methods:
34 | - retrieve: Searches for relevant documents based on the given queries, returning the top-k results.
35 | """
36 |
37 | def __init__(self, model: Lexical, tokenizer: Callable[[List[str]], List[List[str]]] = tokenize_list):
38 | """
39 | Initializes the `BM25Retriever` class with a lexical model and a tokenizer function.
40 |
41 | Args:
42 | model (`Lexical`):
43 | A lexical model (e.g., BM25) implementing the `Lexical` protocol, responsible for calculating relevance scores.
44 | tokenizer (`Callable[[List[str]], List[List[str]]]`, *optional*):
45 | A function that tokenizes the input queries. Defaults to `tokenize_list`, which uses `nltk.word_tokenize`.
46 | """
47 | self.model: Lexical = model
48 | self.tokenizer: Callable[[List[str]], List[List[str]]] = tokenizer
49 | self.results: Optional[Dict[str, Any]] = {}
50 |
51 | def retrieve(
52 | self,
53 | corpus: Dict[str, Dict[Literal["title", "text"], str]],
54 | queries: Dict[str, str],
55 | top_k: Optional[int] = None,
56 | score_function: Optional[str] = None,
57 | return_sorted: bool = False,
58 | **kwargs
59 | ) -> Dict[str, Dict[str, float]]:
60 | """
61 | Searches the corpus for the most relevant documents based on the given queries. The retrieval process involves
62 | tokenizing the queries, calculating relevance scores using the lexical model, and returning the top-k results
63 | for each query.
64 |
65 | Args:
66 | corpus (`Dict[str, Dict[Literal["title", "text"], str]]`):
67 | A dictionary representing the corpus, where each key is a document ID, and each value is another dictionary
68 | containing document fields such as 'id', 'title', and 'text'.
69 | queries (`Dict[str, str]`):
70 | A dictionary containing query IDs and corresponding query texts.
71 | top_k (`Optional[int]`, *optional*):
72 | The number of top documents to return for each query. If not provided, all documents are returned. Defaults to `None`.
73 | return_sorted (`bool`, *optional*):
74 | Whether to return the results sorted by score. Defaults to `False`.
75 | **kwargs:
76 | Additional keyword arguments passed to the lexical model during scoring.
77 |
78 | Returns:
79 | `Dict[str, Dict[str, float]]`:
80 | A dictionary where each key is a query ID, and the value is another dictionary mapping document IDs to relevance scores.
81 | """
82 | query_ids = list(queries.keys())
83 | self.results = {qid: {} for qid in query_ids}
84 |
85 | logger.info("Tokenizing queries with lower cases")
86 | query_lower_tokens = self.tokenizer([queries[qid].lower() for qid in queries])
87 |
88 | corpus_ids = list(corpus.keys())
89 |
90 | for qid, query in zip(query_ids, query_lower_tokens):
91 | scores = self.model.get_scores(query)
92 | top_k_result = np.argsort(scores)[::-1][:top_k]
93 | for idx in top_k_result:
94 | self.results[qid][corpus_ids[idx]] = scores[idx]
95 |
96 | return self.results
--------------------------------------------------------------------------------
/GAR/DPO_agent.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | from openai import OpenAI
3 | import os
4 | import json
5 | import numpy as np
6 | import openai
7 | from tqdm import tqdm
8 | from dotenv import load_dotenv
9 | load_dotenv()
10 | class DPOTrainer:
11 | def __init__(self, model_name):
12 | load_dotenv()
13 | self.model_name = model_name
14 | self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
15 | def generate_output(self, system_prompt, user_prompt, temp=0, returnRaw=False):
16 | response = self.client.chat.completions.create(
17 | model=self.model_name,
18 | messages=[
19 | {"role": "system", "content": system_prompt},
20 | {"role": "user", "content": user_prompt}
21 | ],
22 | temperature=temp
23 | )
24 | text = response.choices[0].message.content
25 | return response if returnRaw else text
26 | def evaluate_corpus_relevance(self, query_line, corpus_list):
27 | relevancy_instruction = """
28 | You are an expert financial advisor and evaluator focused on improving responses.
29 | Your task is to enhance answers based on detailed evaluation scores while:
30 | - Maintaining factual accuracy with the provided documents
31 | - Ensuring responses are clear and well-structured for financial contexts
32 | - Providing comprehensive answers that address all aspects of the query
33 | - Using professional financial terminology appropriately
34 | You are given the pair of Query, Corpus (same query)
35 | Out of the 10 documents, only provide the list of indices of those that are RELEVANT (e.g. the content is somehow needed to answer the question), from 0~9.
36 | Example : [0, 2, 8, 9]
37 | """
38 | user_prompt = f"""
39 | Query: {query_line}
40 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(corpus_list)])
41 | try:
42 | relevancy = eval(self.generate_output(relevancy_instruction, user_prompt, temp=0))
43 | except Exception as error:
44 | print(error)
45 | relevancy = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
46 | if relevancy == []:
47 | relevancy = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
48 | return relevancy
49 | def answer_query(self, query_line, corpus_list, temp=0, returnRaw=False):
50 | financial_prompt = """
51 | You are an expert financial advisor and evaluator focused on improving responses.
52 | Your task is to enhance answers based on detailed evaluation scores while:
53 | - Maintaining factual accuracy with the provided documents
54 | - Ensuring responses are clear and well-structured for financial contexts
55 | - Providing comprehensive answers that address all aspects of the query
56 | - Using professional financial terminology appropriately
57 | """
58 | query_prompt = f"""
59 | Query: {query_line}
60 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(corpus_list)])
61 | query_prompt += "\nDo not add any closing pleasantries or phrases like 'please feel free to ask!'"
62 | response = self.generate_output(financial_prompt, query_prompt, temp=temp, returnRaw=returnRaw)
63 | return response
64 | def process_data(self, input_csv_path, output_csv_path):
65 | df = pd.read_csv(input_csv_path)
66 | dpo_data = pd.DataFrame({'query': [], 'response1': [], 'response2': []})
67 | for i in tqdm(range(0, len(df), 10)):
68 | chunk = df[i:i+10]
69 | query_line = chunk['Query'].iloc[0]
70 | corpus_list = chunk['Corpus'].tolist()
71 | corpus_rel = self.evaluate_corpus_relevance(query_line, corpus_list)
72 | try:
73 | relevant_corpus = chunk.iloc[corpus_rel]
74 | except:
75 | relevant_corpus = chunk
76 | relevant_corpus = relevant_corpus['Corpus'].tolist()
77 | response1 = self.answer_query(query_line, relevant_corpus, temp=0)
78 | response2 = self.answer_query(query_line, relevant_corpus, temp=0.5)
79 | dpo_data.loc[i] = [query_line, response1, response2]
80 | dpo_data.to_csv(output_csv_path)
81 | return dpo_data
82 | if __name__ == "__main__":
83 | model_name = "gpt-4o-mini" # or "ft:gpt-4o-mini-2024-07-18:personal::AkFnSqzI"
84 | trainer = DPOTrainer(model_name)
85 |
86 | input_path = '/content/drive/MyDrive/FinanceRAG-GAR/본선/data_files/sampled_data.csv'
87 | output_path = '/content/drive/MyDrive/FinanceRAG-GAR/본선/data_files/sampled_answer_4o.csv'
88 |
89 | trainer.process_data(input_path, output_path)
--------------------------------------------------------------------------------
/financerag/generate/openai.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import multiprocessing
3 | import os
4 | from multiprocessing import Pool
5 | from typing import Any, Dict, List, Tuple, cast
6 |
7 | import openai
8 | from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
9 |
10 | from financerag.common.protocols import Generator
11 |
12 | openai.api_key = os.getenv("OPENAI_API_KEY")
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | class OpenAIGenerator(Generator):
18 | """
19 | A class that interfaces with the OpenAI API to generate responses using a specified model. It implements the
20 | `Generator` protocol and supports generating responses in parallel using multiple processes.
21 |
22 | Args:
23 | model_name (`str`):
24 | The name of the OpenAI model to use for generating responses (e.g., "gpt-4", "gpt-3.5-turbo").
25 | """
26 |
27 | def __init__(self, model_name: str):
28 | """
29 | Initializes the OpenAIGenerator with the specified model name.
30 |
31 | Args:
32 | model_name (`str`):
33 | The OpenAI model name used to generate responses.
34 | """
35 | self.model_name: str = model_name
36 | self.results: Dict = {}
37 |
38 | def _process_query(
39 | self, args: Tuple[str, List[ChatCompletionMessageParam], Dict[str, Any]]
40 | ) -> Tuple[str, str]:
41 | """
42 | Internal method to process a single query using the OpenAI model. It sends the query and messages to the
43 | OpenAI API and retrieves the response.
44 |
45 | Args:
46 | args (`Tuple[str, List[ChatCompletionMessageParam], Dict[str, Any]]`):
47 | Contains the query ID, a list of messages (query), and additional arguments for the model.
48 |
49 | Returns:
50 | `Tuple[str, str]`:
51 | A tuple containing the query ID and the generated response.
52 | """
53 | q_id, messages, kwargs = args
54 | temperature = kwargs.pop("temperature", 1.0)
55 | top_p = kwargs.pop("top_p", 1.0)
56 | stream = kwargs.pop("stream", False)
57 | max_tokens = kwargs.pop("max_tokens", 10000)
58 | presence_penalty = kwargs.pop("presence_penalty", 0.0)
59 | frequency_penalty = kwargs.pop("frequency_penalty", 0.0)
60 |
61 | client = openai.OpenAI()
62 | response = client.chat.completions.create(
63 | model=self.model_name,
64 | messages=messages,
65 | temperature=temperature,
66 | top_p=top_p,
67 | stream=stream,
68 | max_tokens=max_tokens,
69 | presence_penalty=presence_penalty,
70 | frequency_penalty=frequency_penalty,
71 | )
72 | return q_id, response.choices[0].message.content
73 |
74 | def generation(
75 | self,
76 | messages: Dict[str, List[Dict[str, str]]],
77 | num_processes: int = multiprocessing.cpu_count(), # Number of parallel processes
78 | **kwargs,
79 | ) -> Dict[str, str]:
80 | """
81 | Generate responses for the given messages using the OpenAI model. This method supports parallel processing
82 | using multiprocessing to speed up the generation process for multiple queries.
83 |
84 | Args:
85 | messages (`Dict[str, List[Dict[str, str]]]`):
86 | A dictionary where the keys are query IDs, and the values are lists of dictionaries representing the
87 | messages (queries).
88 | num_processes (`int`, *optional*, defaults to `multiprocessing.cpu_count()`):
89 | The number of processes to use for parallel generation.
90 | **kwargs:
91 | Additional keyword arguments for the OpenAI model (e.g., temperature, top_p, max_tokens).
92 |
93 | Returns:
94 | `Dict[str, str]`:
95 | A dictionary where each key is a query ID, and the value is the generated response.
96 | """
97 | logger.info(
98 | f"Starting generation for {len(messages)} queries using {num_processes} processes..."
99 | )
100 |
101 | # Prepare arguments for multiprocessing
102 | query_args = [
103 | (q_id, cast(list[ChatCompletionMessageParam], msg), kwargs.copy())
104 | for q_id, msg in messages.items()
105 | ]
106 |
107 | # Use multiprocessing Pool for parallel generation
108 | with Pool(processes=num_processes) as pool:
109 | results = pool.map(self._process_query, query_args)
110 |
111 | # Collect results
112 | self.results = {q_id: content for q_id, content in results}
113 |
114 | logger.info(
115 | f"Generation completed for all queries. Collected {len(self.results)} results."
116 | )
117 |
118 | return self.results
--------------------------------------------------------------------------------
/financerag/rerank/cross_encoder.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Dict, Optional
3 |
4 | from financerag.common import CrossEncoder, Reranker
5 |
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | # Adapted from https://github.com/beir-cellar/beir/blob/main/beir/reranking/rerank.py
10 | class CrossEncoderReranker(Reranker):
11 | """
12 | A reranker class that utilizes a cross-encoder model from the `sentence-transformers` library
13 | to rerank search results based on query-document pairs. This class implements a reranking
14 | mechanism using cross-attention, where each query-document pair is passed through the
15 | cross-encoder model to compute relevance scores.
16 |
17 | The cross-encoder model expects two inputs (query and document) and directly computes a
18 | score indicating the relevance of the document to the query. The model follows the
19 | `CrossEncoder` protocol, ensuring it is compatible with `sentence-transformers` cross-encoder models.
20 |
21 | Methods:
22 | rerank:
23 | Takes in a corpus, queries, and initial retrieval results, and reranks
24 | the top-k documents using the cross-encoder model.
25 | """
26 |
27 | def __init__(self, model: CrossEncoder):
28 | """
29 | Initializes the `CrossEncoderReranker` class with a cross-encoder model.
30 |
31 | Args:
32 | model (`CrossEncoder`):
33 | A cross-encoder model implementing the `CrossEncoder` protocol from the `sentence-transformers` library.
34 | """
35 | self.model: CrossEncoder = model
36 | self.results: Dict[str, Dict[str, float]] = {}
37 |
38 | def rerank(
39 | self,
40 | corpus: Dict[str, Dict[str, str]],
41 | queries: Dict[str, str],
42 | results: Dict[str, Dict[str, float]],
43 | top_k: int,
44 | batch_size: Optional[int] = None,
45 | **kwargs
46 | ) -> Dict[str, Dict[str, float]]:
47 | """
48 | Reranks the top-k documents for each query based on cross-encoder model predictions.
49 |
50 | Args:
51 | corpus (`Dict[str, Dict[str, str]]`):
52 | A dictionary representing the corpus, where each key is a document ID and each value is a dictionary
53 | containing the title and text fields of the document.
54 | queries (`Dict[str, str]`):
55 | A dictionary containing query IDs as keys and the corresponding query texts as values.
56 | results (`Dict[str, Dict[str, float]]`):
57 | A dictionary containing query IDs and the initial retrieval results. Each query ID is mapped to another
58 | dictionary where document IDs are keys and initial retrieval scores are values.
59 | top_k (`int`):
60 | The number of top documents to rerank for each query.
61 | batch_size (`Optional[int]`, *optional*):
62 | The batch size used when passing the query-document pairs through the cross-encoder model.
63 | Defaults to None.
64 | **kwargs:
65 | Additional arguments passed to the cross-encoder model during prediction.
66 |
67 | Returns:
68 | `Dict[str, Dict[str, float]]`:
69 | A dictionary containing query IDs as keys and dictionaries of reranked document IDs and their scores as values.
70 | """
71 | sentence_pairs, pair_ids = [], []
72 |
73 | for query_id in results:
74 | if len(results[query_id]) > top_k:
75 | for doc_id, _ in sorted(
76 | results[query_id].items(), key=lambda item: item[1], reverse=True
77 | )[:top_k]:
78 | pair_ids.append([query_id, doc_id])
79 | corpus_text = (
80 | corpus[doc_id].get("title", "")
81 | + " "
82 | + corpus[doc_id].get("text", "")
83 | ).strip()
84 | sentence_pairs.append([queries[query_id], corpus_text])
85 |
86 | else:
87 | for doc_id in results[query_id]:
88 | pair_ids.append([query_id, doc_id])
89 | corpus_text = (
90 | corpus[doc_id].get("title", "")
91 | + " "
92 | + corpus[doc_id].get("text", "")
93 | ).strip()
94 | sentence_pairs.append([queries[query_id], corpus_text])
95 |
96 | #### Starting to Rerank using cross-attention
97 | logger.info(f"Starting To Rerank Top-{top_k}....")
98 | rerank_scores = [
99 | float(score)
100 | for score in self.model.predict(
101 | sentences=sentence_pairs, batch_size=batch_size, **kwargs
102 | )
103 | ]
104 |
105 | #### Reranker results
106 | self.results = {query_id: {} for query_id in results}
107 | for pair, score in zip(pair_ids, rerank_scores):
108 | query_id, doc_id = pair[0], pair[1]
109 | self.results[query_id][doc_id] = score
110 |
111 | return self.results
--------------------------------------------------------------------------------
/financerag/tasks/TaskMetadata.py:
--------------------------------------------------------------------------------
1 | # Adapted from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/abstasks/TaskMetadata.py
2 | from __future__ import annotations
3 |
4 | import logging
5 | from datetime import date
6 | from typing import Any, Dict, Optional, Union
7 |
8 | from pydantic import AnyUrl, BaseModel, BeforeValidator, TypeAdapter, field_validator
9 | from typing_extensions import Annotated, Literal
10 |
11 | TASK_SUBTYPE = Literal[
12 | "Financial retrieval",
13 | "Question answering",
14 | ]
15 |
16 | TASK_DOMAIN = Literal["Report",]
17 |
18 | SAMPLE_CREATION_METHOD = Literal[
19 | "found",
20 | "human-generated",
21 | "LM-generated and verified",
22 | ]
23 |
24 | TASK_TYPE = Literal["RAG",]
25 |
26 | TASK_CATEGORY = Literal["s2p",] # Sentence-to-paragraph
27 |
28 | ANNOTATOR_TYPE = Literal[
29 | "expert-annotated",
30 | "human-annotated",
31 | "derived",
32 | "LM-generated",
33 | "LM-generated and reviewed", # reviewed by humans
34 | ]
35 |
36 | http_url_adapter = TypeAdapter(AnyUrl)
37 | STR_URL = Annotated[
38 | str, BeforeValidator(lambda value: str(http_url_adapter.validate_python(value)))
39 | ] # Allows the type to be a string, but ensures that the string is a URL
40 |
41 | pastdate_adapter = TypeAdapter(date)
42 | STR_DATE = Annotated[
43 | str, BeforeValidator(lambda value: str(pastdate_adapter.validate_python(value)))
44 | ] # Allows the type to be a string, but ensures that the string is a valid date
45 |
46 | SPLIT_NAME = str
47 | HFSubset = str
48 |
49 | LICENSES = ( # this list can be extended as needed
50 | Literal[ # we use lowercase for the licenses similar to the huggingface datasets
51 | "not specified", # or none found
52 | "mit",
53 | "cc-by-2.0",
54 | "cc-by-3.0",
55 | "cc-by-4.0",
56 | "cc-by-sa-3.0",
57 | "cc-by-sa-4.0",
58 | "cc-by-nc-4.0",
59 | "cc-by-nc-sa-3.0",
60 | "cc-by-nc-sa-4.0",
61 | "cc-by-nc-nd-4.0",
62 | "openrail",
63 | "openrail++",
64 | "odc-by",
65 | "afl-3.0",
66 | "apache-2.0",
67 | "cc-by-nd-2.1-jp",
68 | "cc0-1.0",
69 | "bsd-3-clause",
70 | "gpl-3.0",
71 | "cdla-sharing-1.0",
72 | "mpl-2.0",
73 | ]
74 | )
75 |
76 | METRIC_NAME = str
77 | METRIC_VALUE = Union[int, float, Dict[str, Any]]
78 |
79 | logger = logging.getLogger(__name__)
80 |
81 |
82 | class TaskMetadata(BaseModel):
83 | """
84 | Metadata for a task.
85 |
86 | Args:
87 | dataset: A dictionary containing the arguments to pass to `datasets.load_dataset` to load the dataset for the task. Must include 'path' and *Optional*ly 'revision'.
88 | Refer to https://huggingface.co/docs/datasets/v3.0.0/en/package_reference/loading_methods for more details.
89 | name: The name of the task.
90 | description: A description of the task.
91 | type: (*Optional*) The type of the task, such as "Retrieval" or "Generation". Corresponds to the TASK_TYPE literal.
92 | modalities: The input modality of the dataset. In this case, it is set to ["text"], meaning the dataset deals with textual data.
93 | category: (*Optional*) The category of the task, e.g., "s2p" (sentence-to-paragraph). Corresponds to the TASK_CATEGORY literal.
94 | reference: (*Optional*) A URL to documentation or a published paper about the task. Must be a valid URL.
95 | date: (*Optional*) A tuple containing the start and end dates when the dataset was collected, ensuring the data reflects a certain time frame.
96 | domains: (*Optional*) The domain(s) of the data, e.g., "Report". Defined as TASK_DOMAIN literals.
97 | task_subtypes: (*Optional*) Subtypes of the task, providing more specific details (e.g., "Financial retrieval", "Question answering").
98 | license: (*Optional*) The license under which the dataset is released. Uses a predefined list of licenses (e.g., "cc-by-4.0"), but custom licenses can be provided via URLs.
99 | annotations_creators: (*Optional*) The type of annotators who created or verified the dataset annotations, such as "expert-annotated" or "LM-generated and reviewed".
100 | dialect: (*Optional*) The dialect of the data, if applicable. Ideally specified as a BCP-47 language tag. Empty if no dialects are present.
101 | sample_creation: (*Optional*) The method used to create the dataset samples, such as "found", "human-generated", or "LM-generated and verified".
102 | bibtex_citation: (*Optional*) The BibTeX citation for the dataset. Should be provided if available; otherwise, it is an empty string.
103 |
104 | Methods:
105 | validate_metadata: Validates that the necessary metadata fields (like dataset path and revision) are specified.
106 | is_filled: Checks if all required metadata fields are filled in the TaskMetadata instance.
107 | intext_citation: Generates an in-text citation based on the BibTeX entry provided. If no BibTeX is available, returns an empty string.
108 |
109 | Validators:
110 | _check_dataset_path_is_specified: Ensures that the dataset dictionary contains the 'path' key.
111 | _check_dataset_revision_is_specified: Ensures that the dataset dictionary contains the 'revision' key or provides a warning if it's missing.
112 | """
113 |
114 | dataset: dict
115 |
116 | name: str
117 | description: str
118 | type: Optional[TASK_TYPE] = None
119 | modalities: list[Literal["text"]] = ["text"]
120 | category: Optional[TASK_CATEGORY] = None
121 | reference: Optional[STR_URL] = None
122 |
123 | date: Optional[tuple[STR_DATE, STR_DATE]] = None
124 | domains: Optional[list[TASK_DOMAIN]] = None
125 | task_subtypes: Optional[list[TASK_SUBTYPE]] = None
126 | license: Optional[LICENSES | STR_URL] = None
127 |
128 | annotations_creators: Optional[ANNOTATOR_TYPE] = None
129 | dialect: Optional[list[str]] = None
130 |
131 | sample_creation: Optional[SAMPLE_CREATION_METHOD] = None
132 | bibtex_citation: Optional[str] = None
133 |
134 | @field_validator("dataset")
135 | def _check_dataset_path_is_specified(
136 | cls, dataset: dict[str, Any]
137 | ) -> dict[str, Any]:
138 | if "path" not in dataset:
139 | raise ValueError("Dataset path must be specified")
140 | return dataset
141 |
142 | @field_validator("dataset")
143 | def _check_dataset_subset_is_specified(
144 | cls, dataset: dict[str, Any]
145 | ) -> dict[str, Any]:
146 | if "subset" not in dataset:
147 | raise ValueError("Dataset subset must be specified")
148 | return dataset
149 |
150 | def is_filled(self) -> bool:
151 | """Check if all the metadata fields are filled."""
152 | return all(
153 | getattr(self, field_name) is not None for field_name in self.model_fields
154 | )
155 |
156 | @property
157 | def intext_citation(self, include_cite: bool = True) -> str:
158 | """Create an in-text citation for the dataset."""
159 | cite = ""
160 | if self.bibtex_citation:
161 | cite = f"{self.bibtex_citation.split(',')[0].split('{')[1]}"
162 | if include_cite and cite:
163 | # check for whitespace in the citation
164 | if " " in cite:
165 | logger.warning(
166 | "Citation contains whitespace. Please ensure that the citation is correctly formatted."
167 | )
168 | return f"\\cite{{{cite}}}"
169 | return cite
170 |
--------------------------------------------------------------------------------
/GAR/Embedder_Finetuning.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import pandas as pd
4 | from sklearn.model_selection import train_test_split
5 | from sentence_transformers import SentenceTransformer, InputExample
6 | from sentence_transformers.evaluation import InformationRetrievalEvaluator
7 | from torch.utils.data import DataLoader
8 | from sentence_transformers.datasets import NoDuplicatesDataLoader
9 | from sentence_transformers import losses
10 | from huggingface_hub import HfApi, login
11 | import torch
12 | import gc
13 |
14 | class EmbedderFinetuner:
15 | def __init__(self):
16 | self.model = None
17 |
18 | @staticmethod
19 | def format_query(query: str) -> str:
20 | return f'Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: {query}'
21 |
22 | @staticmethod
23 | def format_text(title: str, text: str) -> str:
24 | return f"Title: {title}\nText: {text}"
25 |
26 | def load_datasets(self, names: list, markdown_dir: str = "Markdown"):
27 | corpus_df = pd.DataFrame()
28 | queries_df = pd.DataFrame()
29 | relevant_docs_data = pd.DataFrame()
30 | all_data = pd.DataFrame()
31 |
32 | for name in names:
33 | print(f"Loading data for {name} ...")
34 | qrels_path = f'{name}_qrels.tsv'
35 | queries_path = os.path.join(markdown_dir, name, 'queries.jsonl')
36 | corpus_path = os.path.join(markdown_dir, name, 'corpus.jsonl')
37 |
38 | qrels = pd.read_csv(qrels_path, sep='\t')
39 | queries_temp = pd.read_json(queries_path, lines=True)
40 | corpus_temp = pd.read_json(corpus_path, lines=True)
41 |
42 | corpus_df = pd.concat([corpus_df, corpus_temp], ignore_index=True)
43 | queries_df = pd.concat([queries_df, queries_temp], ignore_index=True)
44 | relevant_docs_data = pd.concat([relevant_docs_data, qrels], ignore_index=True)
45 |
46 | queries_temp.rename(columns={'_id': 'query_id', 'title': 'title_queries', 'text': 'text_queries'}, inplace=True)
47 | corpus_temp.rename(columns={'_id': 'corpus_id', 'title': 'title_corpus', 'text': 'text_corpus'}, inplace=True)
48 |
49 | data = qrels.merge(queries_temp, on='query_id').merge(corpus_temp, on='corpus_id')
50 | all_data = pd.concat([all_data, data], ignore_index=True)
51 |
52 | return corpus_df, queries_df, relevant_docs_data, all_data
53 |
54 | def split_train_val(self, all_data: pd.DataFrame, relevant_docs_data: pd.DataFrame,
55 | test_size: float = 0.2, random_state: int = 42):
56 | train_rel, val_rel = train_test_split(relevant_docs_data, test_size=test_size, random_state=random_state)
57 | train_data = all_data[all_data['query_id'].isin(train_rel['query_id'])]
58 | val_data = all_data[all_data['query_id'].isin(val_rel['query_id'])]
59 | return train_data, val_data, train_rel, val_rel
60 |
61 | def create_train_samples(self, train_data: pd.DataFrame) -> list:
62 | train_samples = []
63 | for _, row in train_data.iterrows():
64 | sample = InputExample(
65 | texts=[
66 | self.format_query(self.format_text(row['title_queries'], row['text_queries'])),
67 | self.format_text(row['title_corpus'], row['text_corpus'])
68 | ]
69 | )
70 | train_samples.append(sample)
71 | return train_samples
72 |
73 | def add_ir_doc(self, df: pd.DataFrame) -> pd.DataFrame:
74 | irrelevant_docs = []
75 | for _, row in df.iterrows():
76 | text = row['text_corpus']
77 | candidates = df[df['text_corpus'] != text]
78 | if len(candidates) > 0:
79 | irrelevant = candidates.sample(n=1)
80 | irrelevant_docs.append(self.format_text(irrelevant.iloc[0]['title_corpus'], irrelevant.iloc[0]['text_corpus']))
81 | else:
82 | irrelevant_docs.append("")
83 | df = df.copy()
84 | df['irrelevant_docs'] = irrelevant_docs
85 | return df
86 |
87 | def create_eval_examples(self, val_data: pd.DataFrame) -> list:
88 | examples = []
89 | val_data_ir = self.add_ir_doc(val_data)
90 | for _, row in val_data_ir.iterrows():
91 | examples.append(InputExample(
92 | texts=[
93 | self.format_query(self.format_text(row['title_queries'], row['text_queries'])),
94 | self.format_text(row['title_corpus'], row['text_corpus'])
95 | ],
96 | label=1.0
97 | ))
98 | examples.append(InputExample(
99 | texts=[
100 | self.format_query(self.format_text(row['title_queries'], row['text_queries'])),
101 | row['irrelevant_docs']
102 | ],
103 | label=0.0
104 | ))
105 | return examples
106 |
107 | def prepare_corpus_queries(self, corpus_df: pd.DataFrame, queries_df: pd.DataFrame,
108 | val_rel: pd.DataFrame, random_sample: int = 3000):
109 | corpus_df['text'] = corpus_df.apply(lambda row: self.format_text(row['title'], row['text']), axis=1)
110 | corpus_df = corpus_df.drop(columns=['title'])
111 |
112 | queries_df['text'] = queries_df.apply(lambda row: self.format_query(self.format_text(row['title'], row['text'])), axis=1)
113 | queries_df = queries_df.drop(columns=['title'])
114 |
115 | required_corpus_ids = set(map(str, val_rel["corpus_id"]))
116 | all_ids = corpus_df["_id"].tolist()
117 | additional_ids = set(random.sample(all_ids, k=random_sample)) if len(all_ids) >= random_sample else set(all_ids)
118 | required_corpus_ids |= additional_ids
119 |
120 | corpus_df = corpus_df.loc[corpus_df["_id"].astype(str).isin(required_corpus_ids)]
121 | corpus_dict = dict(zip(corpus_df["_id"].astype(str), corpus_df["text"]))
122 | queries_dict = dict(zip(queries_df["_id"].astype(str), queries_df["text"]))
123 |
124 | return corpus_dict, queries_dict
125 |
126 | def build_relevant_docs(self, val_rel: pd.DataFrame) -> dict:
127 | relevant_docs = {}
128 | for qid, corpus_id in zip(val_rel["query_id"], val_rel["corpus_id"]):
129 | qid_str = str(qid)
130 | corpus_str = str(corpus_id)
131 | if qid_str not in relevant_docs:
132 | relevant_docs[qid_str] = set()
133 | relevant_docs[qid_str].add(corpus_str)
134 | return relevant_docs
135 |
136 | def create_ir_evaluator(self, queries: dict, corpus: dict, relevant_docs: dict,
137 | batch_size: int = 4, name: str = "Evaluate") -> InformationRetrievalEvaluator:
138 | return InformationRetrievalEvaluator(
139 | queries=queries,
140 | corpus=corpus,
141 | relevant_docs=relevant_docs,
142 | name=name,
143 | batch_size=batch_size
144 | )
145 |
146 | def train_model(self, model: SentenceTransformer, train_samples: list, evaluator: InformationRetrievalEvaluator,
147 | output_path: str, epochs: int = 2, learning_rate: float = 2e-5,
148 | warmup_ratio: float = 0.1, batch_size: int = 4):
149 | self.model = model
150 | loader = NoDuplicatesDataLoader(train_samples, batch_size=batch_size)
151 | loss = losses.MultipleNegativesRankingLoss(model)
152 | warmup_steps = int(len(loader) * epochs * warmup_ratio)
153 |
154 | model.fit(
155 | train_objectives=[(loader, loss)],
156 | evaluator=evaluator,
157 | epochs=epochs,
158 | warmup_steps=warmup_steps,
159 | output_path=output_path,
160 | optimizer_params={'lr': learning_rate},
161 | show_progress_bar=True,
162 | use_amp=True,
163 | evaluation_steps=len(loader),
164 | save_best_model=True,
165 | )
166 |
167 | def upload_model_to_hub(self, save_path: str, repo_id: str, hf_token: str, repo_owner: str = None):
168 | login(token=hf_token)
169 | api = HfApi()
170 | try:
171 | api.create_repo(repo_id=repo_id)
172 | except Exception as e:
173 | print(f"Repo creation: {e}")
174 | full_repo_id = repo_id if repo_owner is None else f"{repo_owner}/{repo_id}"
175 | api.upload_folder(
176 | folder_path=save_path,
177 | repo_id=full_repo_id,
178 | repo_type="model",
179 | )
180 |
181 | def clear_gpu_memory(self):
182 | torch.cuda.empty_cache()
183 | gc.collect()
184 |
--------------------------------------------------------------------------------
/GAR/hybrid_search.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import logging
4 | import random
5 | import numpy as np
6 | import pandas as pd
7 | import torch
8 | from tqdm import tqdm
9 | from torch.nn.functional import softmax
10 | from rank_bm25 import BM25Okapi
11 | from huggingface_hub import login
12 |
13 |
14 | from sentence_transformers import SentenceTransformer, CrossEncoder
15 | from financerag.rerank import CrossEncoderReranker
16 | from financerag.retrieval import DenseRetrieval, SentenceTransformerEncoder
17 | from financerag.tasks import FinDER, FinQABench, FinanceBench, TATQA, FinQA, ConvFinQA, MultiHiertt
18 | from financerag.common import Retrieval
19 |
20 | logging.basicConfig(level=logging.INFO)
21 |
22 |
23 |
24 | def tokenize_list(input_list):
25 | from nltk.tokenize import word_tokenize
26 | return list(map(word_tokenize, input_list))
27 |
28 |
29 | class BM25Retriever(Retrieval):
30 | def __init__(self, model, tokenizer=tokenize_list):
31 | self.model = model
32 | self.tokenizer = tokenizer
33 | self.results = {}
34 |
35 | def retrieve(self, corpus, queries, top_k=None, score_function=None, return_sorted=False, **kwargs):
36 | query_ids = list(queries.keys())
37 | self.results = {qid: {} for qid in query_ids}
38 | query_lower_tokens = self.tokenizer([queries[qid].lower() for qid in queries])
39 | corpus_ids = list(corpus.keys())
40 |
41 | for qid, query in tqdm(zip(query_ids, query_lower_tokens), total=len(query_ids), desc="BM25 Retrieval"):
42 | scores = self.model.get_scores(query)
43 | top_k_result = np.argsort(scores)[::-1][:top_k]
44 | for idx in top_k_result:
45 | self.results[qid][corpus_ids[idx]] = scores[idx]
46 | return self.results
47 |
48 |
49 | class HybridSearcher:
50 | def __init__(self):
51 | self.retrieval_results = {}
52 |
53 | @staticmethod
54 | def softmax_normalize(retrieval_results):
55 | for query_id, results in retrieval_results.items():
56 | scores = torch.tensor(list(results.values()), dtype=torch.float32)
57 | probs = softmax(scores, dim=0).tolist()
58 | doc_ids = list(results.keys())
59 | for i, doc_id in enumerate(doc_ids):
60 | results[doc_id] = probs[i]
61 | return retrieval_results
62 |
63 | def setup_task(self):
64 | finder_task=FinDER()
65 | finqabench_task=FinQABench()
66 | finqa_task=FinQA()
67 | financebench_task=FinanceBench()
68 | tatqa_task=TATQA()
69 | convfinqa_task=ConvFinQA()
70 | multihiertt_task=MultiHiertt()
71 | tasks=[finder_task, finqabench_task, finqa_task, financebench_task, tatqa_task, convfinqa_task, multihiertt_task]
72 | names=["Finder", "FinQABench", "FinQA", "FinanceBench", "TATQA", "ConvFinQA", "MultiHiertt"]
73 |
74 | for i, task in enumerate(tasks):
75 | task.metadata.dataset['path']='thomaskim1130/FinanceRAG-Lingua'
76 | task.metadata.dataset['qrels']= f'markdown-{names[i]}'
77 | task.load_data()
78 | return tasks, names
79 |
80 | def retrieval_model_setup(self):
81 | login(token=os.getenv('HF_TOKEN'))
82 | stella=SentenceTransformer(
83 | model_name_or_path='thomaskim1130/stella_en_1.5B_v5-FinanceRAG-TAT-MH-v2',
84 | trust_remote_code=True,
85 | query_prompt='Instruct: Given a web search query, retrieve relevant passages that answer the query.\nQuery: ',
86 | doc_prompt='',
87 | config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False}
88 | )
89 | retrieval_stella=DenseRetrieval(
90 | model=stella,
91 | )
92 |
93 | return retrieval_stella
94 |
95 | def reranker_model_setup(self):
96 | reranker=CrossEncoderReranker(
97 | model=CrossEncoder('BAAI/bge-reranker-v2-e3',
98 | trust_remote_code=True,
99 | config_args={"torch_dtype": torch.bfloat16, "attn_implementation": "sdpa", "device_map":"auto", "offload_folder":"offload"}
100 | ),
101 | )
102 | return reranker
103 |
104 | def get_sparse_score(self, task):
105 | document_list = task.corpus
106 | query_list = task.queries
107 | tokenized_corpus = tokenize_list([doc["title"].lower() + ' ' + doc["text"].lower()
108 | for doc in document_list.values()])
109 | bm25 = BM25Okapi(tokenized_corpus)
110 | retriever = BM25Retriever(bm25)
111 | top_k = len(document_list)
112 | sparse_retrieval_results = retriever.retrieve(corpus=document_list, queries=query_list, top_k=top_k)
113 | return sparse_retrieval_results
114 |
115 | def get_dense_score(self, task, retrieval_model):
116 | document_list = task.corpus
117 | query_list = task.queries
118 | dense_retrieval_results = retrieval_model.retrieve(
119 | corpus=document_list,
120 | queries=query_list,
121 | top_k=100
122 | )
123 | return dense_retrieval_results
124 |
125 | def get_hybrid_score(self, task, alpha, retrieval_model):
126 | sparse_alpha = 1 - alpha
127 | hybrid_retrieval_result = {}
128 | dense_retrieval_result = self.get_dense_score(task, retrieval_model)
129 | dense_retrieval_result = self.softmax_normalize(dense_retrieval_result)
130 |
131 | for query_id, results in dense_retrieval_result.items():
132 | for doc_id, dense_score_val in results.items():
133 | if query_id not in hybrid_retrieval_result:
134 | hybrid_retrieval_result[query_id] = {}
135 | hybrid_retrieval_result[query_id][doc_id] = alpha * dense_score_val
136 |
137 | sparse_retrieval_result = self.get_sparse_score(task)
138 | sparse_retrieval_result = self.softmax_normalize(sparse_retrieval_result)
139 |
140 | for query_id, results in sparse_retrieval_result.items():
141 | for doc_id, sparse_score_val in results.items():
142 | if doc_id in hybrid_retrieval_result.get(query_id, {}):
143 | hybrid_retrieval_result[query_id][doc_id] += sparse_alpha * sparse_score_val
144 |
145 | return hybrid_retrieval_result
146 |
147 | def evaluate_hybrid(self, task, qrels_dict, hybrid_retrieval_result):
148 | result = task.evaluate(qrels_dict, hybrid_retrieval_result, [10])
149 | return result[0]['NDCG@10']
150 |
151 | def tune_alpha(self, task, retrieval_model):
152 | alpha_values = np.linspace(0, 1, 41)
153 | ndcg_values = []
154 | qrels_path = f'Dataset/{task}_qrels.tsv'
155 | qrels_df = pd.read_csv(qrels_path, sep='\t')
156 | qrels_dict = qrels_df.groupby('query_id').apply(lambda x: dict(zip(x['corpus_id'], x['score']))).to_dict()
157 |
158 | for alpha in alpha_values:
159 | hybrid_retrieval_result = self.get_hybrid_score(task, alpha=alpha, retrieval_model=retrieval_model)
160 | ndcg_value = self.evaluate_hybrid(task, qrels_dict, hybrid_retrieval_result)
161 | ndcg_values.append(ndcg_value)
162 | max_ndcg_index = np.argmax(ndcg_values)
163 | optimal_alpha = alpha_values[max_ndcg_index]
164 | return optimal_alpha
165 |
166 | def get_reranker_score(self, task, hybrid_retrieval_result, reranker_model):
167 | reranker_result = task.rerank(
168 | reranker=reranker_model,
169 | results=hybrid_retrieval_result,
170 | top_k=10,
171 | batch_size=32
172 | )
173 | return reranker_result
174 |
175 | def get_ndcg_score(self, task, name):
176 | qrels = pd.read_csv(f'Dataset/{name}_qrels.tsv', sep='\t').groupby('query_id').apply(lambda x: dict(zip(x['corpus_id'], x['score']))).to_dict()
177 | return task.evaluate(qrels, task.retrieve_results, [10])[0]['NDCG@10']
178 |
179 | def get_final_ndcg(self, tasks, names):
180 | result = 0
181 | task_lengths = []
182 |
183 | for n, task in enumerate(tasks):
184 | task_lengths.append(len(task.queries))
185 | print(f"{names[n]} : {len(task.queries)} Queries")
186 | result += self.get_ndcg_score(task, names[n])*task_lengths[-1]
187 |
188 | result /= sum(task_lengths)
189 | return result
190 |
191 | @staticmethod
192 | def merge_csv_results(output_dir):
193 | import glob
194 | csv_files = glob.glob(os.path.join(output_dir, "*", "*.csv"))
195 | dataframes = []
196 | for file in csv_files:
197 | df = pd.read_csv(file)
198 | dataframes.append(df)
199 | merged_df = pd.concat(dataframes, ignore_index=True)
200 | merged_csv_path = os.path.join(output_dir, 'merged_output.csv')
201 | merged_df.to_csv(merged_csv_path, index=False)
202 | logging.info(f"Merged CSV saved to {merged_csv_path}")
203 |
--------------------------------------------------------------------------------
/financerag/common/loader.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 | from typing import Optional, Tuple, cast
4 |
5 | from datasets import Dataset, Value, load_dataset
6 |
7 | logger = logging.getLogger(__name__)
8 |
9 | class HFDataLoader:
10 | """
11 | A Hugging Face Dataset loader for corpus and query data. Supports loading datasets from local files
12 | (in JSONL format) or directly from a Hugging Face repository.
13 |
14 | Args:
15 | hf_repo (`str`, *optional*):
16 | The Hugging Face repository containing the dataset. If provided, it overrides the
17 | data folder, prefix, and *_file arguments.
18 | data_folder (`str`, *optional*):
19 | Path to the folder containing the dataset files when loading from local files.
20 | subset (`str`, *optional*):
21 | The subset of the dataset to load (e.g., "FinDER", "FinQA"). Used in both local and HF repo loading.
22 | prefix (`str`, *optional*):
23 | A prefix to add to the file names (e.g., "train_", "test_").
24 | corpus_file (`str`, defaults to `"corpus.jsonl"`):
25 | The filename for the corpus when loading locally.
26 | query_file (`str`, defaults to `"queries.jsonl"`):
27 | The filename for the queries when loading locally.
28 | keep_in_memory (`bool`, defaults to `False`):
29 | Whether to keep the dataset in memory.
30 | """
31 |
32 | def __init__(
33 | self,
34 | hf_repo: Optional[str] = None,
35 | data_folder: Optional[str] = None,
36 | subset: Optional[str] = None,
37 | prefix: Optional[str] = None,
38 | corpus_file: str = "corpus.jsonl",
39 | query_file: str = "queries.jsonl",
40 | keep_in_memory: bool = False,
41 | ):
42 | """
43 | Initializes the HFDataLoader class.
44 |
45 | Args:
46 | hf_repo (`str`, *optional*):
47 | The Hugging Face repository containing the dataset.
48 | data_folder (`str`, *optional*):
49 | Path to the folder containing the dataset files when loading from local files.
50 | subset (`str`, *optional*):
51 | The subset of the dataset to load.
52 | prefix (`str`, *optional*):
53 | A prefix to add to the file names.
54 | corpus_file (`str`, defaults to `"corpus.jsonl"`):
55 | The filename for the corpus when loading locally.
56 | query_file (`str`, defaults to `"queries.jsonl"`):
57 | The filename for the queries when loading locally.
58 | keep_in_memory (`bool`, defaults to `False`):
59 | Whether to keep the dataset in memory.
60 | """
61 | self.corpus: Optional[Dataset] = None
62 | self.queries: Optional[Dataset] = None
63 | self.hf_repo = hf_repo
64 | self.subset = subset
65 | if hf_repo:
66 | logger.warning(
67 | "A Hugging Face repository is provided. This will override the data_folder, prefix, and *_file arguments."
68 | )
69 | else:
70 | if (data_folder is None) or (subset is None):
71 | raise ValueError(
72 | "A Hugging Face repository or local directory is required."
73 | )
74 |
75 | if prefix:
76 | query_file = prefix + "_" + query_file
77 |
78 | self.corpus_file = (
79 | (Path(data_folder) / subset / corpus_file).as_posix()
80 | if data_folder
81 | else corpus_file
82 | )
83 | self.query_file = (
84 | (Path(data_folder) / subset / query_file).as_posix()
85 | if data_folder
86 | else query_file
87 | )
88 | self.streaming = False
89 | self.keep_in_memory = keep_in_memory
90 |
91 | @staticmethod
92 | def check(file_in: str, ext: str):
93 | """
94 | Check if the given file exists and has the correct extension.
95 |
96 | Args:
97 | file_in (`str`): The path of the file to check.
98 | ext (`str`): The expected file extension.
99 |
100 | Raises:
101 | `ValueError`: If the file does not exist or if the extension does not match.
102 | """
103 | if not Path(file_in).exists():
104 | raise ValueError(
105 | "File {} not present! Please provide an accurate file.".format(file_in)
106 | )
107 |
108 | if not file_in.endswith(ext):
109 | raise ValueError(
110 | "File {} must have the extension {}".format(file_in, ext)
111 | )
112 |
113 | def load(self) -> Tuple[Dataset, Dataset]:
114 | """
115 | Loads both the corpus and query datasets. If the datasets are not already loaded,
116 | they are loaded from the specified source (either local files or Hugging Face repository).
117 |
118 | Returns:
119 | `Tuple[Dataset, Dataset]`: A tuple containing the loaded corpus and queries datasets.
120 | """
121 | if not self.hf_repo:
122 | self.check(file_in=self.corpus_file, ext="jsonl")
123 | self.check(file_in=self.query_file, ext="jsonl")
124 |
125 | if self.corpus is None:
126 | logger.info("Loading Corpus...")
127 | self._load_corpus()
128 | self.corpus = cast(Dataset, self.corpus)
129 | logger.info("Loaded %d Documents.", len(self.corpus))
130 | logger.info("Corpus Example: %s", self.corpus[0])
131 |
132 | if self.queries is None:
133 | logger.info("Loading Queries...")
134 | self._load_queries()
135 | self.queries = cast(Dataset, self.queries)
136 |
137 | logger.info("Loaded %d Queries.", len(self.queries))
138 | logger.info("Query Example: %s", self.queries[0])
139 |
140 | return self.corpus, self.queries
141 |
142 | def load_corpus(self) -> Dataset:
143 | """
144 | Loads the corpus dataset. If the corpus is already loaded, returns the existing dataset.
145 |
146 | Returns:
147 | `Dataset`: The loaded corpus dataset.
148 | """
149 | if not self.hf_repo:
150 | self.check(file_in=self.corpus_file, ext="jsonl")
151 |
152 | if (self.corpus is None) or (not len(self.corpus)):
153 | logger.info("Loading Corpus...")
154 | self._load_corpus()
155 | self.corpus = cast(Dataset, self.corpus)
156 | logger.info("Loaded %d Documents.", len(self.corpus))
157 | logger.info("Corpus Example: %s", self.corpus[0])
158 |
159 | return self.corpus
160 |
161 | def _load_corpus(self):
162 | """
163 | Internal method to load the corpus dataset from either local files or Hugging Face repository.
164 | The dataset is processed by renaming and removing unnecessary columns.
165 | """
166 | if self.hf_repo:
167 | corpus_ds = load_dataset(
168 | path=self.hf_repo,
169 | name=self.subset,
170 | split="corpus",
171 | keep_in_memory=self.keep_in_memory,
172 | streaming=self.streaming,
173 | )
174 | else:
175 | corpus_ds = load_dataset(
176 | "json",
177 | data_files=self.corpus_file,
178 | streaming=self.streaming,
179 | keep_in_memory=self.keep_in_memory,
180 | )
181 |
182 | corpus_ds = cast(Dataset, corpus_ds)
183 | corpus_ds = corpus_ds.cast_column("_id", Value("string"))
184 | corpus_ds = corpus_ds.rename_column("_id", "id")
185 | corpus_ds = corpus_ds.remove_columns(
186 | [
187 | col
188 | for col in corpus_ds.column_names
189 | if col not in ["id", "text", "title"]
190 | ]
191 | )
192 | self.corpus = corpus_ds
193 |
194 | def _load_queries(self):
195 | """
196 | Internal method to load the queries dataset from either local files or Hugging Face repository.
197 | The dataset is processed by renaming and removing unnecessary columns.
198 | """
199 | if self.hf_repo:
200 | queries_ds = load_dataset(
201 | path=self.hf_repo,
202 | name=self.subset,
203 | split="queries",
204 | keep_in_memory=self.keep_in_memory,
205 | streaming=self.streaming,
206 | )
207 | else:
208 | queries_ds = load_dataset(
209 | "json",
210 | data_files=self.query_file,
211 | streaming=self.streaming,
212 | keep_in_memory=self.keep_in_memory,
213 | )
214 | queries_ds = cast(Dataset, queries_ds)
215 | queries_ds = queries_ds.cast_column("_id", Value("string"))
216 | queries_ds = queries_ds.rename_column("_id", "id")
217 | queries_ds = queries_ds.remove_columns(
218 | [col for col in queries_ds.column_names if col not in ["id", "text"]]
219 | )
220 | self.queries = queries_ds
--------------------------------------------------------------------------------
/financerag/common/protocols.py:
--------------------------------------------------------------------------------
1 | import abc
2 | from typing import Dict, List, Literal, Optional, Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 |
7 | __all__ = [
8 | "Encoder",
9 | "Lexical",
10 | "Retrieval",
11 | "CrossEncoder",
12 | "Reranker",
13 | "Generator",
14 | ]
15 |
16 |
17 | class Lexical(abc.ABC):
18 | """
19 | Abstract class for lexical models that defines an interface for calculating relevance scores
20 | between a query and a set of documents. This abstract class is designed to be implemented by
21 | classes that calculate document-query relevance using lexical methods such as BM25 or
22 | other term-based approaches.
23 | """
24 |
25 | @abc.abstractmethod
26 | def get_scores(self, query: List[str], **kwargs) -> List[float]:
27 | """
28 | Calculates relevance scores for a given query against a set of documents.
29 |
30 | Args:
31 | query (`List[str]`):
32 | A tokenized query in the form of a list of words. This represents the query
33 | to be evaluated for relevance against the documents.
34 |
35 | Returns:
36 | `List[float]`:
37 | A list of relevance scores, where each score corresponds to the relevance of
38 | a document in the indexed corpus to the provided query.
39 | """
40 | raise NotImplementedError
41 |
42 |
43 | class Encoder(abc.ABC):
44 | """
45 | Abstract class for dense encoders, providing methods to encode texts, queries, and corpora into dense vectors.
46 | """
47 |
48 | @abc.abstractmethod
49 | def encode_queries(
50 | self, queries: List[str], **kwargs
51 | ) -> Union[torch.Tensor, np.ndarray]:
52 | """
53 | Encodes a list of queries into dense vector representations.
54 |
55 | Args:
56 | queries (`List[str]`):
57 | A list of query strings to encode.
58 | **kwargs:
59 | Additional arguments passed to the encoder.
60 |
61 | Returns:
62 | `Union[torch.Tensor, np.ndarray]`:
63 | Encoded queries as a tensor or numpy array.
64 | """
65 | raise NotImplementedError
66 |
67 | def encode_corpus(
68 | self,
69 | corpus: Union[
70 | List[Dict[Literal["title", "text"], str]],
71 | Dict[Literal["title", "text"], List],
72 | ],
73 | **kwargs
74 | ) -> Union[torch.Tensor, np.ndarray]:
75 | """
76 | Encodes a list of corpus documents into dense vector representations.
77 |
78 | Args:
79 | corpus (`Union[List[Dict[Literal["title", "text"], str]], Dict[Literal["title", "text"], List]]`):
80 | A list or dictionary of corpus documents to encode.
81 | **kwargs:
82 | Additional arguments passed to the encoder.
83 |
84 | Returns:
85 | `Union[torch.Tensor, np.ndarray]`:
86 | Encoded corpus documents as a tensor or numpy array.
87 | """
88 | raise NotImplementedError
89 |
90 |
91 | class Retrieval(abc.ABC):
92 | """
93 | Abstract class for retrieval modules, providing a method to search for the most relevant documents based on queries.
94 | """
95 |
96 | @abc.abstractmethod
97 | def retrieve(
98 | self,
99 | corpus: Dict[str, Dict[Literal["title", "text"], str]],
100 | queries: Dict[str, str],
101 | top_k: Optional[int] = None,
102 | score_function: Optional[str] = None,
103 | **kwargs
104 | ) -> Dict[str, Dict[str, float]]:
105 | """
106 | Searches the corpus for the most relevant documents to the given queries.
107 |
108 | Args:
109 | corpus (`Dict[str, Dict[Literal["title", "text"], str]]`):
110 | A dictionary where each key is a document ID and each value is another dictionary containing document fields
111 | (e.g., {'text': str, 'title': str}).
112 | queries (`Dict[str, str]`):
113 | A dictionary where each key is a query ID and each value is the query text.
114 | top_k (`Optional[int]`, *optional*):
115 | The number of top documents to return for each query. If None, return all documents. Defaults to None.
116 | score_function (`Optional[str]`, *optional*):
117 | The scoring function to use when ranking the documents (e.g., 'cosine', 'dot', etc.). Defaults to None.
118 | **kwargs:
119 | Additional arguments passed to the search method.
120 |
121 | Returns:
122 | `Dict[str, Dict[str, float]]`:
123 | A dictionary where each key is a query ID, and each value is another dictionary mapping document IDs to
124 | relevance scores (e.g., {'doc1': 0.9, 'doc2': 0.8}).
125 | """
126 | raise NotImplementedError
127 |
128 |
129 | class CrossEncoder(abc.ABC):
130 | """
131 | Abstract class for rerankers, providing methods to predict sentence similarity and rank documents based on queries.
132 | """
133 |
134 | @abc.abstractmethod
135 | def predict(
136 | self,
137 | sentences: Union[
138 | List[Tuple[str, str]], List[List[str]], Tuple[str, str], List[str]
139 | ],
140 | batch_size: Optional[int] = None,
141 | **kwargs
142 | ) -> Union[torch.Tensor, np.ndarray]:
143 | """
144 | Predicts similarity or relevance scores for pairs of sentences or lists of sentences.
145 |
146 | Args:
147 | sentences (`Union[List[Tuple[str, str]], List[List[str]], Tuple[str, str], List[str]]`):
148 | Sentences to predict similarity scores for. Can be a list of sentence pairs, list of sentence lists,
149 | a single sentence pair, or a list of sentences.
150 | batch_size (`Optional[int]`, *optional*):
151 | Batch size for prediction. Defaults to None.
152 |
153 | Returns:
154 | `Union[torch.Tensor, np.ndarray]`:
155 | Predicted similarity or relevance scores as a tensor or numpy array.
156 | """
157 | raise NotImplementedError
158 |
159 |
160 | class Reranker(abc.ABC):
161 | """
162 | Abstract class for reranking modules that defines methods to rerank search results based on queries.
163 | """
164 |
165 | @abc.abstractmethod
166 | def rerank(
167 | self,
168 | corpus: Dict[str, Dict[str, str]],
169 | queries: Dict[str, str],
170 | results: Dict[str, Dict[str, float]],
171 | top_k: int,
172 | batch_size: Optional[int] = None,
173 | **kwargs
174 | ) -> Dict[str, Dict[str, float]]:
175 | """
176 | Reranks the search results based on the given queries and the initial ranking scores.
177 |
178 | Args:
179 | corpus (`Dict[str, Dict[str, str]]`):
180 | A dictionary where keys are document IDs and values are dictionaries containing
181 | document metadata, such as content or other features.
182 | queries (`Dict[str, str]`):
183 | A dictionary where keys are query IDs and values are the corresponding query texts.
184 | results (`Dict[str, Dict[str, float]]`):
185 | A dictionary where keys are query IDs and values are dictionaries mapping document
186 | IDs to their initial relevance scores.
187 | top_k (`int`):
188 | The number of top documents to rerank.
189 | batch_size (`Optional[int]`, *optional*):
190 | The batch size to use during reranking. Useful for models that process data in
191 | batches. Defaults to None.
192 | **kwargs:
193 | Additional keyword arguments for custom configurations in the reranking process.
194 |
195 | Returns:
196 | `Dict[str, Dict[str, float]]`:
197 | The reranked relevance scores, returned as a dictionary mapping query IDs to dictionaries of document IDs and their scores.
198 | """
199 | raise NotImplementedError
200 |
201 |
202 | class Generator(abc.ABC):
203 | """
204 | Abstract class for text generators, providing methods for generating text completions in a chat-like interface.
205 | """
206 |
207 | @abc.abstractmethod
208 | def generation(
209 | self, messages: Dict[str, List[Dict[str, str]]], **kwargs
210 | ) -> Dict[str, str]:
211 | """
212 | Generates a chat completion based on a sequence of messages.
213 |
214 | Args:
215 | messages (`Dict[str, List[Dict[str, str]]]`):
216 | A list of message dictionaries per `query_id`.
217 | Each dictionary in list must contain:
218 | - 'role' (str): The role of the speaker (e.g., 'user' or 'system').
219 | - 'content' (str): The content of the message.
220 | **kwargs:
221 | Additional arguments passed to the generator.
222 |
223 | Returns:
224 | `Dict[str, str]`:
225 | A dictionary containing the generated response, where each key is the `query_id` and the value is the generated text.
226 | """
227 | raise NotImplementedError
--------------------------------------------------------------------------------
/financerag/retrieval/dense.py:
--------------------------------------------------------------------------------
1 | import heapq
2 | import logging
3 | from typing import Any, Callable, Dict, Literal, Optional
4 |
5 | import torch
6 |
7 | from financerag.common.protocols import Encoder, Retrieval
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | # Copied from https://github.com/beir-cellar/beir/blob/main/beir/retrieval/search/dense/util.py
13 | @torch.no_grad()
14 | def cos_sim(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
15 | """
16 | Computes the cosine similarity between two tensors.
17 |
18 | Args:
19 | a (`torch.Tensor`):
20 | Tensor representing query embeddings.
21 | b (`torch.Tensor`):
22 | Tensor representing corpus embeddings.
23 |
24 | Returns:
25 | `torch.Tensor`:
26 | Cosine similarity scores for all pairs.
27 | """
28 | a = _ensure_tensor(a)
29 | b = _ensure_tensor(b)
30 | return torch.mm(
31 | torch.nn.functional.normalize(a, p=2, dim=1),
32 | torch.nn.functional.normalize(b, p=2, dim=1).transpose(0, 1),
33 | )
34 |
35 |
36 | @torch.no_grad()
37 | def dot_score(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
38 | """
39 | Computes the dot-product score between two tensors.
40 |
41 | Args:
42 | a (`torch.Tensor`):
43 | Tensor representing query embeddings.
44 | b (`torch.Tensor`):
45 | Tensor representing corpus embeddings.
46 |
47 | Returns:
48 | `torch.Tensor`:
49 | Dot-product scores for all pairs.
50 | """
51 | a = _ensure_tensor(a)
52 | b = _ensure_tensor(b)
53 | return torch.mm(a, b.transpose(0, 1))
54 |
55 |
56 | def _ensure_tensor(x: Any) -> torch.Tensor:
57 | """
58 | Ensures the input is a torch.Tensor, converting if necessary.
59 |
60 | Args:
61 | x (`Any`):
62 | Input to be checked.
63 |
64 | Returns:
65 | `torch.Tensor`:
66 | Converted tensor.
67 | """
68 | if not isinstance(x, torch.Tensor):
69 | x = torch.tensor(x)
70 | if len(x.shape) == 1:
71 | x = x.unsqueeze(0)
72 | return x
73 |
74 |
75 | # Adapted from https://github.com/beir-cellar/beir/blob/main/beir/retrieval/search/dense/exact_search.py
76 | class DenseRetrieval(Retrieval):
77 | """
78 | Encoder-based dense retrieval that performs similarity-based search over a corpus.
79 |
80 | This class uses dense embeddings from an encoder model to compute similarity scores (e.g., cosine similarity or
81 | dot product) between query embeddings and corpus embeddings. It retrieves the top-k most relevant documents
82 | based on these scores.
83 | """
84 |
85 | def __init__(
86 | self,
87 | model: Encoder,
88 | batch_size: int = 64,
89 | score_functions: Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] | None = None,
90 | corpus_chunk_size: int = 50000
91 | ):
92 | """
93 | Initializes the DenseRetrieval class.
94 |
95 | Args:
96 | model (`Encoder`):
97 | An encoder model implementing the `Encoder` protocol, responsible for encoding queries and corpus documents.
98 | batch_size (`int`, *optional*, defaults to `64`):
99 | The batch size to use when encoding queries and corpus documents.
100 | score_functions (`Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]]`, *optional*):
101 | A dictionary mapping score function names (e.g., "cos_sim", "dot") to functions that compute similarity
102 | scores between query and corpus embeddings. Defaults to cosine similarity and dot product.
103 | corpus_chunk_size (`int`, *optional*, defaults to `50000`):
104 | The number of documents to process in each batch when encoding the corpus.
105 | """
106 | self.model: Encoder = model
107 | self.batch_size: int = batch_size
108 | if score_functions is None:
109 | score_functions = {"cos_sim": cos_sim, "dot": dot_score}
110 | self.score_functions: Dict[str, Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = score_functions
111 | self.corpus_chunk_size: int = corpus_chunk_size
112 | self.results: Dict = {}
113 |
114 | def retrieve(
115 | self,
116 | corpus: Dict[str, Dict[Literal["title", "text"], str]],
117 | queries: Dict[str, str],
118 | top_k: Optional[int] = None,
119 | score_function: Literal["cos_sim", "dot"] | None = "cos_sim",
120 | return_sorted: bool = False,
121 | **kwargs,
122 | ) -> Dict[str, Dict[str, float]]:
123 | """
124 | Retrieves the top-k most relevant documents from the corpus based on the given queries.
125 |
126 | This method encodes the queries and corpus documents, computes similarity scores using the specified scoring
127 | function, and retrieves the top-k most relevant documents for each query.
128 |
129 | Args:
130 | corpus (`Dict[str, Dict[Literal["title", "text"], str]]`):
131 | A dictionary where each key is a document ID, and each value contains document metadata
132 | such as 'title' and 'text'.
133 | queries (`Dict[str, str]`):
134 | A dictionary where each key is a query ID and each value is the query text.
135 | top_k (`Optional[int]`, *optional*):
136 | The number of top documents to return for each query. If `None`, returns all documents.
137 | return_sorted (`bool`, *optional*, defaults to `False`):
138 | Whether to return the results sorted by score.
139 | score_function (`Literal["cos_sim", "dot"]`, *optional*, defaults to `"cos_sim"`):
140 | The scoring function to use, either 'cos_sim' for cosine similarity or 'dot' for dot product.
141 | **kwargs:
142 | Additional arguments passed to the encoder model.
143 |
144 | Returns:
145 | `Dict[str, Dict[str, float]]`:
146 | A dictionary where each key is a query ID, and the value is another dictionary mapping document
147 | IDs to their similarity scores.
148 | """
149 | if score_function not in self.score_functions:
150 | raise ValueError(
151 | f"Score function: {score_function} must be either 'cos_sim' for cosine similarity or 'dot' for dot product."
152 | )
153 |
154 | logger.info("Encoding queries...")
155 | query_ids = list(queries.keys())
156 | self.results = {qid: {} for qid in query_ids}
157 | query_texts = [queries[qid] for qid in queries]
158 | query_embeddings = self.model.encode_queries(
159 | query_texts, batch_size=self.batch_size, **kwargs
160 | )
161 |
162 | logger.info("Sorting corpus by document length...")
163 | sorted_corpus_ids = sorted(
164 | corpus,
165 | key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")),
166 | reverse=True,
167 | )
168 |
169 | logger.info("Encoding corpus in batches... This may take a while.")
170 | result_heaps = {
171 | qid: [] for qid in query_ids
172 | } # Keep only the top-k docs for each query
173 |
174 | corpus_list = [corpus[cid] for cid in sorted_corpus_ids]
175 |
176 | for batch_num, start_idx in enumerate(
177 | range(0, len(corpus), self.corpus_chunk_size)
178 | ):
179 | logger.info(
180 | f"Encoding batch {batch_num + 1}/{len(range(0, len(corpus_list), self.corpus_chunk_size))}..."
181 | )
182 | end_idx = min(start_idx + self.corpus_chunk_size, len(corpus_list))
183 |
184 | # Encode chunk of corpus
185 | sub_corpus_embeddings = self.model.encode_corpus(
186 | corpus_list[start_idx:end_idx], batch_size=self.batch_size, **kwargs
187 | )
188 |
189 | # Compute similarities using either cosine similarity or dot product
190 | cos_scores = self.score_functions[score_function](
191 | query_embeddings, sub_corpus_embeddings
192 | )
193 | cos_scores[torch.isnan(cos_scores)] = -1
194 |
195 | # Get top-k values
196 | if top_k is None:
197 | top_k = len(cos_scores[1])
198 |
199 | cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk(
200 | cos_scores,
201 | min(top_k + 1, len(cos_scores[1])),
202 | dim=1,
203 | largest=True,
204 | sorted=return_sorted,
205 | )
206 |
207 | cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist()
208 | cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist()
209 |
210 | for query_itr in range(len(query_embeddings)):
211 | query_id = query_ids[query_itr]
212 | for sub_corpus_id, score in zip(
213 | cos_scores_top_k_idx[query_itr], cos_scores_top_k_values[query_itr]
214 | ):
215 | corpus_id = sorted_corpus_ids[start_idx + sub_corpus_id]
216 | if corpus_id != query_id:
217 | if len(result_heaps[query_id]) < top_k:
218 | heapq.heappush(result_heaps[query_id], (score, corpus_id))
219 | else:
220 | heapq.heappushpop(
221 | result_heaps[query_id], (score, corpus_id)
222 | )
223 |
224 | for qid in result_heaps:
225 | for score, corpus_id in result_heaps[qid]:
226 | self.results[qid][corpus_id] = score
227 |
228 | return self.results
--------------------------------------------------------------------------------
/GAR/DPO_agent_Finetuning.py:
--------------------------------------------------------------------------------
1 | # dpo_agent_finetuning.py
2 |
3 | import json
4 | import numpy as np
5 | import openai
6 | from tqdm import tqdm
7 | import pandas as pd
8 | import random
9 |
10 | class dpo_agent_finetuning:
11 | def __init__(self, api_key, default_model="gpt-4o-mini", ft_model="ft:gpt-4o-mini-2024-07-18:personal::AkFnSqzI", top_probs=10):
12 |
13 | self.api_key = api_key
14 | self.default_model = default_model
15 | self.ft_model = ft_model
16 | self.top_probs = top_probs
17 | self.client = openai.OpenAI(api_key=self.api_key)
18 |
19 | def save_from_csv_to_jsonl(self, answer_csv, data_csv, jsonl_file):
20 | answer_df = pd.read_csv(answer_csv)
21 | data_df = pd.read_csv(data_csv)
22 |
23 | financial_prompt = """
24 | You are an expert financial advisor and evaluator focused on improving responses.
25 | Your task is to enhance answers based on detailed evaluation scores while:
26 | - Maintaining factual accuracy with the provided documents
27 | - Ensuring responses are clear and well-structured for financial contexts
28 | - Providing comprehensive answers that address all aspects of the query
29 | - Using professional financial terminology appropriately
30 | """
31 |
32 | with open(jsonl_file, 'w', encoding='utf-8') as f:
33 | for index, row in tqdm(answer_df.iterrows(), total=answer_df.shape[0]):
34 | chunk = data_df[index:index+10]
35 | corpus_list = chunk['Corpus'].tolist()
36 |
37 | query_prompt = f"""
38 | Query: {row['query']}
39 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(corpus_list)])
40 | query_prompt += "\nDo not add any closing pleasantries or phrases like 'please feel free to ask!'"
41 |
42 | choice = random.choice([1, 1, 1, 2])
43 | json_obj = {
44 | "messages": [
45 | {"role": "system", "content": financial_prompt},
46 | {"role": "user", "content": query_prompt},
47 | {"role": "assistant", "content": row[f'response{choice}']}
48 | ]
49 | }
50 | f.write(json.dumps(json_obj, ensure_ascii=False) + "\n")
51 |
52 | def split_jsonl(self, jsonl_file, train_path, eval_path, split_ratio=0.8):
53 |
54 | with open(jsonl_file, "r", encoding="utf-8") as f:
55 | data = [json.loads(line) for line in f]
56 | random.shuffle(data)
57 | split_index = int(len(data) * split_ratio)
58 | train_data = data[:split_index]
59 | eval_data = data[split_index:]
60 | with open(train_path, "w", encoding="utf-8") as f_train:
61 | for record in train_data:
62 | f_train.write(json.dumps(record, ensure_ascii=False) + "\n")
63 | with open(eval_path, "w", encoding="utf-8") as f_eval:
64 | for record in eval_data:
65 | f_eval.write(json.dumps(record, ensure_ascii=False) + "\n")
66 |
67 | def load_jsonl(self, path):
68 |
69 | with open(path, 'r', encoding='utf-8') as file:
70 | data = [json.loads(line) for line in file]
71 | return data
72 |
73 | def calculate_weighted_score(self, response, scores):
74 |
75 | token_probs = [
76 | np.exp(float(dict(response)['choices'][0].logprobs.content[0].top_logprobs[i].logprob))
77 | for i in range(self.top_probs)
78 | ]
79 | weighted_scores = sum([token_probs[i] * scores[i] for i in range(self.top_probs)])
80 | return weighted_scores
81 |
82 | def g_eval(self, query, response, document_list):
83 |
84 | prompt_final_score = f"""
85 | You are an evaluation assistant tasked with assessing the quality of a generated answer.
86 | You will be given one query, one generated answer, and a summarized document related to the query.
87 | Follow these steps to evaluate the answer based on the criteria below.
88 |
89 |
90 | 1. Relevance (0~5): Does the answer directly address the query and is contextually relevant?
91 | 2. Alignment (0~5): Does the answer align with the facts provided in the documents?
92 | 3. Clarity (0~5): Is the answer easy to understand and free of confusion or unnecessary complexity?
93 | 4. Completeness (0~5): Does the answer address all parts of the question thoroughly?
94 | 5. Coherence (0~5): Does the answer follow the following criteria?: Collective properties of all sentences. We match this dimension with DUC's quality problems with structure and consistency.
95 |
96 |
97 | 1. Assign and record each five scores, each score is on a scale of 0 to 5, where 0 is the lowest and 5 is the highest, based on each .
98 | 2. Get the final score by adding up all the recorded scores. The final score must be between 0 and 25.
99 | 3. YOU MUST Return ONLY final score. You must not add any explanations. Do not contain sentences like 'score'. ONLY integer form of the final score, like 11, 28, 30, etc.
100 |
101 | : {query}
102 |
103 | :
104 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(document_list)]) + f"""
105 |
106 | : {dict(response)['choices'][0].message.content}
107 |
108 | : Do NOT add any other text or string, like 'twenty'. ONLY integer form, like 9, 14, 20 etc.
109 | """
110 | cot_response_final_score = self.client.chat.completions.create(
111 | model=self.default_model,
112 | messages=[
113 | {"role": "system", "content": "You are a detailed evaluator. Be sure to respond ONLY in int. (e.g., 11, 20, 19, etc.)"},
114 | {"role": "user", "content": prompt_final_score}
115 | ],
116 | temperature=0,
117 | logprobs=True,
118 | top_logprobs=self.top_probs,
119 | )
120 | measured_score = int(dict(cot_response_final_score)['choices'][0].message.content)
121 | token_scores = []
122 | detailed_scores = {}
123 |
124 | for i in range(self.top_probs):
125 | top_logprob_item = dict(cot_response_final_score)['choices'][0].logprobs.content[0].top_logprobs[i]
126 | temp_token = top_logprob_item.token
127 | temp_probs = np.exp(float(top_logprob_item.logprob)) * 100
128 | detailed_scores[temp_token] = f"{temp_probs}%"
129 | try:
130 | token_scores.append(int(top_logprob_item.token))
131 | except Exception as e:
132 | print(e)
133 | token_scores.append(20)
134 | weighted_score = self.calculate_weighted_score(cot_response_final_score, token_scores)
135 | return measured_score, weighted_score, str(detailed_scores)
136 |
137 | def improve_response(self, query, original_response, document_list, scores_list):
138 |
139 | improvement_prompt = f"""
140 | You need to improve the following answer based on the evaluation scores and criteria.
141 |
142 | : {query}
143 |
144 | :
145 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(document_list)]) + f"""
146 |
147 | : {dict(original_response)['choices'][0].message.content}
148 |
149 |
150 | 1. Relevance (0~5): Does the answer directly address the query and is contextually relevant?
151 | 2. Alignment (0~5): Does the answer align with the facts provided in the documents?
152 | 3. Clarity (0~5): Is the answer easy to understand and free of confusion or unnecessary complexity?
153 | 4. Completeness (0~5): Does the answer address all parts of the question thoroughly?
154 | 5. Coherence (0~5): Does the answer follow the following criteria?: Collective properties of all sentences. We match this dimension with DUC's quality problems with structure and consistency.
155 |
156 | Please provide an improved answer that:
157 | 1. Aims to achieve the highest possible score (25/25)
158 | 2. Focus on creating a response that would:
159 | - Maintain the strengths of the original answer
160 | - Ensure accuracy with the provided documents
161 | - Be clear and well-structured
162 | 3. Your goal is to increase the probability of achieving a perfect score of 25
163 |
164 |
165 | This is the score information which is the probability of the final score. This is the top 10 score information based on the probability. format: {{"score": "probability"}}
166 | {scores_list}
167 |
168 | Provide only the improved answer without any explanations.
169 | """
170 | improved_response = self.client.chat.completions.create(
171 | model=self.default_model,
172 | messages=[
173 | {"role": "system", "content": """
174 | You are an expert financial advisor and evaluator focused on improving responses.
175 | Your task is to enhance answers based on detailed evaluation scores while:
176 | - Maintaining factual accuracy with the provided documents
177 | - Focusing especially on areas that received low scores
178 | - Ensuring responses are clear and well-structured for financial contexts
179 | - Providing comprehensive answers that address all aspects of the query
180 | - Using professional financial terminology appropriately
181 |
182 | You should maintain the strengths of the original response while addressing its weaknesses.
183 | """ },
184 | {"role": "user", "content": improvement_prompt}
185 | ],
186 | temperature=0
187 | )
188 | return improved_response
189 |
190 | def evaluate_corpus_relevance(self, query_line, corpus_list):
191 |
192 | relevancy_instruction = """
193 | You are an expert financial advisor and evaluator focused on improving responses.
194 | Your task is to enhance answers based on detailed evaluation scores while:
195 | - Maintaining factual accuracy with the provided documents
196 | - Ensuring responses are clear and well-structured for financial contexts
197 | - Providing comprehensive answers that address all aspects of the query
198 | - Using professional financial terminology appropriately
199 |
200 | You are given the pair of Query, Corpus (same query)
201 | Out of the 10 documents, only provide the list of indices of those that are RELEVANT (e.g. the content is somehow needed to answer the question), from 0~9.
202 | Example : [0, 2, 8, 9]
203 | """
204 | user_prompt = f"""
205 | Query: {query_line}
206 |
207 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(corpus_list)])
208 | try:
209 | relevancy = eval(self.generate_output(relevancy_instruction, user_prompt, temp=0))
210 | except Exception as error:
211 | print(error)
212 | relevancy = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
213 | if relevancy == []:
214 | relevancy = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
215 | return relevancy
216 |
217 | def answer_query(self, query_line, corpus_list, temp=0, return_raw=False, model_name=None):
218 |
219 | if model_name is None:
220 | model_name = self.default_model
221 | financial_prompt = """
222 | You are an expert financial advisor and evaluator focused on improving responses.
223 | Your task is to enhance answers based on detailed evaluation scores while:
224 | - Maintaining factual accuracy with the provided documents
225 | - Ensuring responses are clear and well-structured for financial contexts
226 | - Providing comprehensive answers that address all aspects of the query
227 | - Using professional financial terminology appropriately
228 | """
229 | query_prompt = f"""
230 | Query: {query_line}
231 |
232 | """ + "\n".join([f"###CORPUS_{i+1}\n{corpus}\n" for i, corpus in enumerate(corpus_list)])
233 | query_prompt += "\nDo not add any closing pleasantries or phrases like 'please feel free to ask!'"
234 | response = self.generate_output(financial_prompt, query_prompt, temp=temp, return_raw=return_raw)
235 | return response
236 |
237 |
238 | def process_finetuning(self, df, submission_output_path, start_index=39610):
239 |
240 | submission_df = pd.DataFrame({'query_id': [], 'response': []})
241 |
242 | for i in tqdm(range(start_index, len(df), 10)):
243 | chunk = df[i:i+10]
244 | query_id = chunk['total_query_id'].iloc[0]
245 | query_line = chunk['total_query'].iloc[0]
246 | corpus_list = chunk['total_corpus'].tolist()
247 |
248 | corpus_rel = self.evaluate_corpus_relevance(query_line, corpus_list)
249 | try:
250 | relevant_corpus = chunk.iloc[corpus_rel]
251 | except Exception as e:
252 | print(e)
253 | relevant_corpus = chunk
254 | relevant_corpus = relevant_corpus['total_corpus'].tolist()
255 |
256 | response = self.answer_query(query_line, relevant_corpus, temp=0, return_raw=True, model_name=self.ft_model)
257 |
258 | measured_score, weighted_score, detailed_scores = self.g_eval(query_line, response, relevant_corpus)
259 |
260 | attempt_count = 0
261 | best_score = weighted_score
262 | best_response = response
263 |
264 | while weighted_score < 22 and attempt_count < 5:
265 | attempt_count += 1
266 | improved_response = self.improve_response(query_line, best_response, relevant_corpus, detailed_scores)
267 | improved_score, improved_weighted, improved_detailed = self.g_eval(query_line, improved_response, relevant_corpus)
268 | if improved_weighted > best_score:
269 | best_score = improved_weighted
270 | best_response = improved_response
271 | weighted_score = improved_weighted
272 | detailed_scores = improved_detailed
273 | else:
274 | break
275 |
276 | submission_df.loc[int(i/10)] = [query_id, best_response.choices[0].message.content]
277 |
278 | submission_df.to_csv(submission_output_path, index=False)
279 | print(f"Submission file saved to {submission_output_path}")
280 |
281 | def upload_finetune_files(self, train_path, eval_path):
282 |
283 | training_file = self.client.files.create(
284 | file=open(train_path, "rb"),
285 | purpose="fine-tune"
286 | )
287 | eval_file = self.client.files.create(
288 | file=open(eval_path, "rb"),
289 | purpose="fine-tune"
290 | )
291 | return training_file, eval_file
292 |
293 | def create_finetune_job(self, training_file, eval_file, model="gpt-4o-mini-2024-07-18", method=None):
294 |
295 | job_response = self.client.fine_tuning.jobs.create(
296 | training_file=training_file.id,
297 | model=model,
298 | validation_file=eval_file.id,
299 | # method 옵션이 필요하면 추가로 전달 가능 (예시 주석 참고)
300 | )
301 | return job_response
302 |
303 | def retrieve_finetune_job(self, job_id):
304 |
305 | return self.client.fine_tuning.jobs.retrieve(job_id)
306 |
--------------------------------------------------------------------------------
/financerag/tasks/BaseTask.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | import logging
4 | import os
5 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
6 |
7 | import pytrec_eval
8 |
9 | from financerag.common import Generator, HFDataLoader, Reranker, Retrieval
10 | from financerag.tasks.TaskMetadata import TaskMetadata
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | # Adapted from https://github.com/embeddings-benchmark/mteb/blob/main/mteb/abstasks/AbsTask.py
16 | class BaseTask:
17 | """
18 | Base class for handling tasks related to document retrieval, reranking, and generation in the finance domain.
19 | The class loads data, handles retrieval and reranking operations, and can generate results using a language model.
20 |
21 | Attributes:
22 | metadata (`TaskMetadata`):
23 | Metadata containing task-specific information, such as dataset paths and subsets.
24 | queries (`Optional[Dict[str, str]]`, defaults to `None`):
25 | A dictionary mapping query IDs to query text.
26 | corpus (`Optional[Dict[str, Dict[str, str]]]`, defaults to `None`):
27 | A dictionary mapping document IDs to a dictionary containing the document title and text.
28 | retrieve_results (`Optional[Dict]`, defaults to `None`):
29 | The results of the retrieval process.
30 | rerank_results (`Optional[Dict]`, defaults to `None`):
31 | The results of the reranking process.
32 | generate_results (`Optional[Dict]`, defaults to `None`):
33 | The results generated by the model.
34 |
35 | Methods:
36 | load_data():
37 | Loads the dataset (queries and corpus) into memory from the provided metadata.
38 | retrieve(retriever: Retrieval, top_k: Optional[int] = 100, **kwargs):
39 | Performs document retrieval based on the given queries and corpus.
40 | rerank(reranker: Reranker, results: Optional[Dict] = None, top_k: Optional[int] = 100, batch_size: Optional[int] = None, **kwargs):
41 | Reranks the retrieved results using the given reranker model.
42 | generate(model: Generator, results: Optional[Dict] = None, prepare_messages: Optional[Callable] = None, **kwargs):
43 | Generates results based on the highest-scoring documents from the reranked or retrieved results.
44 | prepare_generation_inputs(results: Dict, prepare_messages: Callable) -> Dict[str, List[dict]]:
45 | Prepares the input format required for generating results by the model.
46 | save_results(top_k: int = 10, output_dir: Optional[str] = None) -> None:
47 | Saves the results (retrieval, reranking, and generated) to CSV and JSONL files.
48 | """
49 |
50 | def __init__(self, metadata: TaskMetadata):
51 | """
52 | Initializes the BaseTask class with metadata for loading and processing retrieval tasks.
53 |
54 | Args:
55 | metadata (`TaskMetadata`):
56 | Task-specific metadata that contains dataset information and configurations.
57 | """
58 | self.metadata: TaskMetadata = metadata
59 | self.queries: Optional[Dict[str, str]] = None
60 | self.corpus: Optional[Dict[str, Dict[Literal["title", "text"], str]]] = None
61 | self.retrieve_results: Optional[Dict] = None
62 | self.rerank_results: Optional[Dict] = None
63 | self.generate_results: Optional[Dict] = None
64 |
65 | self.load_data()
66 |
67 | @property
68 | def metadata_dict(self) -> Dict[str, Any]:
69 | """
70 | Converts the task metadata into a dictionary format.
71 |
72 | Returns:
73 | `Dict[str, Any]`:
74 | A dictionary representation of the task metadata.
75 | """
76 | return dict(self.metadata)
77 |
78 | def load_data(self):
79 | """
80 | Loads the corpus and queries from the specified dataset path and subset in the metadata.
81 |
82 | Raises:
83 | `ValueError`:
84 | If the dataset cannot be loaded from the specified path and subset.
85 | """
86 | if (self.corpus is None) or (self.queries is None):
87 | dataset_path = self.metadata_dict["dataset"]["path"]
88 | subset = self.metadata_dict["dataset"]["subset"]
89 |
90 | corpus, queries = HFDataLoader(
91 | hf_repo=dataset_path,
92 | subset=subset,
93 | keep_in_memory=False,
94 | ).load()
95 |
96 | self.queries = {query["id"]: query["text"] for query in queries}
97 | self.corpus = {
98 | doc["id"]: {"title": doc["title"], "text": doc["text"]}
99 | for doc in corpus
100 | }
101 |
102 | def retrieve(
103 | self, retriever: Retrieval, top_k: Optional[int] = 100, **kwargs
104 | ) -> Dict[str, Dict[str, float]]:
105 | """
106 | Performs document retrieval using the provided retriever model.
107 |
108 | Args:
109 | retriever (`Retrieval`):
110 | The retrieval model to use for retrieving documents.
111 | top_k (`Optional[int]`, defaults to `100`):
112 | The number of top results to return for each query.
113 | **kwargs:
114 | Additional keyword arguments for the retriever.
115 |
116 | Returns:
117 | `Dict[str, Dict[str, float]]`:
118 | A dictionary where the key is the query ID and the value is another dictionary
119 | mapping document IDs to their retrieval scores.
120 |
121 | Raises:
122 | `TypeError`:
123 | If the `retriever` is not a subclass of `Retrieval`.
124 | `ValueError`:
125 | If the data (corpus or queries) is not loaded before retrieval.
126 | """
127 | if not issubclass(type(retriever), Retrieval):
128 | raise TypeError(f"{type(retriever)} must be a subclass of the `Retrieval` class")
129 |
130 | if (self.corpus is None) or (self.queries is None):
131 | raise ValueError("Data has not been loaded.")
132 |
133 | self.retrieve_results = retriever.retrieve(
134 | queries=self.queries, corpus=self.corpus, top_k=top_k, **kwargs
135 | )
136 |
137 | return self.retrieve_results
138 |
139 | def rerank(
140 | self,
141 | reranker: Reranker,
142 | results: Optional[Dict[str, Dict[str, float]]] = None,
143 | top_k: int = 100,
144 | batch_size: Optional[int] = None,
145 | **kwargs,
146 | ) -> Dict[str, Dict[str, float]]:
147 | """
148 | Reranks the retrieved results using the provided reranker model.
149 |
150 | Args:
151 | reranker (`Reranker`):
152 | The reranker model to use for reranking the retrieved results.
153 | results (`Optional[Dict]`, defaults to `None`):
154 | The initial results to rerank. If not provided, the method uses the retrieval results.
155 | top_k (`Optional[int]`, defaults to `100`):
156 | The number of top results to return after reranking.
157 | batch_size (`Optional[int]`, defaults to `None`):
158 | The batch size to use for reranking.
159 | **kwargs:
160 | Additional keyword arguments for the reranker.
161 |
162 | Returns:
163 | `Dict[str, Dict[str, float]]`:
164 | A dictionary where the key is the query ID and the value is another dictionary
165 | mapping document IDs to reranked scores.
166 |
167 | Raises:
168 | `TypeError`:
169 | If the `reranker` is not a subclass of `Reranker`.
170 | `ValueError`:
171 | If the data (corpus or queries) is not loaded before reranking or both `results` and `retrieve_results` are None.
172 | """
173 | if not issubclass(type(reranker), Reranker):
174 | raise TypeError(f"{type(reranker)} must be a subclass of the `Reranker` class")
175 |
176 | if (self.corpus is None) or (self.queries is None):
177 | raise ValueError("Data has not been loaded.")
178 |
179 | if results is None:
180 | if self.retrieve_results is not None:
181 | results = self.retrieve_results
182 | else:
183 | raise ValueError("Neither retrieve_results nor results can be None simultaneously.")
184 |
185 | self.rerank_results = reranker.rerank(
186 | queries=self.queries,
187 | corpus=self.corpus,
188 | results=results,
189 | top_k=top_k,
190 | batch_size=batch_size,
191 | **kwargs,
192 | )
193 |
194 | return self.rerank_results
195 |
196 | def generate(
197 | self,
198 | model: Generator,
199 | results: Optional[Dict] = None,
200 | prepare_messages: Optional[Callable] = None,
201 | **kwargs,
202 | ) -> Dict[str, str]:
203 | """
204 | Generates responses based on the highest-scoring documents from either the reranked or retrieved results.
205 |
206 | Args:
207 | model (`Generator`):
208 | The model used to generate responses.
209 | results (`Optional[Dict]`, defaults to `None`):
210 | The results to generate responses from. If not provided, uses reranked or retrieved results.
211 | prepare_messages (`Optional[Callable]`, defaults to `None`):
212 | A function to prepare messages for the generation model. If not provided, a default message
213 | preparation function is used.
214 | **kwargs:
215 | Additional keyword arguments for the generation process.
216 |
217 | Returns:
218 | `Dict[str, str]`:
219 | A dictionary where the key is the query ID and the value is the generated response.
220 |
221 | Raises:
222 | `TypeError`:
223 | If the `model` is not a subclass of `Generator`.
224 | `AssertionError`:
225 | If neither rerank_results nor retrieve_results are available for generating responses.
226 | """
227 | if not issubclass(type(model), Generator):
228 | raise TypeError(f"{type(model)} must be a subclass of the `Generator` class")
229 |
230 | if prepare_messages is None:
231 | logger.info(
232 | "No prepare_messages function provided. "
233 | "Using default message preparation function, which selects the highest scored document for each query."
234 | )
235 |
236 | def default_messages(
237 | query: str, documents: List[Tuple[str, float]]
238 | ) -> List[Dict]:
239 | first_document = max(documents, key=lambda x: x[1])[0]
240 | messages = [
241 | {"role": "system", "content": "You are a helpful assistant."},
242 | {
243 | "role": "user",
244 | "content": f"Document: {first_document}"
245 | f"\nGenerate an answer to the question from the document."
246 | f"\nQuestion: {query}",
247 | },
248 | ]
249 | return messages
250 |
251 | prepare_messages = default_messages
252 |
253 | if results is None:
254 | results = (
255 | self.rerank_results
256 | if self.rerank_results is None
257 | else self.retrieve_results
258 | )
259 | assert results is not None, (
260 | "Neither rerank_results nor retrieve_results are available. "
261 | "One of them must be provided."
262 | )
263 |
264 | messages_dict = self.prepare_generation_inputs(results, prepare_messages)
265 | self.generate_results = model.generation(messages_dict, **kwargs)
266 |
267 | return self.generate_results
268 |
269 | def prepare_generation_inputs(
270 | self, results, prepare_messages
271 | ) -> Dict[str, List[dict]]:
272 | """
273 | Prepares the input messages required for the generation model.
274 |
275 | Args:
276 | results (`Dict`):
277 | The results from retrieval or reranking, which are used to generate responses.
278 | prepare_messages (`Callable`):
279 | A function that prepares the messages required for the generation model.
280 |
281 | Returns:
282 | `Dict[str, List[dict]]`:
283 | A dictionary where the key is the query ID and the value is a list of messages (dictionaries)
284 | that will be passed to the generation model.
285 |
286 | Raises:
287 | `ValueError`:
288 | If the data (corpus or queries) is not loaded.
289 | """
290 | if (self.corpus is None) or (self.queries is None):
291 | raise ValueError("Data has not been loaded.")
292 |
293 | messages_dict: Dict[str, List[Dict[str, str]]] = {}
294 | logger.info("Preparing generation inputs for %d queries.", len(results))
295 | for query_id, result in results.items():
296 | query = self.queries[query_id]
297 | documents = [
298 | (self.corpus[doc_id], score) for doc_id, score in result.items()
299 | ]
300 | messages = prepare_messages(query, documents)
301 | messages_dict[query_id] = messages
302 |
303 | logger.info("Successfully prepared generation inputs for all queries.")
304 | return messages_dict
305 |
306 | def save_results(self, top_k: int = 10, output_dir: Optional[str] = None) -> None:
307 | """
308 | Saves the top retrieval or reranking, and generated results to CSV and JSONL files.
309 |
310 | Args:
311 | top_k (`int`, defaults to `10`):
312 | The number of top results to save for each query.
313 | output_dir (`Optional[str]`, defaults to `None`):
314 | The directory where the results should be saved. If not provided, results are not saved.
315 |
316 | Saves:
317 | - Top `top_k` retrieval or reranked results in CSV format.
318 | - Generated responses in JSONL format.
319 | """
320 | # If no output directory is provided, stop saving.
321 | if output_dir is None:
322 | return
323 | # Create the output directory if it does not exist
324 | output_dir = os.path.join(output_dir, self.metadata.name)
325 | os.makedirs(output_dir, exist_ok=True)
326 |
327 | logger.info(f"Output directory set to: {output_dir}")
328 |
329 | # Path to save the CSV file
330 | csv_file_path = os.path.join(output_dir, "results.csv")
331 | logger.info(f"Saving top {top_k} results to CSV file: {csv_file_path}")
332 |
333 | # Determine whether to use rerank results or retrieve results
334 | final_result = (
335 | self.rerank_results
336 | if self.rerank_results is not None
337 | else self.retrieve_results
338 | )
339 |
340 | # Process the final result if it's not None
341 | if final_result is not None:
342 | with open(csv_file_path, mode="w", newline="") as csv_file:
343 | writer = csv.writer(csv_file)
344 | # Write the header to the CSV file
345 | writer.writerow(["query_id", "corpus_id"])
346 | logger.info("Writing header ['query_id', 'corpus_id'] to CSV.")
347 |
348 | # For each query_id, save the top_k corpus_ids sorted by score
349 | for q_id, doc_scores in final_result.items():
350 | # Sort doc_scores by score and select top_k documents
351 | sorted_docs = sorted(
352 | doc_scores.items(), key=lambda item: item[1], reverse=True
353 | )[:top_k]
354 |
355 | # Write the query_id and corpus_id to the CSV
356 | for doc_id, _ in sorted_docs:
357 | writer.writerow([q_id, doc_id])
358 |
359 | logger.info(f"Top {top_k} results saved successfully to {csv_file_path}")
360 |
361 | # Save generate_results to JSON Lines format
362 | if self.generate_results is not None:
363 | jsonl_file_path = os.path.join(output_dir, "output.jsonl")
364 | logger.info(f"Saving generate_results to JSONL file: {jsonl_file_path}")
365 |
366 | with open(jsonl_file_path, "w") as f:
367 | f.writelines(
368 | json.dumps({"query_id": q_id, "answer": answer}) + "\n"
369 | for q_id, answer in self.generate_results.items()
370 | )
371 |
372 | logger.info(f"generate_results saved successfully to {jsonl_file_path}")
373 |
374 | # adapted from https://github.com/beir-cellar/beir/blob/main/beir/retrieval/evaluation.py
375 | @staticmethod
376 | def evaluate(
377 | qrels: Dict[str, Dict[str, int]],
378 | results: Dict[str, Dict[str, float]],
379 | k_values: List[int],
380 | ignore_identical_ids: bool = True
381 | ) -> Tuple[Dict[str, float], Dict[str, float], Dict[str, float], Dict[str, float]]:
382 |
383 | if ignore_identical_ids:
384 | logger.info(
385 | 'For evaluation, we ignore identical query and document ids (default), '
386 | 'please explicitly set ``ignore_identical_ids=False`` to ignore this.')
387 | popped = []
388 | for qid, rels in results.items():
389 | for pid in list(rels):
390 | if qid == pid:
391 | results[qid].pop(pid) # remove identical query-document pairs
392 | popped.append(pid)
393 |
394 | # Filter results to only keep queries that are present in qrels
395 | filtered_results = {qid: rels for qid, rels in results.items() if qid in qrels}
396 |
397 | # Initialize dictionaries for evaluation metrics
398 | ndcg = {}
399 | _map = {}
400 | recall = {}
401 | precision = {}
402 |
403 | # Initialize metric values for each k in k_values
404 | for k in k_values:
405 | ndcg[f"NDCG@{k}"] = 0.0
406 | _map[f"MAP@{k}"] = 0.0
407 | recall[f"Recall@{k}"] = 0.0
408 | precision[f"P@{k}"] = 0.0
409 |
410 | # Define strings for pytrec_eval evaluation
411 | map_string = "map_cut." + ",".join([str(k) for k in k_values])
412 | ndcg_string = "ndcg_cut." + ",".join([str(k) for k in k_values])
413 | recall_string = "recall." + ",".join([str(k) for k in k_values])
414 | precision_string = "P." + ",".join([str(k) for k in k_values])
415 |
416 | # Perform evaluation using pytrec_eval with filtered results
417 | evaluator = pytrec_eval.RelevanceEvaluator(qrels,
418 | {map_string, ndcg_string, recall_string, precision_string})
419 | scores = evaluator.evaluate(filtered_results)
420 |
421 | # Aggregate the scores for each query and each k
422 | for query_id in scores.keys():
423 | for k in k_values:
424 | ndcg[f"NDCG@{k}"] += scores[query_id]["ndcg_cut_" + str(k)]
425 | _map[f"MAP@{k}"] += scores[query_id]["map_cut_" + str(k)]
426 | recall[f"Recall@{k}"] += scores[query_id]["recall_" + str(k)]
427 | precision[f"P@{k}"] += scores[query_id]["P_" + str(k)]
428 |
429 | # Compute the average scores for each k
430 | for k in k_values:
431 | ndcg[f"NDCG@{k}"] = round(ndcg[f"NDCG@{k}"] / len(scores), 5)
432 | _map[f"MAP@{k}"] = round(_map[f"MAP@{k}"] / len(scores), 5)
433 | recall[f"Recall@{k}"] = round(recall[f"Recall@{k}"] / len(scores), 5)
434 | precision[f"P@{k}"] = round(precision[f"P@{k}"] / len(scores), 5)
435 |
436 | # Log the results for each metric
437 | for _eval in [ndcg, _map, recall, precision]:
438 | logger.info("\n")
439 | for k in _eval.keys():
440 | logger.info("{}: {:.4f}".format(k, _eval[k]))
441 |
442 | return ndcg, _map, recall, precision
--------------------------------------------------------------------------------