├── factscore ├── __init__.py ├── abstain_detection.py ├── lm.py ├── clm.py ├── openai_lm.py ├── utils.py ├── download_data.py ├── npm.py ├── retrieval.py └── atomic_facts.py ├── massw ├── models │ ├── __init__.py │ └── mixtral_azure.py ├── __init__.py ├── collection │ ├── __init__.py │ ├── neurips │ │ ├── README.md │ │ └── script.py │ ├── cvpr │ │ ├── README.md │ │ └── script.py │ ├── acl │ │ ├── README.md │ │ └── script.py │ ├── aaai │ │ ├── README.md │ │ └── script.py │ ├── README.md │ └── base.py ├── llm_sim_eval.txt ├── download.py ├── data.py └── metrics.py ├── assets └── context_topic_labels.png ├── requirements.txt ├── setup.py ├── .pylintrc ├── notebooks ├── evaluation_results │ ├── gpt35_quality.csv │ ├── gpt4_quality.csv │ ├── human_agreement.csv │ └── mistral_quality.csv ├── prompt.txt └── generate_table.ipynb ├── .github ├── workflows │ └── pylint.yml └── PULL_REQUEST_TEMPLATE.md ├── benchmark └── aspect_prediction │ ├── utils.py │ ├── README.md │ ├── eval.py │ ├── prompts.py │ └── task.py ├── .gitignore ├── prompts ├── method.json ├── future.json ├── outcome.json ├── key_idea.json └── context.json ├── factscore.ipynb ├── LICENSE └── data ├── few_shot_examples.json └── cot_examples.json /factscore/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /massw/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Init.""" 2 | -------------------------------------------------------------------------------- /massw/__init__.py: -------------------------------------------------------------------------------- 1 | """Init file for massw package.""" 2 | -------------------------------------------------------------------------------- /assets/context_topic_labels.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xingjian-zhang/massw/HEAD/assets/context_topic_labels.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bert-score 2 | evaluate 3 | jsonlines 4 | numpy 5 | openai 6 | pandas 7 | rouge_score 8 | sentence_transformers 9 | tqdm 10 | wget 11 | -------------------------------------------------------------------------------- /massw/collection/__init__.py: -------------------------------------------------------------------------------- 1 | """Collection of papers from academic conferences.""" 2 | 3 | from .base import BaseCollection, Paper 4 | 5 | __all__ = ["BaseCollection", "Paper"] 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup script for massw.""" 2 | from setuptools import setup, find_packages 3 | 4 | setup( 5 | name='massw', 6 | version='0.1.0', 7 | packages=find_packages(), 8 | ) 9 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MESSAGES CONTROL] 2 | disable=wrong-import-position,import-error,wrong-import-order,protected-access,too-many-instance-attributes,too-many-arguments,broad-exception-caught,attribute-defined-outside-init,dangerous-default-value,invalid-name 3 | -------------------------------------------------------------------------------- /notebooks/evaluation_results/gpt35_quality.csv: -------------------------------------------------------------------------------- 1 | ,context,key_idea,method,outcome,projected_impact 2 | Cosine Embedding,0.934,0.936,0.895,0.928,0.876 3 | BLEURT,0.597,0.575,0.510,0.608,0.498 4 | BERTScore-f1,0.934,0.927,0.910,0.934,0.905 5 | BLEU,0.524,0.439,0.197,0.452,0.170 6 | ROUGE-1,0.635,0.582,0.445,0.626,0.371 7 | -------------------------------------------------------------------------------- /notebooks/evaluation_results/gpt4_quality.csv: -------------------------------------------------------------------------------- 1 | ,context,key_idea,method,outcome,projected_impact 2 | Cosine Embedding,0.940,0.944,0.894,0.931,0.916 3 | BLEURT,0.607,0.582,0.510,0.603,0.611 4 | BERTScore-f1,0.934,0.928,0.908,0.933,0.933 5 | BLEU,0.384,0.375,0.197,0.355,0.282 6 | ROUGE-1,0.604,0.572,0.450,0.596,0.563 7 | -------------------------------------------------------------------------------- /notebooks/evaluation_results/human_agreement.csv: -------------------------------------------------------------------------------- 1 | ,context,key_idea,method,outcome,projected_impact 2 | Cosine Embedding,0.935,0.944,0.900,0.936,0.941 3 | BLEURT,0.656,0.618,0.559,0.671,0.742 4 | BERTScore-f1,0.942,0.938,0.924,0.950,0.955 5 | BLEU,0.594,0.464,0.357,0.608,0.642 6 | ROUGE-1,0.703,0.637,0.540,0.737,0.748 7 | -------------------------------------------------------------------------------- /notebooks/evaluation_results/mistral_quality.csv: -------------------------------------------------------------------------------- 1 | ,context,key_idea,method,outcome,projected_impact 2 | Cosine Embedding,0.944,0.949,0.905,0.933,0.917 3 | BLEURT,0.645,0.636,0.554,0.674,0.635 4 | BERTScore-f1,0.946,0.943,0.920,0.948,0.936 5 | BLEU,0.590,0.556,0.295,0.665,0.384 6 | ROUGE-1,0.693,0.662,0.509,0.707,0.599 7 | -------------------------------------------------------------------------------- /massw/collection/neurips/README.md: -------------------------------------------------------------------------------- 1 | # NeurIPS Collection 2 | 3 | 1. Download the raw metadata from `https://neurips.cc/Downloads/{year}` in `tsv` format. Only `Poster` type papers are included. 4 | 2. Run the script to collect the paper pdf urls. (~30 minutes) 5 | 3. Run the script to download the pdfs. (~30 minutes) 6 | 7 | Statistics: 8 | -------------------------------------------------------------------------------- /massw/collection/cvpr/README.md: -------------------------------------------------------------------------------- 1 | # CVPR Collection 2 | 3 | 1. Metadata Crawling: 4 | The script extracts raw metadata from the CVPR website `https://openaccess.thecvf.com/CVPR{year}?day=all` and saves the data in TSV format. It contains about 2716 papers 5 | 6 | 3. Collecting PDF URLs: 7 | Run the script to extract and collect the URLs of the paper PDFs. This process takes approximately 1 hour and 30 minutes. 8 | 9 | 4. Downloading PDFs: 10 | After collecting the URLs, run the download_pdfs() function to fetch all the PDFs. This step is estimated to take around 10 minutes. 11 | 12 | # Statistic 13 | Our script retrieves metadata for a total of 2716 papers from the CVPR 2024 conference. 14 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: [pull_request, workflow_dispatch] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.8"] 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: ${{ matrix.python-version }} 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install numpy torch scipy 21 | pip install --force-reinstall pylint==2.17.5 22 | - name: Analysing the code with pylint 23 | run: | 24 | pylint massw/ --rcfile .pylintrc --recursive y 25 | -------------------------------------------------------------------------------- /massw/collection/acl/README.md: -------------------------------------------------------------------------------- 1 | # ACL Collection 2 | 3 | 1. Metadata Crawling: 4 | The script extracts raw metadata from the ACL website `https://aclanthology.org/events/acl-{year}/` and saves the data in TSV format. We only consider metadata for all 941 long and short papers presented at the ACL conference. 5 | 6 | 3. Collecting PDF URLs: 7 | Run the script to extract and collect the URLs of the paper PDFs. This process takes approximately 50 minutes. 8 | 9 | 4. Downloading PDFs: 10 | After collecting the URLs, run the download_pdfs() function to fetch all the PDFs. This step is estimated to take around 10 minutes. 11 | 12 | # Statistic 13 | Our script retrieves metadata for a total of 941 papers from the ACL 2024 conference, including both long and short papers. This figure is consistent with the official records available on the ACL website, ensuring the accuracy of the collected data. 14 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Description 4 | 5 | 6 | ## Related Issue 7 | 8 | 9 | 10 | 11 | 12 | ## Motivation and Context 13 | 14 | 15 | 16 | ## How Has This Been Tested? 17 | 18 | 19 | 20 | 21 | ## Screenshots (if appropriate): 22 | -------------------------------------------------------------------------------- /massw/llm_sim_eval.txt: -------------------------------------------------------------------------------- 1 | You are an expert in Computer Science with a specialization in text analysis, 2 | you are tasked to evaluate the semantic similarity between pairs of abstractive 3 | summarizations from scientific papers. Each summarization pertains to the same 4 | aspect (Context, Key Idea, Method, Outcome, or Projected Impact) of the same 5 | paper. 6 | 7 | For each pair of summarizations, classify the aspect, and assign a similarity score 8 | on a scale from 1 to 10, where 1 indicates completely dissimilar and 10 9 | indicates identical content. Before scoring, include a brief justification for 10 | your score. You should output your results in JSON format as shown in the example. 11 | 12 | Example Input: 13 | Input 1: The experiments demonstrated a 20% increase in efficiency, confirming the proposed model's effectiveness. 14 | Input 2: Results show that the new model outperforms existing ones by improving efficiency by approximately 20%. 15 | 16 | Example JSON Output: 17 | { 18 | "aspect": "Outcome", 19 | "score": 9, 20 | "justification": "Both texts describe similar measurable improvements in efficiency, closely aligning in their depiction of the model's effectiveness." 21 | } 22 | -------------------------------------------------------------------------------- /massw/collection/aaai/README.md: -------------------------------------------------------------------------------- 1 | # AAAI Collection 2 | 3 | 1. Metadata Crawling: 4 | The script crawls raw metadata from the AAAI archive website `https://ojs.aaai.org/index.php/AAAI/issue/archive` and saves the data in TSV format. This archive contains all the AAAI conference papers. For example, the archive currently lists 21 technical tracks in 2024. If the conference adds more papers or if a year’s papers are split across multiple pages, the script will adapt by checking URLs formatted like `https://ojs.aaai.org/index.php/AAAI/issue/archive/{index}` and gathering all relevant pages. 5 | 6 | 2. Collecting PDF URLs: 7 | Run the script to extract and collect the URLs of the paper PDFs. This process takes approximately 1 hour and 20 minutes. 8 | 9 | 3. Downloading PDFs: 10 | After collecting the URLs, run the download_pdfs() function to fetch all the PDFs. This step is estimated to take around 1 hour and 30 minutes. 11 | 12 | # Statistic 13 | Our script retrieves metadata for 2,865 papers from the AAAI-2024 conference. This number aligns closely with external sources, some of which report figures like 2,867 papers`https://dblp.org/db/conf/aaai/aaai2024.html`, and it confirms that the conference accepted over 2,500 papers overall `https://www.paperdigest.org/2024/02/aaai-2024-papers-highlights/` This comprehensive metadata collection underscores the scale and breadth of contributions presented at AAAI-2024. 14 | 15 | 16 | -------------------------------------------------------------------------------- /massw/download.py: -------------------------------------------------------------------------------- 1 | """Data loading and processing utilities.""" 2 | import os 3 | import sys 4 | 5 | import wget 6 | 7 | # Setting the project directory relative to this script's location 8 | PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 9 | 10 | 11 | def download_dataset(version="v1"): 12 | """Download the dataset from remote storage.""" 13 | urls = { 14 | "v1": { 15 | "massw_metadata_v1.jsonl": 16 | "https://www.dropbox.com/scl/fi/r2jlil9lj0ypo2fpl3fxa/\ 17 | massw_metadata_v1.jsonl?rlkey=ohnriak63x4ekyli25naajp0q&dl=1", 18 | "massw_v1.tsv": 19 | "https://www.dropbox.com/scl/fi/ykkrpf269fikuchy429l7/\ 20 | massw_v1.tsv?rlkey=mssrbgz3k8adij1moxqtj34ie&dl=1", 21 | } 22 | } 23 | try: 24 | files = urls[version] 25 | except KeyError as e: 26 | raise ValueError( 27 | f"Invalid version: {version}.\ 28 | Choose from {list(urls.keys())}") from e 29 | for filename, url in files.items(): 30 | print(f"Downloading {filename}...") 31 | # Constructing the output path 32 | out_path = os.path.join(PROJECT_DIR, "data", filename) 33 | if os.path.exists(out_path): 34 | print(f"{filename} already exists. Skipping download.") 35 | continue 36 | wget.download(url, out=out_path, bar=bar_progress) 37 | 38 | 39 | def bar_progress(current, total, width=80): 40 | """Display a progress bar for the download.""" 41 | progress_message = f"Downloading: {current / total * 100:.0f}% \ 42 | [{current} / {total}] bytes" 43 | # Don't use print() as it will print in new line every time. 44 | width = min(width, 100) 45 | sys.stdout.write("\r" + progress_message) 46 | sys.stdout.flush() 47 | 48 | 49 | if __name__ == "__main__": 50 | download_dataset() 51 | -------------------------------------------------------------------------------- /factscore/abstain_detection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | 4 | invalid_ppl_mentions = [ 5 | "I could not find any information", 6 | "The search results do not provide", 7 | "There is no information", 8 | "There are no search results", 9 | "there are no provided search results", 10 | "not provided in the search results", 11 | "is not mentioned in the provided search results", 12 | "There seems to be a mistake in the question", 13 | "Not sources found", 14 | "No sources found", 15 | "Try a more general question" 16 | ] 17 | 18 | def remove_citation(text): 19 | # text = re.sub(r'\[\d+\]', '', text) 20 | text = re.sub(r"\s*\[\d+\]\s*","", text) 21 | if text.startswith("According to , "): 22 | text = text.replace("According to , ", "According to the search results, ") 23 | return text 24 | 25 | def is_invalid_ppl(text): 26 | return np.any([text.lower().startswith(mention.lower()) for mention in invalid_ppl_mentions]) 27 | 28 | def is_invalid_paragraph_ppl(text): 29 | return len(text.strip())==0 or np.any([mention.lower() in text.lower() for mention in invalid_ppl_mentions]) 30 | 31 | def perplexity_ai_abstain_detect(generation): 32 | output = remove_citation(generation) 33 | if is_invalid_ppl(output): 34 | return True 35 | valid_paras = [] 36 | for para in output.split("\n\n"): 37 | if is_invalid_paragraph_ppl(para): 38 | break 39 | valid_paras.append(para.strip()) 40 | 41 | if len(valid_paras) == 0: 42 | return True 43 | else: 44 | return False 45 | 46 | def generic_abstain_detect(generation): 47 | return generation.startswith("I'm sorry") or "provide more" in generation 48 | 49 | def is_response_abstained(generation, fn_type): 50 | if fn_type == "perplexity_ai": 51 | return perplexity_ai_abstain_detect(generation) 52 | 53 | elif fn_type == "generic": 54 | return generic_abstain_detect(generation) 55 | 56 | else: 57 | return False 58 | 59 | -------------------------------------------------------------------------------- /massw/collection/README.md: -------------------------------------------------------------------------------- 1 | # Conference Paper Collection 2 | 3 | This folder contains the script and instructions for collecting the conference paper metadata and full text. 4 | 5 | ## Supported conferences 6 | - ✅ - Supported 7 | - ❌ - Unsupported 8 | - 🚧 - Work in Progress 9 | 10 | | Conference | Metadata | PDF | Parsed Full Text | 11 | |:----------:|:--------:|:---:|:----------------:| 12 | | AAAI | ✅ | ✅ | 🚧 | 13 | | ACL | ✅ | ✅ | 🚧 | 14 | | CHI | 🚧 | 🚧 | 🚧 | 15 | | CVPR | ✅ | ✅ | 🚧 | 16 | | ECCV | 🚧 | 🚧 | 🚧 | 17 | | EMNLP | 🚧 | 🚧 | 🚧 | 18 | | ICCV | 🚧 | 🚧 | 🚧 | 19 | | ICLR | 🚧 | 🚧 | 🚧 | 20 | | ICML | 🚧 | 🚧 | 🚧 | 21 | | IJCAI | 🚧 | 🚧 | 🚧 | 22 | | KDD | 🚧 | 🚧 | 🚧 | 23 | | NAACL | 🚧 | 🚧 | 🚧 | 24 | | NeurIPS | ✅ | ✅ | 🚧 | 25 | | SIGIR | 🚧 | 🚧 | 🚧 | 26 | | SIGMOD | 🚧 | 🚧 | 🚧 | 27 | | VLDB | 🚧 | 🚧 | 🚧 | 28 | | WWW | 🚧 | 🚧 | 🚧 | 29 | 30 | ## Folder Structure for each conference 31 | 32 | Each conference has its own folder. In the folder, there are 33 | - a `README.md` file for instructions 34 | - a `script.py` file for scraping the metadata and full text 35 | - multiple `data_` folders for storing the metadata and full text 36 | - `metadata.tsv`: Metadata for the conference papers. 37 | - `pdf`: Folder for storing the PDF files. 38 | - `.pdf`: PDF file for the conference paper. 39 | - `parsed_full_text`: Folder for storing the parsed full text. 40 | - `.txt`: Parsed full text for the conference paper. 41 | 42 | ``` 43 | massw 44 | ├── collection 45 | ├── aaai 46 | ├── README.md 47 | ├── script.py 48 | ├── data_2024 49 | ├── metadata.tsv 50 | ├── pdf 51 | └── parsed_full_text 52 | ├── data_2025 53 | ├── metadata.tsv 54 | ├── pdf 55 | └── parsed_full_text 56 | ``` 57 | -------------------------------------------------------------------------------- /factscore/lm.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import time 4 | 5 | class LM(object): 6 | 7 | def __init__(self, cache_file): 8 | self.cache_file = cache_file 9 | self.cache_dict = self.load_cache() 10 | self.model = None 11 | self.add_n = 0 12 | 13 | def load_model(self): 14 | # load the model and put it as self.model 15 | raise NotImplementedError() 16 | 17 | def generate(self, prompt, sample_idx=0, max_sequence_length=2048, max_output_length=128): 18 | prompt = prompt.strip() # it's important not to end with a whitespace 19 | cache_key = f"{prompt}_{sample_idx}" 20 | 21 | if cache_key in self.cache_dict: 22 | return self.cache_dict[cache_key] 23 | 24 | if self.model is None: 25 | self.load_model() 26 | 27 | if prompt.endswith(" True or False?\nAnswer:"): 28 | generated = self._generate(prompt, max_sequence_length=max_sequence_length, max_output_length=1) 29 | else: 30 | generated = self._generate(prompt, max_sequence_length=max_sequence_length, max_output_length=max_output_length) 31 | 32 | self.cache_dict[cache_key] = generated 33 | self.add_n += 1 34 | return generated 35 | 36 | def save_cache(self): 37 | if self.add_n == 0: 38 | return 39 | 40 | # load the latest cache first, since if there were other processes running in parallel, cache might have been updated 41 | for k, v in self.load_cache().items(): 42 | self.cache_dict[k] = v 43 | 44 | with open(self.cache_file, "wb") as f: 45 | pickle.dump(self.cache_dict, f) 46 | 47 | def load_cache(self, allow_retry=True): 48 | if os.path.exists(self.cache_file): 49 | while True: 50 | try: 51 | with open(self.cache_file, "rb") as f: 52 | cache = pickle.load(f) 53 | break 54 | except Exception: 55 | if not allow_retry: 56 | assert False 57 | print ("Pickle Error: Retry in 5sec...") 58 | time.sleep(5) 59 | else: 60 | cache = {} 61 | return cache 62 | 63 | 64 | 65 | -------------------------------------------------------------------------------- /notebooks/prompt.txt: -------------------------------------------------------------------------------- 1 | ## Instructions 2 | You are an expert in computer science. 3 | Your task is to summarize the following five aspects of the papers given the 4 | definitions below. 5 | 6 | ## Definitions of Aspects 7 | ### Context 8 | - The status quo of related literature or reality which motivated this study. 9 | This could normally be a problem, a research question, or a research gap that 10 | has not been successfully addressed by previous work. 11 | - Anything happened before this study. 12 | ### Key Idea 13 | - The main intellectual merit of this paper, often in comparison to the 14 | context. This could normally be a novel idea or solution proposed in this paper 15 | that distinguishes it from what’s already done in literature. 16 | - Proposed in this study. 17 | ### Method (Validation Methodology) 18 | - The specific experiment or proof that investigates and validates the key idea. 19 | - CS papers often refer "Method" as algorithm or model, but our definition here is **different**. 20 | - Performed in this study. 21 | ### Outcome 22 | - The factual statement about the study output. This could be the experiment 23 | results and any other measurable outcome that has occurred. It marks whether 24 | the key hypothesis is testified or not. 25 | - Produced in this study. 26 | ### Future Impact 27 | - The impact of the work on the field explicitly anticipated by the authors, 28 | and potential further research explicitly identified by the author that may improve or 29 | extend this study. 30 | 31 | ## Notes 32 | - If an aspect is NOT mentioned in the abstract, mark it as “N/A” (not 33 | applicable). DO NOT come up with your own interpretation. 34 | - Each aspect should be summarized in 1-2 sentences in most cases. 35 | - Each aspect should be self-contained and should not contain references 36 | including other aspects (cross-reference). 37 | - Including specific names of proposed models, datasets, etc., in the summary is acceptable. 38 | - If the problem definition is novel (e.g., proposing a new task), classify it as a Key Idea. 39 | - Non-measurable outcomes should be categorized as Future Impact. 40 | - Impacts that have already occurred should be considered as Outcome. 41 | - A new observation that motivates the proposal of a key idea should be classified under Key Idea. 42 | - Future Impact should not account for real impacts, such as the number of 43 | citations a paper has received. 44 | 45 | ## Response Format 46 | The response should be a JSON object in the following format: 47 | ```json 48 | { 49 | "Context": "...", 50 | "Key Idea": "...", 51 | "Method": "...", 52 | "Outcome": "...", 53 | "Future Impact": "..." 54 | } 55 | ``` 56 | -------------------------------------------------------------------------------- /benchmark/aspect_prediction/utils.py: -------------------------------------------------------------------------------- 1 | """This script includes utility functions for benchmarking scripts.""" 2 | import json 3 | import os 4 | import ssl 5 | 6 | MODEL_CHOICES = ["gpt-35-turbo", "gpt-4", "mixtral-8x7b"] 7 | PROMPT_CHOICES = ["zero-shot", "few-shot", "chain-of-thought", "few-shot-cot"] 8 | 9 | 10 | def allow_self_signed_https(allowed): 11 | """Config SSL settings to allow self-signed certificates.""" 12 | if allowed and not os.environ.get('PYTHONHTTPSVERIFY', '') and getattr( 13 | ssl, '_create_unverified_context', None): 14 | ssl._create_default_https_context = ssl._create_unverified_context 15 | 16 | 17 | def load_examples(): 18 | """Load few-shot and chain-of-thought examples from files.""" 19 | with open("data/few_shot_examples.json", "r", encoding="utf-8") as f: 20 | few_shot_examples = json.load(f) 21 | with open("data/cot_examples.json", "r", encoding="utf-8") as f: 22 | cot_examples = json.load(f) 23 | return few_shot_examples, cot_examples 24 | 25 | 26 | def save_results(chat_results, output_dir, task_name): 27 | """ 28 | Save the chat results to a TSV file. 29 | 30 | Args: 31 | chat_results (DataFrame): DataFrame containing the chat results. 32 | output_dir (str): Directory path to save the result files. 33 | task_name (str): Name of task which will be used to name the output. 34 | """ 35 | os.makedirs(output_dir, exist_ok=True) 36 | 37 | # Serializing dictionary entries in 'output' column, if present 38 | if 'output' in chat_results.columns: 39 | chat_results['output'] = chat_results['output'].apply( 40 | lambda x: json.dumps(x) if isinstance(x, dict) else x 41 | ) 42 | 43 | output_path = os.path.join(output_dir, f"{task_name}.tsv") 44 | chat_results.to_csv(output_path, sep="\t", index=False) 45 | 46 | 47 | TASK_NAMES = [ 48 | "idea_generation", 49 | "method_recommendation", 50 | "outcome_prediction", 51 | "future_work_recommendation", 52 | "title_prediction", 53 | ] 54 | 55 | TASK2GT = { 56 | "idea_generation": "key_idea", 57 | "method_recommendation": "method", 58 | "outcome_prediction": "outcome", 59 | "future_work_recommendation": "future_impact", 60 | "title_prediction": "title", 61 | } 62 | 63 | 64 | def postprocess_cot(output: str): 65 | """ 66 | Extract the actual prediction from the output string. 67 | 68 | Args: 69 | output (str): The output string containing the prediction. 70 | 71 | Returns: 72 | str: The extracted prediction or the original output 73 | if no marker is found. 74 | """ 75 | marker_index = output.find("Prediction:") 76 | if marker_index != -1: 77 | actual_prediction = output[marker_index + len("Prediction:"):].strip() 78 | return actual_prediction 79 | 80 | return output 81 | -------------------------------------------------------------------------------- /factscore/clm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import math 9 | import time 10 | import json 11 | import numpy as np 12 | import torch 13 | from tqdm import tqdm 14 | from collections import defaultdict 15 | 16 | from transformers import AutoModelForCausalLM 17 | from transformers import LlamaTokenizer 18 | 19 | from factscore.utils import convert_model_to_int8_on_gpu 20 | from factscore.lm import LM 21 | 22 | class CLM(LM): 23 | def __init__(self, model_name, model_dir, cache_file=None): 24 | self.model_name = model_name 25 | self.model_dir = model_dir 26 | if cache_file: 27 | super().__init__(cache_file) 28 | 29 | def load_model(self): 30 | self.model = AutoModelForCausalLM.from_pretrained(self.model_dir) 31 | self.model = convert_model_to_int8_on_gpu(self.model, device='cuda') 32 | self.tokenizer = LlamaTokenizer.from_pretrained(self.model_dir) 33 | 34 | def _generate(self, prompts, max_sequence_length=2048, max_output_length=128, 35 | end_if_newline=False, end_if_second_newline=False, verbose=False): 36 | is_single = type(prompts)==str 37 | if is_single: 38 | prompts = [prompts] 39 | 40 | input_ids = self.tokenizer(prompts).input_ids 41 | if verbose: 42 | input_ids = tqdm(input_ids) 43 | 44 | generations = [] 45 | scores = [] 46 | for curr_input_ids in input_ids: 47 | if len(curr_input_ids) > max_sequence_length - max_output_length: 48 | curr_input_ids = curr_input_ids[-(max_sequence_length - max_output_length):] 49 | curr_input_ids = torch.LongTensor([curr_input_ids]).cuda() 50 | gen_outputs = self.model.generate( 51 | curr_input_ids, 52 | max_length=curr_input_ids.shape[1]+max_output_length, 53 | return_dict_in_generate=True, 54 | output_scores=True 55 | ) 56 | gen_tokens = gen_outputs["sequences"] 57 | # saving the logits for the very first token 58 | gen_scores = gen_outputs["scores"][0][0].detach().cpu().numpy() 59 | gen = self.tokenizer.decode(gen_tokens[0, curr_input_ids.shape[-1]:]) 60 | 61 | if end_if_newline: 62 | gen = gen.split("\n")[0].strip() 63 | elif end_if_second_newline: 64 | gen = "\n".join(gen.split("\n")[:2]).strip() 65 | 66 | if verbose and len(generations)==0: 67 | print ("Input:", prompts[0]) 68 | print ("Prediction:", gen) 69 | 70 | if self.model_name.startswith("llama-sni"): 71 | gen = gen.split("")[0] 72 | 73 | generations.append(gen) 74 | scores.append(gen_scores) 75 | 76 | assert len(generations)==len(prompts)==len(scores) 77 | if is_single: 78 | return generations[0], scores[0] 79 | 80 | return generations, scores 81 | 82 | -------------------------------------------------------------------------------- /massw/collection/neurips/script.py: -------------------------------------------------------------------------------- 1 | import os 2 | from joblib import Memory 3 | 4 | import pandas as pd 5 | import requests 6 | from bs4 import BeautifulSoup 7 | from tqdm import tqdm 8 | 9 | from massw.collection import BaseCollection 10 | 11 | memory = Memory(os.path.join(os.path.dirname(__file__), "cache"), verbose=0) 12 | 13 | 14 | @memory.cache 15 | def get_paper_pdf_url(base_url: str, year: int) -> str: 16 | """Get the PDF URL for a NeurIPS paper from its virtual site URL. 17 | 18 | Args: 19 | base_url (str): The URL of the paper's virtual site page on neurips.cc 20 | year (int): The conference year of the paper 21 | 22 | Returns: 23 | str: The direct URL to the paper's PDF file, or None if the PDF link cannot be found 24 | 25 | Example: 26 | >>> url = get_paper_pdf_url("https://neurips.cc/virtual/2024/poster/72031", 2024) 27 | >>> print(url) 28 | 'https://proceedings.neurips.cc/paper_files/paper/2024/file/abc123-Paper-Conference.pdf' 29 | """ 30 | # Request the HTML content of the base URL 31 | response = requests.get(base_url, timeout=10) 32 | if response.status_code != 200: 33 | return None 34 | 35 | # Parse the HTML content 36 | soup = BeautifulSoup(response.text, "html.parser") 37 | 38 | # Find the PDF link - typically a link with 'PDF' or 'Paper' text 39 | pdf_link = soup.find("a", {"class": "btn", "title": "Paper"}) 40 | 41 | if pdf_link and "href" in pdf_link.attrs: 42 | # Extract the hash from the HTML link 43 | html_url = pdf_link["href"] 44 | # Convert HTML URL to PDF URL format 45 | if "hash" in html_url: 46 | # Extract the hash part and construct the PDF URL 47 | hash_part = html_url.split("/")[-1].split("-")[0] 48 | return f"https://proceedings.neurips.cc/paper_files/paper/{year}/file/{hash_part}-Paper-Conference.pdf" 49 | 50 | # If we couldn't find or parse the PDF link, return None 51 | return None 52 | 53 | 54 | class NeurIPSCollection(BaseCollection): 55 | def __init__(self, year: int): 56 | super().__init__(year, "neurips") 57 | self.data_dir = os.path.join(os.path.dirname(__file__), f"data_{year}") 58 | 59 | def collect(self): 60 | raw_path = os.path.join(self.data_dir, "raw.tsv") 61 | if not os.path.exists(raw_path): 62 | raise FileNotFoundError( 63 | f"Raw data file not found at {raw_path}. Please see the README for instructions." 64 | ) 65 | 66 | raw_df = pd.read_csv(raw_path, sep="\t") 67 | raw_df = raw_df[raw_df["type"] == "Poster"] 68 | bar = tqdm(raw_df.iterrows(), total=len(raw_df)) 69 | success_count = 0 70 | for _, row in bar: 71 | pdf_url = get_paper_pdf_url(row["virtualsite_url"], self.year) 72 | self.add_paper( 73 | title=row["name"], 74 | authors=row["speakers/authors"].split(", "), 75 | abstract=row["abstract"], 76 | pdf_url=pdf_url, 77 | ) 78 | if pdf_url: 79 | success_count += 1 80 | bar.set_description( 81 | f"Collected {success_count}/{len(raw_df)} paper pdf urls." 82 | ) 83 | self.save_metadata() 84 | 85 | 86 | if __name__ == "__main__": 87 | collection = NeurIPSCollection(year=2024) 88 | if not os.path.exists(os.path.join(collection.data_dir, "metadata.tsv")): 89 | collection.collect() 90 | collection.download_pdfs() 91 | -------------------------------------------------------------------------------- /massw/models/mixtral_azure.py: -------------------------------------------------------------------------------- 1 | """Api module for interacting with the Mixtral model.""" 2 | import json 3 | import os 4 | import time 5 | import urllib.request 6 | from typing import Dict 7 | 8 | import pandas as pd 9 | 10 | 11 | def prompts_to_raw_output(messages): 12 | """Process prompts using the specified Mixtral model endpoint.""" 13 | final_results = pd.DataFrame(columns=['pid', 'output']) 14 | 15 | url = os.environ.get("MISTRAL_API_URL") 16 | api_key = os.environ.get("MISTRAL_API_KEY") 17 | print(f"Using Mixtral API at {url}") 18 | print(f"Using Mixtral API key: {api_key}") 19 | headers = { 20 | 'Content-Type': 'application/json', 21 | 'Authorization': f'Bearer {api_key}', 22 | 'azureml-model-deployment': 'mixtralai-mixtral-8x7b-instru-7' 23 | } 24 | 25 | for pid, msg in messages: 26 | response_df = handle_mixtral_model(url, headers, msg, {"pid": pid}) 27 | final_results = pd.concat([final_results, 28 | response_df], ignore_index=True) 29 | 30 | return final_results 31 | 32 | 33 | def raw_output_to_dict(output_path: str) -> Dict[str, str]: 34 | """ 35 | Load and convert raw output from the Mixtral model into a dictionary. 36 | 37 | Args: 38 | output_path (str): Path to the model output CSV file. 39 | 40 | Returns: 41 | Dict[str, str]: A dictionary mapping pid to processed output. 42 | """ 43 | output_dict = {} 44 | task_output = pd.read_csv(output_path, sep="\t", 45 | converters={'result': lambda x: json.loads(x) 46 | if x else None}) 47 | for _, row in task_output.iterrows(): 48 | output = row["output"] 49 | output_dict[row['pid']] = output 50 | return output_dict 51 | 52 | 53 | def handle_mixtral_model(url, headers, messages, entry): 54 | """Handle the Mixtral model API request.""" 55 | output_df = pd.DataFrame(columns=['pid', 'output']) 56 | max_retries = 5 57 | retries = 0 58 | while retries < max_retries: 59 | try: 60 | data = { 61 | "input_data": { 62 | "input_string": messages, 63 | "parameters": { 64 | "temperature": 0, 65 | "top_p": 0.9, 66 | "do_sample": True, 67 | "max_new_tokens": 200, 68 | "return_full_text": True 69 | } 70 | } 71 | } 72 | body = str.encode(json.dumps(data)) 73 | req = urllib.request.Request(url, body, headers) 74 | print(f"{req = }") 75 | with urllib.request.urlopen(req) as response: 76 | result_json = json.loads(response.read()) 77 | output_df = output_df.append({"pid": entry["pid"], 78 | "output": result_json}, 79 | ignore_index=True) 80 | break 81 | except urllib.error.HTTPError as error: 82 | print(f"The request failed with status code: {error.code}") 83 | retries += 1 84 | time.sleep(2) 85 | # mixtral has a werid excetion, need to change below 86 | # To avoid "Catching too general exception". 87 | except Exception as e: 88 | print(f"An unexpected error occurred: {str(e)}") 89 | retries += 1 90 | time.sleep(2) 91 | 92 | return output_df 93 | -------------------------------------------------------------------------------- /benchmark/aspect_prediction/README.md: -------------------------------------------------------------------------------- 1 | # Benchmark 2 | 3 | ## Reproduce Results in the Paper 4 | 5 | To reproduce the benchmark results across different models and prompt types, run 6 | 7 | ```bash 8 | python benchmark/aspect_prediction/task.py --model --prompt --num_samples 1020 9 | ``` 10 | 11 | where: 12 | 13 | - `` is chosen from `gpt-35-turbo`, `gpt-4`, `mixtral-8x7b`. 14 | - `` is chosen from `zero-shot`, `few-shot`, `chain-of-thought`, `few-shot-cot`. 15 | 16 | > We provide the benchmark output through a Dropbox link 17 | > [here](https://www.dropbox.com/scl/fi/nap87vh9s2mc7v3daql5u/results_v1.zip?rlkey=m1n5vck90quwhqygiq1otn2zp&dl=0). 18 | > You could download the results and unzip them to the 19 | > `benchmark/aspect_prediction/outputs` directory through: 20 | > 21 | > ```bash 22 | > wget "https://www.dropbox.com/scl/fi/nap87vh9s2mc7v3daql5u/results_v1.zip?rlkey=m1n5vck90quwhqygiq1otn2zp&dl=1" -O results_v1.zip 23 | > unzip results_v1.zip -d benchmark/aspect_prediction 24 | > rm results_v1.zip 25 | > mv benchmark/aspect_prediction/results benchmark/aspect_prediction/outputs 26 | > ``` 27 | 28 | After running the tasks, evaluate the outcomes by running: 29 | 30 | ```bash 31 | python benchmark/aspect_prediction/eval.py --model_output_dir benchmark/aspect_prediction/outputs/gpt-35-turbo_zero-shot 32 | ``` 33 | 34 | --- 35 | 36 | 37 | 38 | ## Adding a Custom Model to MASSW/API 39 | 40 | To extend the functionality of MASSW by adding custom model scripts for additional models, follow these guidelines. This will allow your model to integrate seamlessly with the existing framework used for aspect prediction and evaluation. 41 | 42 | #### 1. **Location for API Scripts** 43 | 44 | Place your custom model scripts in the `massw/models` directory. This should be similar in structure and design to the existing scripts: 45 | 46 | - `massw/models/gpt_azure.py` 47 | - `massw/models/mixtral_azure.py` 48 | 49 | #### 2. **Required Functions** 50 | 51 | Each API script must include two essential functions: 52 | 53 | - **`prompts_to_raw_output`**: This function processes prompts and generates raw outputs. 54 | 55 | ```python 56 | def prompts_to_raw_output(messages: List[Tuple[str, str]], **other_arguments) -> pd.DataFrame: 57 | """ 58 | Process prompts to generate raw outputs. 59 | 60 | Parameters: 61 | - messages (List[Tuple[str, str]], str]]): A list of tuples containing paper IDs and messages. 62 | 'pid' is the paper ID, and 'message' is the text of the conversation or prompt. 63 | 64 | Returns: 65 | - pd.DataFrame: A DataFrame containing the processed outputs with paper IDs. 66 | """ 67 | pass 68 | ``` 69 | 70 | - **`raw_output_to_dict`**: This function parses raw outputs into a dictionary format. 71 | 72 | ```python 73 | def raw_output_to_dict(output_path: str) -> Dict[str, str]: 74 | """ 75 | Convert raw outputs into a dictionary mapping from paper ID to output. 76 | 77 | Parameters: 78 | - output_path (str): The file path to the output directory where the results are stored. 79 | 80 | Returns: 81 | - Dict[str, str]: A dictionary mapping each paper ID to its corresponding output. 82 | """ 83 | pass 84 | ``` 85 | 86 | #### 3. **Modify the Task Processing Function** 87 | 88 | Update the `process_task` function in `benchmark/aspect_prediction/task.py` to 89 | handle your custom model. Additionally, adapt the `postprocess_output` function 90 | in `benchmark/aspect_observer/eval.py` to support the evaluation of your 91 | model's outputs. 92 | -------------------------------------------------------------------------------- /notebooks/generate_table.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import pandas as pd\n", 10 | "\n", 11 | "def load_data(file_path):\n", 12 | " # Load CSV file and set the first column as the index\n", 13 | " return pd.read_csv(file_path, index_col=0)\n", 14 | "\n", 15 | "def create_latex_table(human_agreement, gpt35_quality, gpt4_quality, mixtral_quality):\n", 16 | " metrics = human_agreement.index.tolist()\n", 17 | " aspects = ['Context', 'Key Idea', 'Method', 'Outcome', 'Projected Impact']\n", 18 | "\n", 19 | " # Start of the LaTeX table\n", 20 | " latex_table = \"\"\"\n", 21 | "\\\\begin{tabular}{lcccccc}\n", 22 | "\\\\toprule\n", 23 | " & Aspects & CS & BS & BLEURT & BLEU & ROUGE-1 \\\\\\\\\n", 24 | "\\\\midrule\n", 25 | "\"\"\"\n", 26 | "\n", 27 | " # Add data rows for Human Agreement\n", 28 | " latex_table += \"\\\\multirow{5}{*}{\\\\shortstack[l]{Human\\\\Agreement}} \"\n", 29 | " for aspect in aspects:\n", 30 | " scores = human_agreement[aspect].values\n", 31 | " latex_table += f\" & {aspect} & {' & '.join(map(str, scores))} \\\\\\\\\\n\"\n", 32 | " if aspect != aspects[-1]:\n", 33 | " latex_table += \" \"\n", 34 | " latex_table += \"\\\\hdashline\\n\"\n", 35 | "\n", 36 | " # Function to add model data rows\n", 37 | " def add_model_data(model_name, model_data):\n", 38 | " latex_table = f\"\\\\multirow{{5}}{{*}}{{\\\\shortstack[l]{{{model_name}\\\\\\Alignment}}} \"\n", 39 | " for aspect in aspects:\n", 40 | " scores = model_data[aspect].values\n", 41 | " latex_table += f\" & {aspect} & {' & '.join(map(str, scores))} \\\\\\\\\\n\"\n", 42 | " if aspect != aspects[-1]:\n", 43 | " latex_table += \" \"\n", 44 | " latex_table += \"\\\\hdashline\\n\"\n", 45 | " return latex_table\n", 46 | "\n", 47 | " # Add data rows for each model\n", 48 | " latex_table += add_model_data(\"GPT-3.5-Human\", gpt35_quality)\n", 49 | " latex_table += add_model_data(\"GPT-4-Human\", gpt4_quality)\n", 50 | " latex_table += add_model_data(\"Mixtral-8x7B-Human\", mixtral_quality)\n", 51 | "\n", 52 | " # End of the LaTeX table\n", 53 | " latex_table += \"\\\\bottomrule\\n\\\\end{tabular}\\n\"\n", 54 | "\n", 55 | " return latex_table\n", 56 | "\n", 57 | "# Load data from CSV files\n", 58 | "human_agreement = load_data('/mnt/data/human_agreement.csv')\n", 59 | "gpt35_quality = load_data('/mnt/data/gpt35_quality.csv')\n", 60 | "gpt4_quality = load_data('/mnt/data/gpt4_quality.csv')\n", 61 | "mixtral_quality = load_data('/mnt/data/mixtral_quality.csv')\n", 62 | "\n", 63 | "# Generate the updated LaTeX table\n", 64 | "new_latex_table = create_latex_table(human_agreement, gpt35_quality, gpt4_quality, mixtral_quality)\n", 65 | "print(new_latex_table)\n" 66 | ] 67 | } 68 | ], 69 | "metadata": { 70 | "kernelspec": { 71 | "display_name": "llm", 72 | "language": "python", 73 | "name": "python3" 74 | }, 75 | "language_info": { 76 | "name": "python", 77 | "version": "3.8.17" 78 | } 79 | }, 80 | "nbformat": 4, 81 | "nbformat_minor": 2 82 | } 83 | -------------------------------------------------------------------------------- /benchmark/aspect_prediction/eval.py: -------------------------------------------------------------------------------- 1 | """Evaluate experiment results based on the model generated output (file).""" 2 | import argparse 3 | import json 4 | import sys 5 | 6 | import nest_asyncio 7 | import pandas as pd 8 | from utils import TASK2GT, TASK_NAMES, postprocess_cot 9 | 10 | from massw.metrics import compute_metrics, flatten_metrics 11 | from massw.models import gpt_azure, mixtral_azure 12 | 13 | sys.path.append("..") 14 | nest_asyncio.apply() 15 | 16 | 17 | def postprocess_output(model_output_dir, 18 | reference_path, 19 | used_cot=False, 20 | model_type="gpt"): 21 | """ 22 | Process model output files to match predictions with references. 23 | 24 | Args: 25 | model_output_dir (str): Directory containing the output files. 26 | reference_path (str): Path to the file containing reference data. 27 | used_cot (bool): Flag to determine if COT processing is needed. 28 | model_type (str): Type of model used to adjust processing logic. 29 | 30 | Returns: 31 | dict: A dictionary containing predictions and references by task. 32 | """ 33 | results = {} 34 | with open(reference_path, "r", encoding="utf-8") as f: 35 | references = [json.loads(line) for line in f] 36 | id2ref = {r["pid"]: r for r in references} 37 | 38 | for task_name in TASK_NAMES: 39 | gt_name = TASK2GT[task_name] 40 | model_path = f"{model_output_dir}/{task_name}.tsv" 41 | 42 | if model_type == "gpt": 43 | id2predictions = gpt_azure.raw_output_to_dict(model_path) 44 | elif model_type == "mixtral": 45 | id2predictions = mixtral_azure.raw_output_to_dict(model_path) 46 | else: 47 | raise ValueError(f"Model type {model_type} not supported.") 48 | 49 | if used_cot: 50 | for pid in id2predictions: 51 | id2predictions[pid] = postprocess_cot(id2predictions[pid]) 52 | 53 | results[task_name] = { 54 | "predictions": list(id2predictions.values()), 55 | "references": [id2ref[pid][gt_name] for pid in id2ref.keys() 56 | if pid in id2predictions] 57 | } 58 | 59 | return results 60 | 61 | 62 | def main(): 63 | """Run main function to process and compute evaluation metrics.""" 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument( 66 | "--model_output_dir", 67 | type=str, 68 | help="Path to the model output dir.", 69 | default="benchmark/aspect_prediction/outputs/gpt-35-turbo_zero-shot", 70 | ) 71 | parser.add_argument( 72 | "--reference_path", 73 | type=str, 74 | help="Path to the references file.", 75 | default="data/benchmark_0531.jsonl", 76 | ) 77 | parser.add_argument( 78 | "--used_cot", 79 | action="store_true", 80 | help="Used COT.", 81 | ) 82 | args = parser.parse_args() 83 | model_type = "gpt" if "gpt" in args.model_output_dir else "mixtral" 84 | 85 | results = postprocess_output( 86 | args.model_output_dir, 87 | args.reference_path, 88 | args.used_cot, 89 | model_type=model_type, 90 | ) 91 | metrics_output_path = f"{args.model_output_dir}/metrics.tsv" 92 | 93 | metrics = {} 94 | 95 | for task_name, task_results in results.items(): 96 | print(f"Processing task: {task_name}") 97 | predictions = task_results["predictions"] 98 | references = task_results["references"] 99 | metrics[task_name] = flatten_metrics( 100 | compute_metrics( 101 | predictions, 102 | references, 103 | metric_names=[ 104 | "bleu", "rouge", "cosine", "bertscore", "bleurt" 105 | ], 106 | )) 107 | print(f"Processed task: {task_name}") 108 | print(metrics[task_name]) 109 | 110 | df = pd.DataFrame(metrics) 111 | df.to_csv(metrics_output_path, index=True, sep="\t") 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # db 2 | **.db 3 | 4 | # API key 5 | api.key 6 | 7 | # Cache 8 | **/cache/ 9 | 10 | # Raw data 11 | raw_data/ 12 | .DS_store 13 | outputs/ 14 | data/massw_* 15 | 16 | # Dev notebooks 17 | dev.ipynb 18 | 19 | # Logging 20 | **/oaib.txt 21 | 22 | # Byte-compiled / optimized / DLL files 23 | __pycache__/ 24 | *.py[cod] 25 | *$py.class 26 | 27 | # C extensions 28 | *.so 29 | 30 | # Distribution / packaging 31 | .Python 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | downloads/ 36 | eggs/ 37 | .eggs/ 38 | lib/ 39 | lib64/ 40 | parts/ 41 | sdist/ 42 | var/ 43 | wheels/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .nox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | *.py,cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | cover/ 74 | 75 | # Translations 76 | *.mo 77 | *.pot 78 | 79 | # Django stuff: 80 | *.log 81 | local_settings.py 82 | db.sqlite3 83 | db.sqlite3-journal 84 | 85 | # Flask stuff: 86 | instance/ 87 | .webassets-cache 88 | 89 | # Scrapy stuff: 90 | .scrapy 91 | 92 | # Sphinx documentation 93 | docs/_build/ 94 | 95 | # PyBuilder 96 | .pybuilder/ 97 | target/ 98 | 99 | # Jupyter Notebook 100 | .ipynb_checkpoints 101 | 102 | # IPython 103 | profile_default/ 104 | ipython_config.py 105 | 106 | # pyenv 107 | # For a library or package, you might want to ignore these files since the code is 108 | # intended to run in multiple environments; otherwise, check them in: 109 | # .python-version 110 | 111 | # pipenv 112 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 113 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 114 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 115 | # install all needed dependencies. 116 | #Pipfile.lock 117 | 118 | # poetry 119 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 120 | # This is especially recommended for binary packages to ensure reproducibility, and is more 121 | # commonly ignored for libraries. 122 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 123 | #poetry.lock 124 | 125 | # pdm 126 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 127 | #pdm.lock 128 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 129 | # in version control. 130 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 131 | .pdm.toml 132 | .pdm-python 133 | .pdm-build/ 134 | 135 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 136 | __pypackages__/ 137 | 138 | # Celery stuff 139 | celerybeat-schedule 140 | celerybeat.pid 141 | 142 | # SageMath parsed files 143 | *.sage.py 144 | 145 | # Environments 146 | .env 147 | .venv 148 | env/ 149 | venv/ 150 | ENV/ 151 | env.bak/ 152 | venv.bak/ 153 | 154 | # Spyder project settings 155 | .spyderproject 156 | .spyproject 157 | 158 | # Rope project settings 159 | .ropeproject 160 | 161 | # mkdocs documentation 162 | /site 163 | 164 | # mypy 165 | .mypy_cache/ 166 | .dmypy.json 167 | dmypy.json 168 | 169 | # Pyre type checker 170 | .pyre/ 171 | 172 | # pytype static type analyzer 173 | .pytype/ 174 | 175 | # Cython debug symbols 176 | cython_debug/ 177 | 178 | # PyCharm 179 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 180 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 181 | # and can be added to the global gitignore or merged into this file. For a more nuclear 182 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 183 | #.idea/ 184 | roberta_stopwords.txt 185 | -------------------------------------------------------------------------------- /massw/collection/acl/script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import re 4 | import requests 5 | import pandas as pd 6 | from bs4 import BeautifulSoup 7 | from tqdm import tqdm 8 | from urllib.parse import urljoin 9 | 10 | from massw.collection import BaseCollection 11 | 12 | 13 | class ACLCollection(BaseCollection): 14 | def __init__(self, year: int, base_url: str = "https://aclanthology.org"): 15 | super().__init__(year, "acl") 16 | self.data_dir = os.path.join(os.path.dirname(__file__), f"data_{year}") 17 | self.base_url = base_url 18 | 19 | def collect(self): 20 | url = f"https://aclanthology.org/events/acl-{self.year}/" 21 | response = requests.get(url) 22 | if response.status_code != 200: 23 | print(f"Failed to retrieve the page: {response.status_code}") 24 | return None 25 | soup = BeautifulSoup(response.content, "html.parser") 26 | paper_links = soup.select("span.d-block > strong > a.align-middle") 27 | for i, link in enumerate(tqdm(paper_links, desc="Processing Papers")): 28 | time.sleep(1) 29 | rel_url = link.get("href") 30 | paper_url = urljoin(self.base_url, rel_url) 31 | pdf_url = f"{paper_url[:-1]}.pdf" 32 | title = link.get_text().strip() 33 | if ( 34 | ("long" in rel_url or "short" in rel_url) 35 | and ("acl-long.0" not in rel_url) 36 | and ("acl-short.0" not in rel_url) 37 | ): 38 | paper_response = requests.get(paper_url) 39 | if paper_response.status_code == 200: 40 | paper_soup = BeautifulSoup(paper_response.content, "html.parser") 41 | bibtex_element = paper_soup.select_one("pre#citeBibtexContent") 42 | bibtex_text = bibtex_element.get_text().strip() 43 | author_match = re.search( 44 | r'author\s*=\s*"\s*(.*?)\s*"', bibtex_text, re.DOTALL 45 | ) 46 | if author_match: 47 | raw_authors = author_match.group(1) 48 | author_list = [ 49 | author.replace(",", "").strip() 50 | for author in raw_authors.replace("\n", " ").split(" and ") 51 | ] 52 | authors = ", ".join(author_list) 53 | else: 54 | authors = None 55 | abstract_match = re.search( 56 | r'abstract\s*=\s*"\s*(.*?)\s*"', bibtex_text, re.DOTALL 57 | ) 58 | abstract = ( 59 | abstract_match.group(1).strip() if abstract_match else None 60 | ) 61 | self.add_paper( 62 | title=title, 63 | authors=authors, 64 | abstract=abstract, 65 | pdf_url=pdf_url, 66 | url=paper_url, 67 | ) 68 | self.save_metadata() 69 | 70 | def download_pdfs(self): 71 | """ 72 | Download PDF files for all papers in the collection. 73 | 74 | The PDFs are saved to {data_dir}/pdf/{pid}.pdf. 75 | """ 76 | pdf_dir = os.path.join(self.data_dir, "pdf") 77 | if not os.path.exists(pdf_dir): 78 | os.makedirs(pdf_dir) 79 | if os.path.exists(os.path.join(self.data_dir, "metadata.tsv")): 80 | path = os.path.join(self.data_dir, "metadata.tsv") 81 | df = pd.read_csv(path, sep="\t") 82 | for index, row in tqdm( 83 | df.iterrows(), total=len(df), desc="Downloading PDFs" 84 | ): 85 | pdf_url = row["pdf_url"] 86 | pid = row["pid"] 87 | pdf_path = os.path.join(pdf_dir, f"{pid}.pdf") 88 | response = requests.get(pdf_url, timeout=30) 89 | with open(pdf_path, "wb") as f: 90 | f.write(response.content) 91 | 92 | 93 | if __name__ == "__main__": 94 | collection = ACLCollection(year=2024) 95 | if not os.path.exists(os.path.join(collection.data_dir, "metadata.tsv")): 96 | collection.collect() 97 | collection.download_pdfs() 98 | -------------------------------------------------------------------------------- /factscore/openai_lm.py: -------------------------------------------------------------------------------- 1 | from factscore.lm import LM 2 | import openai 3 | import sys 4 | import time 5 | import os 6 | import numpy as np 7 | import logging 8 | 9 | class OpenAIModel(LM): 10 | 11 | def __init__(self, model_name, cache_file=None, key_path="api.key"): 12 | self.model_name = model_name 13 | self.key_path = key_path 14 | self.temp = 0.7 15 | self.save_interval = 100 16 | super().__init__(cache_file) 17 | 18 | def load_model(self): 19 | # Try to load the API key from the file 20 | if os.path.exists(self.key_path): 21 | with open(self.key_path, 'r') as f: 22 | api_key = f.readline().strip() 23 | else: 24 | # Fall back to using the environment variable 25 | api_key = os.getenv('OPENAI_API_KEY') 26 | assert api_key is not None, "API key not found. Please set it in the environment variable 'OPENAI_API_KEY'." 27 | openai.api_key = api_key.strip() 28 | self.model = self.model_name 29 | 30 | def _generate(self, prompt, max_sequence_length=2048, max_output_length=128): 31 | if self.add_n % self.save_interval == 0: 32 | self.save_cache() 33 | # return a tuple of string (generated text) and metadata (any format) 34 | # This should be about generating a response from the prompt, no matter what the application is 35 | if self.model_name == "ChatGPT": 36 | # Construct the prompt send to ChatGPT 37 | message = [{"role": "user", "content": prompt}] 38 | # Call API 39 | response = call_ChatGPT(message, temp=self.temp, max_len=max_sequence_length) 40 | # Get the output from the response 41 | output = response["choices"][0]["message"]["content"] 42 | return output, response 43 | elif self.model_name == "InstructGPT": 44 | # Call API 45 | response = call_GPT3(prompt, temp=self.temp) 46 | # Get the output from the response 47 | output = response["choices"][0]["text"] 48 | return output, response 49 | else: 50 | raise NotImplementedError() 51 | 52 | def call_ChatGPT(message, model_name="gpt-4o-mini", max_len=1024, temp=0.7, verbose=False): 53 | # call GPT-3 API until result is provided and then return it 54 | response = None 55 | received = False 56 | num_rate_errors = 0 57 | while not received: 58 | try: 59 | response = openai.ChatCompletion.create(model=model_name, 60 | messages=message, 61 | max_tokens=max_len, 62 | temperature=temp) 63 | received = True 64 | except: 65 | # print(message) 66 | num_rate_errors += 1 67 | error = sys.exc_info()[0] 68 | if error == openai.error.InvalidRequestError: 69 | # something is wrong: e.g. prompt too long 70 | logging.critical(f"InvalidRequestError\nPrompt passed in:\n\n{message}\n\n") 71 | assert False 72 | 73 | print(f"API error: {error} ({num_rate_errors}). Waiting 1sec") 74 | time.sleep(1) 75 | return response 76 | 77 | 78 | def call_GPT3(prompt, model_name="text-davinci-003", max_len=512, temp=0.7, num_log_probs=0, echo=False, verbose=False): 79 | # call GPT-3 API until result is provided and then return it 80 | response = None 81 | received = False 82 | num_rate_errors = 0 83 | while not received: 84 | try: 85 | response = openai.Completion.create(model=model_name, 86 | prompt=prompt, 87 | max_tokens=max_len, 88 | temperature=temp, 89 | logprobs=num_log_probs, 90 | echo=echo) 91 | received = True 92 | except: 93 | error = sys.exc_info()[0] 94 | num_rate_errors += 1 95 | if error == openai.error.InvalidRequestError: 96 | # something is wrong: e.g. prompt too long 97 | logging.critical(f"InvalidRequestError\nPrompt passed in:\n\n{prompt}\n\n") 98 | assert False 99 | logging.error("API error: %s (%d)" % (error, num_rate_errors)) 100 | time.sleep(np.power(2, num_rate_errors)) 101 | return response 102 | -------------------------------------------------------------------------------- /prompts/method.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "role": "system", 4 | "content": "You are an expert in scientific research and can evaluate the semantic similarity between pairs of method aspect summarizations from scientific papers. Method aspect summarization refers to The specific research method that investigates and validates the key idea. It could be an experimental setup, a theoretical framework, or other necessary methodology to implement and/or evaluate the key idea. You will be provided with two inputs, and your task is to assign a similarity score on a scale from 1 to 10, where 1 indicates completely dissimilar content and 10 indicates identical content. Before scoring, provide a brief justification for your score. Structure your response in the following JSON format: \n{\n \"score\": score,\n \"justification\": \"Brief explanation for the score\"\n}\n" 5 | }, 6 | { 7 | "role": "user", 8 | "content": "Input 1:The authors use a case study with one psychologist and three patients, across a period of two weeks to show the effectiveness of new web platform.\nInput 2:The case study indicated that the platform successfully enhanced the therapy process by improving data collection and maintaining therapeutic awareness." 9 | }, 10 | { 11 | "role": "assistant", 12 | "content": "{\n \"justification\": While both inputs refer to a case study involving a web platform in a therapeutic context, Input 1 provides details about the participants and duration of the case study, whereas Input 2 focuses on the outcomes of the case study. Therefore, specific details covered in two inputs are different.\", \n \"score\": 1\n}" 13 | }, 14 | { 15 | "role": "user", 16 | "content": "Input 1:The authors use a case study with one psychologist and three patients, across a period of two weeks to show the effectiveness of new web platform.\nInput 2:The authors use the platform for three patients for two weeks and improve the performance of the therapy." 17 | }, 18 | { 19 | "role": "assistant", 20 | "content": "{\n \"justification\":Both inputs involve a web platform being used for a case study with three patients over a period of two weeks. However, Input 1 details the inclusion of a psychologist and explicitly mentions showing the effectiveness of the platform, whereas Input 2 focuses on improving therapy performance and has different logic that it's authors who improve cure the patients. There are many differences in details and focus lead to score of 3. \", \n \"score\": 3\n}" 21 | }, 22 | { 23 | "role": "user", 24 | "content": "Input 1:The authors use a case study with one psychologist and three patients, across a period of two weeks to show the effectiveness of new web platform.\nInput 2:The authors collected three patients' data using the web platform, and the impact of the therapy was measured." 25 | }, 26 | { 27 | "role": "assistant", 28 | "content": "{\n \"justification\":Both inputs describe a methodology involving data collection from three patients using a web platform and evaluating its impact on therapy. However, Input 1 provides two more additional details: duration of the case study (two weeks) and the effectiveness of new web platform, which are missing in Input 2. Therefore, the descriptions are similar but not identical. \", \n \"score\": 5\n}" 29 | }, 30 | { 31 | "role": "user", 32 | "content": "Input 1:The authors use a case study with one psychologist and three patients, across a period of two weeks to show the effectiveness of new web platform.\nInput 2:The authors conduct a case study with one psychologist and three patients, across a period of two weeks." 33 | }, 34 | { 35 | "role": "assistant", 36 | "content": "{\n \"justification\":Both inputs describe a similar methodology involving a case study with one psychologist and three patients over two weeks. The only difference is that Input 2 miss one key-point: the effectiveness of a new web platform. Therefore, they are nearly identical in terms of methodology but lack one important detail.\", \n \"score\": 7\n}" 37 | }, 38 | { 39 | "role": "user", 40 | "content": "Input 1:The authors use a case study with one psychologist and three patients, across a period of two weeks to show the effectiveness of new web platform.\nInput 2:The design of the web platform is introduced, and its effectiveness is evaluated through a case study involving one psychologist and three patients over a period of two weeks." 41 | }, 42 | { 43 | "role": "assistant", 44 | "content": "{\n \"justification\": Both inputs discuss a method involving a case study with one psychologist and three patients over two weeks to evaluate the effectiveness of a web platform. The core methodology is nearly identical. \", \n \"score\": 10\n}" 45 | } 46 | ] -------------------------------------------------------------------------------- /massw/collection/base.py: -------------------------------------------------------------------------------- 1 | """Base classes for paper collection from academic conferences.""" 2 | 3 | import os 4 | from abc import ABC, abstractmethod 5 | from dataclasses import dataclass, field 6 | from typing import List, Optional 7 | 8 | import pandas as pd 9 | import requests 10 | from tqdm import tqdm 11 | 12 | 13 | @dataclass 14 | class Paper: 15 | """ 16 | Represents a research paper with its metadata. 17 | 18 | Attributes: 19 | pid: Unique paper ID, formatted as `{venue}_{year}_{pid}` 20 | year: Year of publication 21 | venue: Publication venue 22 | title: Title of the paper 23 | authors: Authors of the paper 24 | abstract: Abstract of the paper 25 | pdf_url: URL to the PDF of the paper that can be directly downloaded 26 | url: URL to the paper page on the venue website 27 | """ 28 | 29 | # Unique paper ID, formatted as `{venue}_{year}_{pid}` 30 | pid: str 31 | # Year of publication 32 | year: int 33 | # Publication venue 34 | venue: str 35 | # Title of the paper 36 | title: Optional[str] = None 37 | # Authors of the paper 38 | authors: Optional[List[str]] = field(default_factory=list) 39 | # Abstract of the paper 40 | abstract: Optional[str] = None 41 | # URL to the PDF of the paper that can be directly downloaded 42 | pdf_url: Optional[str] = None 43 | # URL to the paper page on the venue website 44 | url: Optional[str] = None 45 | 46 | 47 | class BaseCollection(ABC): 48 | """ 49 | Abstract base class for collecting papers from academic conferences. 50 | 51 | This class provides common functionality for collecting, storing, and 52 | downloading papers from various academic venues. 53 | 54 | Attributes: 55 | year: The year of the conference 56 | venue: The venue code (e.g., 'neurips', 'acl') 57 | papers: List of collected papers 58 | data_dir: Directory to store the collected data 59 | """ 60 | 61 | def __init__(self, year: int, venue: str): 62 | self.year: int = year 63 | self.venue: str = venue 64 | self.papers: List[Paper] = [] 65 | self.data_dir: str = None 66 | 67 | @abstractmethod 68 | def collect(self): 69 | """ 70 | Abstract method that must be implemented by subclasses. 71 | This method should collect papers from the specific venue. 72 | """ 73 | raise NotImplementedError("Subclasses must implement this method.") 74 | 75 | def add_paper( 76 | self, 77 | title: Optional[str] = None, 78 | authors: Optional[List[str]] = None, 79 | abstract: Optional[str] = None, 80 | pdf_url: Optional[str] = None, 81 | url: Optional[str] = None, 82 | ): 83 | """ 84 | Add a paper to the collection. 85 | 86 | Args: 87 | title: Title of the paper 88 | authors: List of authors 89 | abstract: Abstract of the paper 90 | pdf_url: URL to the PDF file 91 | url: URL to the paper page on the venue website 92 | """ 93 | self.papers.append( 94 | Paper( 95 | pid=f"{self.venue}_{self.year}_{len(self.papers)}", 96 | year=self.year, 97 | venue=self.venue, 98 | title=title, 99 | authors=authors, 100 | abstract=abstract, 101 | pdf_url=pdf_url, 102 | url=url, 103 | ) 104 | ) 105 | 106 | def save_metadata(self): 107 | """ 108 | Save the metadata of collected papers to a TSV file. 109 | 110 | The metadata is saved to {data_dir}/metadata.tsv. 111 | """ 112 | if not os.path.exists(self.data_dir): 113 | os.makedirs(self.data_dir) 114 | metadata_path = os.path.join(self.data_dir, "metadata.tsv") 115 | metadata_df = pd.DataFrame(self.papers) 116 | metadata_df.to_csv(metadata_path, sep="\t", index=False) 117 | 118 | def download_pdfs(self): 119 | """ 120 | Download PDF files for all papers in the collection. 121 | 122 | The PDFs are saved to {data_dir}/pdf/{pid}.pdf. 123 | """ 124 | pdf_dir = os.path.join(self.data_dir, "pdf") 125 | if not os.path.exists(pdf_dir): 126 | os.makedirs(pdf_dir) 127 | for paper in tqdm(self.papers, desc="Downloading PDFs"): 128 | if paper.pdf_url: 129 | pdf_path = os.path.join(pdf_dir, f"{paper.pid}.pdf") 130 | if not os.path.exists(pdf_path): 131 | response = requests.get(paper.pdf_url, timeout=30) 132 | with open(pdf_path, "wb") as f: 133 | f.write(response.content) 134 | -------------------------------------------------------------------------------- /prompts/future.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "role": "system", 4 | "content": "You are an expert in scientific research and can evaluate the semantic similarity between pairs of future impact aspect summarizations from scientific papers. Future impact refers to the author-anticipated impact of the work on the field, and potential further research identified by the author that may improve or extend this study. You will be provided with two inputs, and your task is to assign a similarity score on a scale from 1 to 10, where 1 indicates completely dissimilar content and 10 indicates identical content. Before scoring, provide a brief justification for your score. Structure your response in the following JSON format: \n{\n \"score\": score,\n \"justification\": \"Brief explanation for the score\"\n}\n" 5 | }, 6 | { 7 | "role": "user", 8 | "content": "Input 1:The proposed adversarial detection algorithm provides an opportunity to build an interpretable and robust model.\nInput 2:The paper proposes advancements in dynamic network pruning through interpretable layerwise channel selection, aiming to enhance runtime efficiency and model interpretability." 9 | }, 10 | { 11 | "role": "assistant", 12 | "content": "{\n \"justification\":While both inputs focus on the interpretability and robustness of models, their context differs. Input 1 discusses an adversarial detection algorithm and its potential impact, whereas Input 2 talks about dynamic network pruning and its goal to enhance runtime efficiency and interpretability. The overlap in 'interpretability' is the primary commonality, but the technological focus and specific impacts are different. \", \n \"score\": 1\n}" 13 | }, 14 | { 15 | "role": "user", 16 | "content": "Input 1:The proposed adversarial detection algorithm provides an opportunity to build an interpretable and robust model.\nInput 2:The proposed adversarial detection algorithm offers enhanced security by effectively identifying malicious inputs, potentially raising the bar for adversarial defense in AI systems." 17 | }, 18 | { 19 | "role": "assistant", 20 | "content": "{\n \"justification\":Both inputs discuss the proposed adversarial detection algorithm and the benefits it brings. Input 1 focuses on building an interpretable and robust model, whereas Input 2 emphasizes enhanced security and the potential for improved adversarial defense. Although there is some overlap in the general context of adversarial detection, the specific impacts and focus areas differ. \", \n \"score\": 3\n}" 21 | }, 22 | { 23 | "role": "user", 24 | "content": "Input 1:The proposed adversarial detection algorithm provides an opportunity to build an interpretable and robust model.\nInput 2: This proposed adversarial detection algorithm promises robustness against adversarial attacks, significantly improving detection rates compared to existing methods." 25 | }, 26 | { 27 | "role": "assistant", 28 | "content": "{\n \"justification\": Both inputs discuss the proposed adversarial detection algorithm with a focus on its robustness. Input 1 mentions building an interpretable and robust model, while Input 2 highlights improved robustness against adversarial attacks and better detection rates compared to existing methods. The core theme of robustness is shared, but Input 2 includes additional specifics about performance improvements.\", \n \"score\": 5\n}" 29 | }, 30 | { 31 | "role": "user", 32 | "content": "Input 1:The proposed adversarial detection algorithm provides an opportunity to build an interpretable and robust model.\nInput 2:The proposed method provides an opportunity to build an interpretable and robust model with higher accuracy and reduced computing costs, which can be applied in various fields involving large-scale neural networks and adversarial attacks." 33 | }, 34 | { 35 | "role": "assistant", 36 | "content": "{\n \"justification\":Both inputs highlight the creation of an interpretable and robust model as a future impact of the proposed method. However, Input 2 includes additional details about higher accuracy, reduced computing costs, and applicability to large-scale neural networks and adversarial attacks, adding more context and specific benefits that are not covered in Input 1. \", \n \"score\": 8\n}" 37 | }, 38 | { 39 | "role": "user", 40 | "content": "Input 1:The proposed adversarial detection algorithm provides an opportunity to build an interpretable and robust model.\nInput 2:The proposed method can help to build an interpretable and robust model for adversarial example detection." 41 | }, 42 | { 43 | "role": "assistant", 44 | "content": "{\n \"justification\":Both inputs describe the potential future impact of the proposed method in terms of building an interpretable and robust model for adversarial detection. The slight difference in phrasing does not significantly alter the semantic similarity between them. \", \n \"score\": 10\n}" 45 | } 46 | ] -------------------------------------------------------------------------------- /prompts/outcome.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "role": "system", 4 | "content": "You are an expert in scientific research and can evaluate the semantic similarity between pairs of outcome aspect summarizations from scientific papers. Outcome refers to the factual statement about the study output. It could be the experiment results and any other measurable outcome that has occurred. It marks whether the key hypothesis is testified or not . You will be provided with two inputs, and your task is to assign a similarity score on a scale from 1 to 10, where 1 indicates completely dissimilar content and 10 indicates identical content. Before scoring, provide a brief justification for your score. Structure your response in the following JSON format: \n{\n \"score\": score,\n \"justification\": \"Brief explanation for the score\"\n}\n" 5 | }, 6 | { 7 | "role": "user", 8 | "content": "Input 1:The case study show that web platform make improvements in the collection of meaningful data about a person, and on maintaining awareness of the therapy as a whole.\nInput 2:The authors propose a web platform to collect and manage biographical materials in a streamlined manner, enabling reminiscence and other biographical cognitive stimulation practices for people with dementia." 9 | }, 10 | { 11 | "role": "assistant", 12 | "content": "{\n \"justification\":Two inputs only have one key-point in common: web platform. Input 1 emphasizes improvements in data collection and maintaining therapy awareness accomplished by the web platform, whereas Input 2 highlights motivation and advantages of the web platform. The overlap is in the concept of data collection for therapeutic purposes, but the specifics and context differ significantly. \", \n \"score\": 1\n}" 13 | }, 14 | { 15 | "role": "user", 16 | "content": "Input 1:The case study show that web platform make improvements in the collection of meaningful data about a person, and on maintaining awareness of the therapy as a whole.\nInput 2:Experienments by conducted by one psychologist and three patients, across a period of two weeks show change of collection of data about a person made by the web platform." 17 | }, 18 | { 19 | "role": "assistant", 20 | "content": "{\n \"justification\": Both inputs refer to a web platform and its role in the collection of data about a person. However, Input 1 specifically mentions ‘case study’, 'improvements' and 'maintaining awareness of the therapy,' while Input 2 only indicates a 'change' in data collection instead of improvement and does not address therapy awareness. And it has more details about experienments.\", \n \"score\": 3\n}" 21 | }, 22 | { 23 | "role": "user", 24 | "content": "Input 1:The case study show that web platform make improvements in the collection of meaningful data about a person, and on maintaining awareness of the therapy as a whole.\nInput 2:Results from a case study with one psychologist and three patients, across a period of two weeks show improvements in the collection of meaningful data about a person." 25 | }, 26 | { 27 | "role": "assistant", 28 | "content": "{\n \"justification\":Both inputs discuss the case stady shows improvements in the collection of meaningful data about a person. Input 1 also mentions two other main points: maintaining awareness of therapy and the platform. Input 2 provides additional context about the participants and duration of the case study, which is not mentioned in Input 1, leading to score of 5.\", \n \"score\": 5\n}" 29 | }, 30 | { 31 | "role": "user", 32 | "content": "Input 1:The case study show that web platform make improvements in the collection of meaningful data about a person, and on maintaining awareness of the therapy as a whole.\nInput 2:The study showed improvements in the collection and analyzing of patient data and maintenance of the therapy's overall awareness." 33 | }, 34 | { 35 | "role": "assistant", 36 | "content": "{\n \"justification\":Both inputs emphasize three key-points: the case study, improvements in data collection and maintaining awareness of therapy. The primary differences are the addition of 'analyzing' in Input 2 and input 2 miss one key-point platform, leading to score of 7. \", \n \"score\": 7\n}" 37 | }, 38 | { 39 | "role": "user", 40 | "content": "Input 1:The case study show that web platform make improvements in the collection of meaningful data about a person, and on maintaining awareness of the therapy as a whole.\nInput 2:The platform showed improvements in the collection of meaningful data about people with dementia and in maintaining awareness of the therapy as a whole." 41 | }, 42 | { 43 | "role": "assistant", 44 | "content": "{\n \"justification\":Both inputs describe improvements made by a web platform in the collection of meaningful data about persons and maintaining awareness of therapy. The details are nearly identical, with input2 miss a point: the case study, leading to score of 9.\", \n \"score\": 9\n}" 45 | } 46 | ] -------------------------------------------------------------------------------- /factscore/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import torch 7 | 8 | def assert_all_approx_close(a, b, rtol, atol, count): 9 | 10 | idx = torch.isclose(a.float(), b.float(), rtol, atol) 11 | sumval = (idx==0).sum().item() 12 | if sumval > count: 13 | print(f'Too many values not close: assert {sumval} < {count}') 14 | try: 15 | torch.testing.assert_allclose(a, b, rtol, atol) 16 | except Exception as e: 17 | print(e) 18 | 19 | 20 | def get_memory_footprint(model, return_buffers=True): 21 | """ 22 | Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. 23 | Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the 24 | PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 25 | Arguments: 26 | return_buffers (`bool`, *optional*, defaults to `True`): 27 | Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers 28 | are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch 29 | norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 30 | """ 31 | mem = sum([param.nelement() * param.element_size() for param in model.parameters()]) 32 | if return_buffers: 33 | mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) 34 | mem = mem + mem_bufs 35 | return mem 36 | 37 | 38 | def ـreplace_linear_with_int8linear(model, modules_to_not_convert="lm_head"): 39 | for name, module in model.named_children(): 40 | ـreplace_linear_with_int8linear(module, modules_to_not_convert) 41 | 42 | if isinstance(module, torch.nn.Linear) and name != modules_to_not_convert: 43 | model._modules[name] = QuantizedLinearInt8(linear_layer=module) 44 | return 45 | 46 | 47 | class QuantizedLinearInt8(torch.nn.Module): 48 | ''' 49 | A simple but effictive implmenetion of Int8 quantization for linear layers. 50 | The weights are quantized and stored as Int8, which saves ~50% of the gpu memory. 51 | During the forwared pass, the weights are de-quantized back to fp16 to do multiplication. 52 | Pros: 53 | - saves ~50% of the gpu memory 54 | - accurate quantization because only the weights are quantized, and the weights don't suffer 55 | from the "outliers" issue mentioned in the LLM.int8 paper; only the activations do. 56 | - high precision results beacuse the multiplication is done in fp16 57 | - much faster than LLM.int8 58 | Cons: 59 | - a bit slower because of the added computation of dequantization in each forward pass. In practice, the slowdown 60 | is not large because in the generation application, gpu utilization is not very high. 61 | ''' 62 | def __init__(self, linear_layer): 63 | super().__init__() 64 | self.bias = linear_layer.bias 65 | 66 | weight_bit_width = 8 67 | weight = linear_layer.weight 68 | 69 | self.weight_scale = torch.nn.Parameter( 70 | (weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1)).half(), 71 | ) 72 | # print(self.weight_scale.max().item(), self.weight_scale.min().item(), self.weight_scale.mean().item()) 73 | # if self.weight_scale.max().item() > 0.002: 74 | # print(self.weight_scale.max().item()) 75 | self.weight = torch.nn.Parameter( 76 | torch.round(weight.float() / self.weight_scale[:, None]).char(), 77 | requires_grad=False 78 | ) 79 | 80 | def forward(self, x): 81 | weight = self.weight.half() * self.weight_scale[:, None] 82 | return torch.nn.functional.linear(x, weight, self.bias) 83 | 84 | 85 | def convert_model_to_int8_on_gpu(model, device): 86 | """ 87 | Quantize a model to int8 and move it to GPU using a simple method. 88 | """ 89 | if 'cuda' not in device: 90 | raise ValueError(f"Target device should be a gpu. Device {device} is not supported") 91 | 92 | model.half() 93 | 94 | memory_before_quantization = get_memory_footprint(model) # without lm_head 95 | 96 | ـreplace_linear_with_int8linear(model) # replace `Linear` with `QuantizedLinearInt8` 97 | 98 | model.to(device=device) 99 | memory_after_quantization = get_memory_footprint(model) # without lm_head 100 | 101 | saving = round(100 * memory_after_quantization/memory_before_quantization) 102 | memory_before_quantization = round(memory_before_quantization / 2**30, 2) # rounding for printing 103 | memory_after_quantization = round(memory_after_quantization / 2**30, 2) # rounding for printing 104 | 105 | print(f'Quantization memory - before: {memory_before_quantization} GB, after: {memory_after_quantization} GB ({saving}% of the size before)') 106 | return model 107 | -------------------------------------------------------------------------------- /massw/collection/cvpr/script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | import requests 5 | import pandas as pd 6 | from bs4 import BeautifulSoup 7 | from tqdm import tqdm 8 | from massw.collection import BaseCollection 9 | from urllib.parse import urljoin 10 | 11 | 12 | class CVPRCollection(BaseCollection): 13 | def __init__(self, year: int, base_url: str = "https://openaccess.thecvf.com/"): 14 | super().__init__(year, "cvpr") 15 | self.data_dir = os.path.join(os.path.dirname(__file__), f"data_{year}") 16 | self.base_url = base_url 17 | self.headers = { 18 | "User-Agent": ( 19 | "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " 20 | "AppleWebKit/537.36 (KHTML, like Gecko) " 21 | "Chrome/117.0.0.0 Safari/537.36" 22 | ) 23 | } 24 | 25 | def collect(self): 26 | url = f"{self.base_url}CVPR{self.year}?day=all" 27 | response = requests.get(url, headers=self.headers, timeout=30) 28 | if response.status_code != 200: 29 | print(f"Failed to fetch data: Status code {response.status_code}") 30 | 31 | soup = BeautifulSoup(response.text, "html.parser") 32 | paper_titles = soup.find_all("dt", class_="ptitle") 33 | for i, title_element in enumerate(tqdm(paper_titles, desc="Processing Papers")): 34 | time.sleep(1) 35 | title_link = title_element.find("a") 36 | # get title 37 | title = title_link.text.strip() 38 | # get paper_url 39 | paper_url = urljoin(self.base_url, title_link["href"]) 40 | info_element = title_element.find_next_sibling("dd") 41 | info_element = info_element.find_next_sibling("dd") 42 | pdf_link = info_element.find("a", text="pdf") 43 | # get pdf_url 44 | pdf_url = urljoin(self.base_url, pdf_link["href"]) if pdf_link else None 45 | # get authors 46 | bibtex_div = info_element.find("div", class_="bibref") 47 | authors = None 48 | if bibtex_div: 49 | bibtex_text = bibtex_div.text.strip() 50 | author_match = re.search(r"author\s*=\s*{([^}]+)}", bibtex_text) 51 | if author_match: 52 | authors_original = author_match.group(1).strip() 53 | author_parts = authors_original.split(" and ") 54 | transformed_authors = [] 55 | for author in author_parts: 56 | parts = [part.strip() for part in author.split(",")] 57 | if len(parts) >= 2: 58 | transformed_authors.append(f"{parts[0]} {parts[1]}") 59 | else: 60 | transformed_authors.append(author) 61 | 62 | authors = ", ".join(transformed_authors) 63 | # get abstract 64 | try: 65 | response = requests.get( 66 | paper_url, headers=self.headers, timeout=30, verify=False 67 | ) 68 | response.raise_for_status() 69 | soup = BeautifulSoup(response.text, "html.parser") 70 | abstract_div = soup.find("div", id="abstract") 71 | if abstract_div: 72 | abstract_text = abstract_div.text.strip() 73 | abstract_text = re.sub(r"\s+", " ", abstract_text) 74 | except: 75 | print(f"Failed to fetch {paper_url}") 76 | continue 77 | 78 | self.add_paper( 79 | title=title, 80 | authors=authors, 81 | abstract=abstract_text, 82 | pdf_url=pdf_url, 83 | url=paper_url, 84 | ) 85 | self.save_metadata() 86 | 87 | def download_pdfs(self): 88 | """ 89 | Download PDF files for all papers in the collection. 90 | 91 | The PDFs are saved to {data_dir}/pdf/{pid}.pdf. 92 | """ 93 | pdf_dir = os.path.join(self.data_dir, "pdf") 94 | if not os.path.exists(pdf_dir): 95 | os.makedirs(pdf_dir) 96 | if os.path.exists(os.path.join(self.data_dir, "metadata.tsv")): 97 | path = os.path.join(self.data_dir, "metadata.tsv") 98 | df = pd.read_csv(path, sep="\t") 99 | for index, row in tqdm( 100 | df.iterrows(), total=len(df), desc="Downloading PDFs" 101 | ): 102 | pdf_url = row["pdf_url"] 103 | pid = row["pid"] 104 | pdf_path = os.path.join(pdf_dir, f"{pid}.pdf") 105 | try: 106 | response = requests.get(pdf_url, verify=False, timeout=30) 107 | with open(pdf_path, "wb") as f: 108 | f.write(response.content) 109 | except: 110 | print(f"Failed to download {pdf_url}") 111 | continue 112 | 113 | 114 | if __name__ == "__main__": 115 | collection = CVPRCollection(year=2024) 116 | if not os.path.exists(os.path.join(collection.data_dir, "metadata.tsv")): 117 | collection.collect() 118 | collection.download_pdfs() 119 | -------------------------------------------------------------------------------- /prompts/key_idea.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "role": "system", 4 | "content": "You are an expert in scientific research and can evaluate the semantic similarity between pairs of key-idea aspect summarizations from scientific papers. Key idea refers to the main intellectual merit of this paper, often in comparison to the context. This could normally be a novel idea or solution proposed in this paper that distincts it from what’s already done in literature. You will be provided with two inputs, and your task is to assign a similarity score on a scale from 1 to 10, where 1 indicates completely dissimilar content and 10 indicates identical content. Before scoring, provide a brief justification for your score. Structure your response in the following JSON format: \n{\n \"score\": score,\n \"justification\": \"Brief explanation for the score\"\n}\n" 5 | }, 6 | { 7 | "role": "user", 8 | "content": "Input 1:The authors create a web platform that supports the work of psychologists, streamlining the collection of relevant information about people with dementia.\nInput 2:Non-pharmacological interventions are the most common and the most effective for people with dementia. Using biographical or personalized materials has been proven to benefit these interventions." 9 | }, 10 | { 11 | "role": "assistant", 12 | "content": "{\n \"justification\": Both inputs are related to people with dementia. However, the first input talks about a web platform to help psychologists collect information, while the second input focuses on the effectiveness of non-pharmacological interventions in dementia care, particularly using biographical or personalized materials. They share a similar context but differ in specific focus and application.\", \n \"score\": 1\n}" 13 | }, 14 | { 15 | "role": "user", 16 | "content": "Input 1:The authors create a web platform that supports the work of psychologists, streamlining the collection of relevant information about people with dementia.\nInput 2:The paper focuses on non-pharmacological interventions for people with dementia, emphasizing the effectiveness of biographical or personalized materials." 17 | }, 18 | { 19 | "role": "assistant", 20 | "content": "{\n \"justification\":Both input discuss only one common key-point: people with dementia. Input 1 discusses a web platform to assist psychologists in collecting information about people with dementia. Input 2 focuses on non-pharmacological interventions for dementia, particularly biographical or personalized materials. While both relate to dementia, they address different aspects—one is about a technological tool for data collection, and the other about therapeutic methods. Their key ideas share some thematic overlap but are fundamentally different in focus and application.\", \n \"score\": 2\n}" 21 | }, 22 | { 23 | "role": "user", 24 | "content": "Input 1:The authors create a web platform that supports the work of psychologists, streamlining the collection of relevant information about people with dementia.\nInput 2:The web platform make improvements in the collection of meaningful data about a person, and on maintaining awareness of the therapy as a whole." 25 | }, 26 | { 27 | "role": "assistant", 28 | "content": "{\n \"justification\": Both inputs describe two main-points: creating a web platform and improving the collection of data for specific professional use. Input 1 specifies two more key-points: the platform supports psychologists working with people with dementia, while Input 2 emphasizes improvements in data collection and therapy awareness. The core idea of a web platform facilitating data collection is similar, but the specific focus areas slightly differ.\", \n \"score\": 5\n}" 29 | }, 30 | { 31 | "role": "user", 32 | "content": "Input 1:The authors create a web platform that supports the work of psychologists, streamlining the collection of relevant information about people with dementia.\nInput 2:The authors propose a web platform to collect and manage biographical materials in a streamlined manner, enabling reminiscence and other biographical cognitive stimulation practices for people with dementia." 33 | }, 34 | { 35 | "role": "assistant", 36 | "content": "{\n \"justification\"::Both inputs describe a web platform aimed at supporting the care of people with dementia by streamlining the collection of relevant information. While input 2 miss one key-point: the platform is supporting psychologists generally, leading to score of 8. \", \n \"score\": 8\n}" 37 | }, 38 | { 39 | "role": "user", 40 | "content": "Input 1:The authors create a web platform that supports the work of psychologists, streamlining the collection of relevant information about people with dementia.\nInput 2:The authors created a web platform to support psychologists by streamlining the collection of relevant biographical information about people with dementia, facilitating reminiscence and other biographical cognitive stimulation practices." 41 | }, 42 | { 43 | "role": "assistant", 44 | "content": "{\n \"justification\":Both inputs describe a web platform designed to assist psychologists by streamlining the collection of information about people with dementia. Input 2 provides a little additional detail by mentioning biographical information, but the core idea remains nearly identical. \", \n \"score\": 9\n}" 45 | } 46 | ] -------------------------------------------------------------------------------- /benchmark/aspect_prediction/prompts.py: -------------------------------------------------------------------------------- 1 | """ 2 | This program defines a series of functions that automate the generation. 3 | 4 | Each function corresponds to a specific aspect of research paper generation, \ 5 | such as idea generation, method\ 6 | recommendation, outcome prediction, future work recommendation, \ 7 | and title prediction. 8 | """ 9 | 10 | SYSTEM_PROMPT = """ 11 | You are an expert in research tasked with generating detailed prompts for \ 12 | various aspects of academic research papers. 13 | Each task involves creating a specific type of prompt \ 14 | based on the provided information. Here are the definitions of \ 15 | each part you will work with: 16 | 17 | - Concept 18 | - Definition 19 | - Relative Time 20 | 21 | - Context: The status quo of related literature or reality \ 22 | which motivated this study. 23 | This could normally be a problem, a research question, or a research gap \ 24 | that has not been successfully addressed by previous work. \ 25 | This is anything that happened before this study. 26 | 27 | - Key Idea: The main intellectual merit of this paper, \ 28 | often in comparison to the context. 29 | This could normally be a novel idea or solution proposed in this paper \ 30 | that distinguishes it from what's already done in literature. 31 | This is proposed in this study. 32 | 33 | - Method: The specific research method that investigates \ 34 | and validates the key idea. 35 | This could be an experimental setup, a theoretical framework, or \ 36 | other necessary methodology to implement and/or evaluate the key idea. 37 | This is performed in this study. 38 | 39 | - Outcome: The factual statement about the study output. 40 | This could be the experiment results and any other measurable \ 41 | outcome that has occurred. 42 | It marks whether the key hypothesis is testified or not. \ 43 | This is produced in this study. 44 | 45 | - Projected Impact: The author-anticipated impact of the work on the field, \ 46 | and potential further research identified by the author \ 47 | that may improve or extend this study. 48 | This is anything being anticipated but has not happened yet. 49 | """ 50 | 51 | 52 | def idea_generation(data): 53 | """ 54 | Generate a prompt for idea generation based on the provided context. 55 | 56 | Args: 57 | data (dict): Contains 'context' and 'key_idea' from the research data. 58 | 59 | Returns: 60 | tuple: A tuple containing the prompt and the ground truth for 61 | idea generation. 62 | """ 63 | context = data['context'] 64 | prompt = f"Given the context: '{context}', generate key ideas \ 65 | that could advance this area of study. " 66 | ground_truth = data['key_idea'] 67 | return prompt, ground_truth 68 | 69 | 70 | def method_recommendation(data): 71 | """ 72 | Recommend a method to validate a key idea. 73 | 74 | Args: 75 | data (dict): Contains 'context', 'key_idea', and 'method' 76 | from the research data. 77 | 78 | Returns: 79 | tuple: A tuple containing the prompt and the ground truth 80 | for method recommendation. 81 | """ 82 | context = data['context'] 83 | key_idea = data['key_idea'] 84 | prompt = f"Given the context: '{context}' and the key idea: '{key_idea}', \ 85 | recommend the most suitable method to validate this idea. " 86 | ground_truth = data['method'] 87 | return prompt, ground_truth 88 | 89 | 90 | def outcome_prediction(data): 91 | """ 92 | Predict the potential outcome of a research. 93 | 94 | Args: 95 | data (dict): Contains 'context', 'key_idea', 'method', and 'outcome'. 96 | 97 | Returns: 98 | tuple: A tuple containing the prompt and the ground truth 99 | for outcome prediction. 100 | """ 101 | context = data['context'] 102 | key_idea = data['key_idea'] 103 | method = data['method'] 104 | prompt = f"Based on the context: '{context}', the key idea: '{key_idea}', \ 105 | and the recommended method: '{method}', \ 106 | predict the potential outcome of this research. " 107 | ground_truth = data['outcome'] 108 | return prompt, ground_truth 109 | 110 | 111 | def future_work_recommendation(data): 112 | """ 113 | Suggest projected impact for the research. 114 | 115 | Args: 116 | data (dict): Contains 'context', 'key_idea', 'method', 'outcome', \ 117 | and 'future_impact' from the research data. 118 | 119 | Returns: 120 | tuple: A tuple containing the prompt and the ground truth 121 | for future work. 122 | """ 123 | context = data['context'] 124 | key_idea = data['key_idea'] 125 | method = data['method'] 126 | outcome = data['outcome'] 127 | prompt = f"Based on the context: '{context}', the key idea: '{key_idea}', \ 128 | the method: '{method}', and the outcome: '{outcome}', \ 129 | suggest projected impact for this research." 130 | ground_truth = data.get('future_impact', '') 131 | return prompt, ground_truth 132 | 133 | 134 | def predict_title(data): 135 | """ 136 | Predict the title of a research paper. 137 | 138 | Args: 139 | data (dict): Contains all necessary information from the research data. 140 | 141 | Returns: 142 | tuple: A tuple containing the prompt and the ground trut 143 | for title prediction. 144 | """ 145 | context = data['context'] 146 | key_idea = data['key_idea'] 147 | method = data['method'] 148 | outcome = data['outcome'] 149 | future_impact = data['future_impact'] 150 | prompt = f"Given the context: '{context}', the key idea: '{key_idea}', \ 151 | the method: '{method}', the outcome: '{outcome}', \ 152 | and the future impact: '{future_impact}', \ 153 | predict the title of this research paper. \ 154 | The title should be concise and reflective of the core aspects." 155 | ground_truth = data.get('title', '') 156 | return prompt, ground_truth 157 | -------------------------------------------------------------------------------- /prompts/context.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "role": "system", 4 | "content": "You are an expert in scientific research and can evaluate the semantic similarity between pairs of context aspect summarizations from scientific papers. The context refers to the status quo of related literature or reality which motivated the study. It typically includes a problem, a research question, or a research gap that has not been successfully addressed by previous work. You will be provided with two inputs, and your task is to assign a similarity score on a scale from 1 to 10, where 1 indicates completely dissimilar content and 10 indicates identical content. Before scoring, provide a brief justification for your score. Structure your response in the following JSON format: \n{\n \"score\": score,\n \"justification\": \"Brief explanation for the score\"\n}\n" 5 | }, 6 | { 7 | "role": "user", 8 | "content": "Input 1:Non-pharmacological interventions are the most common and the most effective for people with dementia. Using biographical or personalized materials has been proven to benefit these interventions, but it can be challenging to obtain and maintain such information.\nInput 2:The authors create a web platform that supports the work of psychologists, streamlining the collection of relevant information about people with dementia." 9 | }, 10 | { 11 | "role": "assistant", 12 | "content": "{\n \"justification\": \"The similarity between the two inputs is minimal. Both inputs only have one key point in common: people with dementia. Input 1 focuses on the effectiveness of non-pharmacological interventions, particularly those involving biographical or personalized materials, and mentions the challenges of maintaining such information. Input 2, on the other hand, talks about the creation of a web platform to streamline the collection of relevant information about people with dementia, which may indirectly relate to the challenges mentioned in Input 1 but does not directly address non-pharmacological interventions or their effectiveness, leading to a score of 1.\", \n \"score\": 1\n}" 13 | }, 14 | { 15 | "role": "user", 16 | "content": "Input 1:Non-pharmacological interventions are the most common and the most effective for people with dementia. Using biographical or personalized materials has been proven to benefit these interventions, but it can be challenging to obtain and maintain such information.\nInput 2:Non-pharmacological interventions are the most common and arguably most effective for people with dementia." 17 | }, 18 | { 19 | "role": "assistant", 20 | "content": "{\n \"justification\": Both inputs discuss two key points: non-pharmacological interventions and its effectiveness for people with dementia. However, Input 1 goes into more detail by mentioning the use of biographical or personalized materials and the challenges associated with obtaining and maintaining them. Input 2 does not include this additional information, leading to score of 4. \", \n \"score\": 4\n}" 21 | }, 22 | { 23 | "role": "user", 24 | "content": "Input 1:Non-pharmacological interventions are the most common and the most effective for people with dementia. Using biographical or personalized materials has been proven to benefit these interventions, but it can be challenging to obtain and maintain such information.\nInput 2:Non-pharmacological interventions, particularly those utilizing biographical or personalized materials, are effective for people with dementia." 25 | }, 26 | { 27 | "role": "assistant", 28 | "content": "{\n \"justification\": Both inputs discuss three key points: non-pharmacological interventions, its effectiveness for people with dementia and the use of biographical or personalized materials. However, Input 1 goes further to highlight two more key points: the challenge of obtaining and maintaining such information, leading to score of 6.\", \n \"score\": 6\n}" 29 | }, 30 | { 31 | "role": "user", 32 | "content": "Input 1:Non-pharmacological interventions are the most common and the most effective for people with dementia. Using biographical or personalized materials has been proven to benefit these interventions, but it can be challenging to obtain and maintain such information.\nInput 2:Some interventions effective for people with dementia have been proven to benefit from the use of biographical or personalized materials. However, these materials are not always easy to obtain and can be challenging to maintain." 33 | }, 34 | { 35 | "role": "assistant", 36 | "content": "{\n \"justification\":Both inputs discuss the effectiveness of interventions for people with dementia and highlight the benefits of using biographical or personalized materials. They also both note the challenges associated with obtaining and maintaining such information. The only difference is input 2 miss one key-point:non-pharmacological interventions, leading to score of 8. \", \n \"score\": 8\n}" 37 | }, 38 | { 39 | "role": "user", 40 | "content": "Input 1:Non-pharmacological interventions are the most common and the most effective for people with dementia. Using biographical or personalized materials has been proven to benefit these interventions, but it can be challenging to obtain and maintain such information.\nInput 2:Non-pharmacological interventions, such as reminiscence and biographical cognitive stimulation practices, are common and effective for people with dementia. However, obtaining and maintaining biographical or personalized materials can be challenging." 41 | }, 42 | { 43 | "role": "assistant", 44 | "content": "{\n \"justification\":Both inputs discuss the effectiveness of non-pharmacological interventions for people with dementia, specifically mentioning the use of biographical or personalized materials. They also both highlight the challenge of obtaining and maintaining such information. The semantic content is nearly identical lead to score of 10. \", \n \"score\": 10\n}" 45 | } 46 | ] -------------------------------------------------------------------------------- /benchmark/aspect_prediction/task.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script facilitates the automated benchmarking of research prompts. 3 | 4 | The script supports multiple models and prompt types 5 | and is designed to work with large sets 6 | of test data asynchronously. 7 | """ 8 | 9 | import argparse 10 | import os 11 | import sys 12 | 13 | import jsonlines as jl 14 | from prompts import (SYSTEM_PROMPT, future_work_recommendation, 15 | idea_generation, method_recommendation, 16 | outcome_prediction, predict_title) 17 | from utils import (MODEL_CHOICES, PROMPT_CHOICES, allow_self_signed_https, 18 | load_examples, save_results) 19 | 20 | from massw.models import gpt_azure, mixtral_azure 21 | 22 | sys.path.append("../..") 23 | 24 | allow_self_signed_https(True) 25 | 26 | few_shot_examples, cot_examples = load_examples() 27 | 28 | 29 | def prepare_messages(model, task_name, prompt_type, main_prompt): 30 | """Prepare the messages based on the task and prompt type.""" 31 | messages = [{"role": "system", "content": SYSTEM_PROMPT}] 32 | if model == "mixtral-8x7b": 33 | format_instruction = "\nNever use double quotes in your output. \ 34 | Use single quotes instead.\n" 35 | 36 | messages = [{ 37 | "role": "user", 38 | "content": SYSTEM_PROMPT + format_instruction 39 | }, { 40 | "role": 41 | "assistant", 42 | "content": 43 | "I got it. \ 44 | Please give me further instructions!" 45 | }] 46 | 47 | if prompt_type in ["few-shot", "few-shot-cot"]: 48 | examples = few_shot_examples if prompt_type == "few-shot"\ 49 | else cot_examples 50 | for example in examples.get(task_name, []): 51 | messages.extend([{ 52 | "role": "user", 53 | "content": example["user"] 54 | }, { 55 | "role": "assistant", 56 | "content": example["assistant"] 57 | }]) 58 | 59 | if prompt_type == "chain-of-thought": 60 | main_prompt += "Let's think step by step. \ 61 | You should first present you reasoning. \ 62 | After that, the final prediction should start after \ 63 | the marker 'Prediction:'." 64 | 65 | messages.append({"role": "user", "content": main_prompt}) 66 | 67 | return messages 68 | 69 | 70 | def process_task(generate_prompt_fn, test_cases, task_name, **kwargs): 71 | """Process the tasks and retrieve chat completions.""" 72 | messages = [] 73 | for entry in test_cases: 74 | main_prompt, _ = generate_prompt_fn(entry) 75 | message = prepare_messages(kwargs['model'], task_name, 76 | kwargs['prompt_type'], main_prompt) 77 | messages.append((entry['pid'], message)) 78 | 79 | model = kwargs['model'] 80 | if model == "mixtral-8x7b": 81 | chat_results = mixtral_azure.prompts_to_raw_output(messages) 82 | elif model in ["gpt-35-turbo", "gpt-4"]: 83 | chat_results = gpt_azure.prompts_to_raw_output(messages, model, 84 | kwargs.get('tpm')) 85 | else: 86 | raise ValueError(f"Model {model} not supported. \ 87 | You can modify the code here \ 88 | to support custom models.") 89 | 90 | return chat_results 91 | 92 | 93 | def main(): 94 | """Execute main function to process tasks.""" 95 | parser = argparse.ArgumentParser(description="Process benchmarking \ 96 | of academic paper prompts.") 97 | parser.add_argument("--test_data", 98 | type=str, 99 | default="data/benchmark_0531.jsonl") 100 | parser.add_argument("--output_dir", type=str, default=False) 101 | parser.add_argument("--model", type=str, default="gpt-35-turbo") 102 | parser.add_argument("--prompt", type=str, default="zero-shot") 103 | parser.add_argument("--num_samples", type=int, default=5) 104 | args = parser.parse_args() 105 | 106 | if args.model not in MODEL_CHOICES: 107 | raise ValueError(f"Model {args.model} not supported. \ 108 | Choose from {MODEL_CHOICES}") 109 | 110 | if args.prompt not in PROMPT_CHOICES: 111 | raise ValueError(f"Prompt type {args.prompt} not supported. \ 112 | Choose from {PROMPT_CHOICES}") 113 | 114 | if not args.output_dir: 115 | args.output_dir = os.path.join("benchmark", "aspect_prediction", 116 | "outputs", 117 | f"{args.model}_{args.prompt}") 118 | 119 | # Load test data 120 | with jl.open(args.test_data) as file: 121 | test_data = [ 122 | record for record, _ in zip(file, range(args.num_samples)) 123 | ] 124 | 125 | tasks = [("idea_generation", idea_generation), 126 | ("method_recommendation", method_recommendation), 127 | ("outcome_prediction", outcome_prediction), 128 | ("future_work_recommendation", future_work_recommendation), 129 | ("title_prediction", predict_title)] 130 | 131 | tokens_per_minute = { 132 | "gpt-35-turbo": 40000, 133 | "gpt-4": 10000, 134 | "mixtral-8x7b": None 135 | } 136 | 137 | for task_name, generate_prompt_fn in tasks: 138 | print(f"Processing task: {task_name}") 139 | chat_results = process_task(generate_prompt_fn, 140 | test_data, 141 | task_name, 142 | model=args.model, 143 | prompt_type=args.prompt, 144 | tpm=tokens_per_minute[args.model]) 145 | print(f"{chat_results = }") 146 | save_results(chat_results, args.output_dir, task_name) 147 | 148 | 149 | if __name__ == "__main__": 150 | main() 151 | -------------------------------------------------------------------------------- /massw/data.py: -------------------------------------------------------------------------------- 1 | """Data loading and processing functions.""" 2 | 3 | import os 4 | from dataclasses import dataclass 5 | from typing import List, Union 6 | 7 | import jsonlines as jl 8 | import numpy as np 9 | import pandas as pd 10 | 11 | from massw.download import download_dataset 12 | 13 | PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 14 | DATA_DIR = os.path.join(PROJECT_DIR, "data") 15 | 16 | 17 | @dataclass 18 | class Workflow: 19 | """Workflow dataclass.""" 20 | 21 | id: str # Unique ID 22 | title: str # Title of the publication 23 | abstract: str # Abstract of the publication 24 | year: int # Year of publication 25 | venue: str # Publication venue 26 | context: str # Context of the workflow 27 | key_idea: str # Key idea of the workflow 28 | method: str # Method used in the workflow 29 | outcome: str # Outcome of the workflow 30 | projected_impact: str # Projected impact of the workflow 31 | data: dict = None # Additional data 32 | 33 | 34 | class MASSWDataset: 35 | """MASSW dataset class.""" 36 | 37 | def __init__(self, data: pd.DataFrame, metadata: pd.DataFrame): 38 | """Initialize the dataset. 39 | 40 | Args: 41 | data (pd.DataFrame): The data containing the workflows. 42 | metadata (pd.DataFrame): The metadata containing the workflow 43 | information. 44 | """ 45 | self.merged_data = data.join(metadata.set_index("id"), on="id") 46 | self.merged_data = self.merged_data.set_index("id") 47 | self.merged_data = self.merged_data.fillna(np.nan).replace([np.nan], 48 | [None]) 49 | 50 | def __len__(self): 51 | """Return the number of workflows in the dataset.""" 52 | return len(self.merged_data) 53 | 54 | def _get_by_position(self, pos: int) -> Workflow: 55 | """Return a workflow by its position. 56 | 57 | Args: 58 | pos (int): The position of the workflow. 59 | 60 | Returns: 61 | Workflow: The workflow object. 62 | """ 63 | row = self.merged_data.iloc[pos] 64 | return Workflow(id=self.merged_data.index[pos], 65 | title=row["title"], 66 | abstract=row["abstract"], 67 | year=row["year"], 68 | venue=row["venue"], 69 | context=row["context"], 70 | key_idea=row["key_idea"], 71 | method=row["method"], 72 | outcome=row["outcome"], 73 | projected_impact=row["projected_impact"], 74 | data=row["data"]) 75 | 76 | def _get_by_unique_id(self, id_my: str) -> Workflow: 77 | """Return a workflow by its unique ID. 78 | 79 | Args: 80 | id_my (str): The unique ID of the workflow. 81 | 82 | Returns: 83 | Workflow: The workflow object. 84 | """ 85 | row = self.merged_data.loc[id_my] 86 | return Workflow(id=id_my, 87 | title=row["title"], 88 | abstract=row["abstract"], 89 | year=row["year"], 90 | venue=row["venue"], 91 | context=row["context"], 92 | key_idea=row["key_idea"], 93 | method=row["method"], 94 | outcome=row["outcome"], 95 | projected_impact=row["projected_impact"], 96 | data=row["data"]) 97 | 98 | def __getitem__(self, key): 99 | """Return a workflow by its unique ID or position.""" 100 | if isinstance(key, int): 101 | return self._get_by_position(key) 102 | if isinstance(key, str): 103 | return self._get_by_unique_id(key) 104 | raise TypeError("Invalid key type.") 105 | 106 | def __iter__(self): 107 | """Return an iterator over the workflows.""" 108 | for idx in range(len(self)): 109 | yield self[idx] 110 | 111 | def __repr__(self): 112 | """Return a string representation of the dataset.""" 113 | return f"MASSWDataset({len(self)} workflows)" 114 | 115 | def search(self, 116 | query: str, 117 | return_ids=False) -> Union[List[Workflow], List[str]]: 118 | """Search for workflows containing a query string by title. 119 | 120 | Args: 121 | query (str): The query string to search for. 122 | return_ids (bool): Whether to return the IDs of the workflows. 123 | If true, returns a list of IDs. 124 | Otherwise, returns a list of Workflows. 125 | 126 | Returns: 127 | list: A list of workflows or IDs containing the query string. 128 | """ 129 | mask = self.merged_data["title"].str.contains(query, 130 | case=False, 131 | na=False) 132 | ids = mask[mask].index 133 | if return_ids: 134 | return ids.tolist() 135 | return [self._get_by_unique_id(id) for id in ids] 136 | 137 | 138 | def load_massw(version: str = "v1") -> MASSWDataset: 139 | """Load the massw dataset. 140 | 141 | Args: 142 | version (str): The version of the dataset to load. 143 | 144 | Returns: 145 | MASSWDataset: The MASSW dataset object. 146 | """ 147 | data_path = os.path.join(DATA_DIR, f"massw_{version}.tsv") 148 | metadata_path = os.path.join(DATA_DIR, f"massw_metadata_{version}.jsonl") 149 | if not os.path.exists(data_path) or not os.path.exists(metadata_path): 150 | download_dataset(version) 151 | data = pd.read_csv(data_path, sep="\t") 152 | metadata = [] 153 | with jl.open(metadata_path) as f: 154 | for line in f: 155 | metadata.append(line) 156 | metadata = pd.DataFrame(metadata) 157 | return MASSWDataset(data, metadata) 158 | -------------------------------------------------------------------------------- /factscore/download_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | import torch 5 | import tqdm 6 | import transformers 7 | 8 | 9 | def download_file(_id, dest, cache_dir): 10 | if os.path.exists(dest) or os.path.exists(os.path.join(cache_dir, dest)): 11 | print ("[Already exists] Skipping", dest) 12 | print ("If you want to download the file in another location, please specify a different path") 13 | return 14 | 15 | if os.path.exists(dest.replace(".zip", "")) or os.path.exists(os.path.join(cache_dir, dest.replace(".zip", ""))): 16 | print ("[Already exists] Skipping", dest) 17 | print ("If you want to download the file in another location, please specify a different path") 18 | return 19 | 20 | if "/" in dest: 21 | dest_dir = "/".join(dest.split("/")[:-1]) 22 | if not os.path.isdir(dest_dir): 23 | os.makedirs(dest_dir) 24 | else: 25 | dest_dir = "." 26 | 27 | if _id.startswith("https://"): 28 | command = """wget -O %s %s""" % (dest, _id) 29 | else: 30 | command = """wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=%s' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id=%s" -O %s && rm -rf /tmp/cookies.txt""" % (_id, _id, dest) 31 | 32 | ret_code = subprocess.run([command], shell=True) 33 | if ret_code.returncode != 0: 34 | print("Download {} ... [Failed]".format(dest)) 35 | else: 36 | print("Download {} ... [Success]".format(dest)) 37 | 38 | if dest.endswith(".zip"): 39 | command = """unzip %s -d %s && rm %s""" % (dest, dest_dir, dest) 40 | 41 | ret_code = subprocess.run([command], shell=True) 42 | if ret_code.returncode != 0: 43 | print("Unzip {} ... [Failed]".format(dest)) 44 | else: 45 | print("Unzip {} ... [Success]".format(dest)) 46 | 47 | 48 | 49 | def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model): 50 | """Resize tokenizer and embedding. 51 | 52 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 53 | """ 54 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 55 | model.resize_token_embeddings(len(tokenizer)) 56 | 57 | if num_new_tokens > 0: 58 | input_embeddings = model.get_input_embeddings().weight.data 59 | output_embeddings = model.get_output_embeddings().weight.data 60 | 61 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 62 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 63 | 64 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 65 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 66 | 67 | 68 | def recover_instruct_llama(path_raw, output_path, device="cpu", test_recovered_model=False): 69 | """Heavily adapted from https://github.com/tatsu-lab/stanford_alpaca/blob/main/weight_diff.py.""" 70 | 71 | model_raw = transformers.AutoModelForCausalLM.from_pretrained( 72 | path_raw, 73 | device_map={"": torch.device(device)}, 74 | torch_dtype=torch.float32, 75 | low_cpu_mem_usage=True, 76 | ) 77 | model_recovered = transformers.AutoModelForCausalLM.from_pretrained( 78 | "kalpeshk2011/instruct-llama-7b-wdiff", 79 | device_map={"": torch.device(device)}, 80 | torch_dtype=torch.float32, 81 | low_cpu_mem_usage=True, 82 | ) 83 | 84 | tokenizer_raw = transformers.AutoTokenizer.from_pretrained(path_raw) 85 | if tokenizer_raw.pad_token is None: 86 | smart_tokenizer_and_embedding_resize( 87 | special_tokens_dict=dict(pad_token="[PAD]"), 88 | model=model_raw, 89 | tokenizer=tokenizer_raw, 90 | ) 91 | tokenizer_recovered = transformers.AutoTokenizer.from_pretrained("kalpeshk2011/instruct-llama-7b-wdiff") 92 | 93 | state_dict_recovered = model_recovered.state_dict() 94 | state_dict_raw = model_raw.state_dict() 95 | for key in tqdm.tqdm(state_dict_recovered): 96 | state_dict_recovered[key].add_(state_dict_raw[key]) 97 | 98 | if output_path is not None: 99 | model_recovered.save_pretrained(output_path) 100 | tokenizer_recovered.save_pretrained(output_path) 101 | 102 | if test_recovered_model: 103 | input_text = ( 104 | "Below is an instruction that describes a task. " 105 | "Write a response that appropriately completes the request.\r\n\r\n" 106 | "### Instruction:\r\nList three technologies that make life easier.\r\n\r\n### Response:" 107 | ) 108 | inputs = tokenizer_recovered(input_text, return_tensors="pt") 109 | out = model_recovered.generate(inputs=inputs.input_ids, max_new_tokens=100) 110 | output_text = tokenizer_recovered.batch_decode(out, skip_special_tokens=True)[0] 111 | output_text = output_text[len(input_text) :] 112 | print(f"Input: {input_text}\nCompletion: {output_text}") 113 | 114 | return model_recovered, tokenizer_recovered 115 | 116 | if __name__ == '__main__': 117 | 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument('--data_dir', 120 | type=str, 121 | default=".cache/factscore") 122 | parser.add_argument('--model_dir', 123 | type=str, 124 | default=".cache/factscore") 125 | parser.add_argument('--llama_7B_HF_path', 126 | type=str, 127 | default=None) 128 | 129 | args = parser.parse_args() 130 | 131 | if not os.path.exists(args.model_dir): 132 | os.makedirs(args.model_dir) 133 | 134 | if not os.path.exists(args.data_dir): 135 | os.makedirs(args.data_dir) 136 | 137 | download_file("1IseEAflk1qqV0z64eM60Fs3dTgnbgiyt", "demos.zip", args.data_dir) 138 | download_file("1enz1PxwxeMr4FRF9dtpCPXaZQCBejuVF", "data.zip", args.data_dir) 139 | download_file("1mekls6OGOKLmt7gYtHs0WGf5oTamTNat", "enwiki-20230401.db", args.data_dir) 140 | 141 | if args.llama_7B_HF_path: 142 | recover_instruct_llama(args.llama_7B_HF_path, os.path.join(args.model_dir, "inst-llama-7B")) 143 | 144 | # download the roberta_stopwords.txt file 145 | subprocess.run(["wget https://raw.githubusercontent.com/shmsw25/FActScore/main/roberta_stopwords.txt"], shell=True) 146 | 147 | # move the files to the data directory 148 | subprocess.run(["mv demos %s" % args.data_dir], shell=True) 149 | subprocess.run(["mv enwiki-20230401.db %s" % args.data_dir], shell=True) 150 | 151 | -------------------------------------------------------------------------------- /factscore.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import hashlib\n", 10 | "import json\n", 11 | "import os\n", 12 | "import sys\n", 13 | "\n", 14 | "import pandas as pd\n", 15 | "\n", 16 | "import factscore\n", 17 | "import factscore.factscorer\n", 18 | "\n", 19 | "sys.path.append(\"benchmark/aspect_prediction\")\n", 20 | "from utils import TASK2GT, TASK_NAMES, postprocess_cot\n", 21 | "from massw.models import gpt_azure, mixtral_azure\n", 22 | "\n", 23 | "\n", 24 | "def evaluate_factscore(generations, contexts):\n", 25 | " hash_key = hashlib.md5(\"\\n\".join(generations + contexts).encode()).hexdigest()\n", 26 | " cache_dir = f\"./cache/factscore/{hash_key}\"\n", 27 | " fs = factscore.factscorer.FactScorer(cache_dir=cache_dir)\n", 28 | " dummy_topics = [None] * len(generations)\n", 29 | " scores = fs.get_score(dummy_topics, generations, contexts, verbose=False, gamma=0)\n", 30 | " return scores\n", 31 | "\n", 32 | "\n", 33 | "DATA_FOLDER = \"../massw/data\"" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "\"\"\"Evaluate experiment results based on the model generated output (file).\"\"\"\n", 43 | "model_output_base_dir = \"../massw/benchmark/aspect_prediction/outputs\"\n", 44 | "reference_path = f\"{DATA_FOLDER}/benchmark_0531.jsonl\"\n", 45 | "\n", 46 | "# Read reference data\n", 47 | "df_ref = pd.read_json(reference_path, lines=True)\n", 48 | "sample_ids = df_ref[\"pid\"][:100].tolist()\n", 49 | "\n", 50 | "\n", 51 | "def postprocess_output(\n", 52 | " model_output_dir, reference_path, used_cot=False, model_type=\"gpt\"\n", 53 | "):\n", 54 | " \"\"\"\n", 55 | " Process model output files to match predictions with references.\n", 56 | "\n", 57 | " Args:\n", 58 | " model_output_dir (str): Directory containing the output files.\n", 59 | " reference_path (str): Path to the file containing reference data.\n", 60 | " used_cot (bool): Flag to determine if COT processing is needed.\n", 61 | " model_type (str): Type of model used to adjust processing logic.\n", 62 | "\n", 63 | " Returns:\n", 64 | " dict: A dictionary containing predictions and references by task.\n", 65 | " \"\"\"\n", 66 | " results = {}\n", 67 | " with open(reference_path, \"r\", encoding=\"utf-8\") as f:\n", 68 | " references = [json.loads(line) for line in f]\n", 69 | " id2ref = {r[\"pid\"]: r for r in references}\n", 70 | "\n", 71 | " for task_name in TASK_NAMES:\n", 72 | " gt_name = TASK2GT[task_name]\n", 73 | " model_path = f\"{model_output_dir}/{task_name}.tsv\"\n", 74 | "\n", 75 | " if model_type == \"gpt\":\n", 76 | " id2predictions = gpt_azure.raw_output_to_dict(model_path)\n", 77 | " elif model_type == \"mixtral\":\n", 78 | " id2predictions = mixtral_azure.raw_output_to_dict(model_path)\n", 79 | " else:\n", 80 | " raise ValueError(f\"Model type {model_type} not supported.\")\n", 81 | "\n", 82 | " if used_cot:\n", 83 | " for pid in id2predictions:\n", 84 | " try:\n", 85 | " id2predictions[pid] = postprocess_cot(id2predictions[pid])\n", 86 | " except Exception as e:\n", 87 | " print(f\"Error processing COT for {pid}: {e}\")\n", 88 | " continue\n", 89 | "\n", 90 | " id2predictions = {\n", 91 | " pid: id2predictions[pid] for pid in id2predictions if pid in sample_ids\n", 92 | " }\n", 93 | "\n", 94 | " results[task_name] = {\n", 95 | " \"predictions\": list(id2predictions.values()),\n", 96 | " \"references\": [\n", 97 | " id2ref[pid][gt_name] for pid in id2ref.keys() if pid in id2predictions\n", 98 | " ],\n", 99 | " }\n", 100 | " assert len(results[task_name][\"predictions\"]) == len(\n", 101 | " results[task_name][\"references\"]\n", 102 | " )\n", 103 | "\n", 104 | " return results" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": null, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "# List all directories in the model output base directory\n", 114 | "dump_folder = \"./results/\"\n", 115 | "model_output_dirs = os.listdir(model_output_base_dir)\n", 116 | "for d in model_output_dirs:\n", 117 | " model_output_path = f\"{model_output_base_dir}/{d}\"\n", 118 | " print(f\"Processing {model_output_path}\")\n", 119 | " used_cot = \"chain\" in d\n", 120 | " model_type = \"gpt\" if \"gpt\" in d else \"mixtral\"\n", 121 | " results = postprocess_output(model_output_path, reference_path, used_cot, model_type)\n", 122 | " print(f\"Model: {d}\")\n", 123 | " for task_name in results:\n", 124 | " predictions = results[task_name][\"predictions\"]\n", 125 | " references = results[task_name][\"references\"]\n", 126 | " print(f\"Task: {task_name}, predictions: {len(predictions)}, references: {len(references)}\")\n", 127 | " scores = evaluate_factscore(predictions, references)\n", 128 | " with open(f\"{dump_folder}/benchmark/{d}_{task_name}_fs.json\", \"w\") as f:\n", 129 | " json.dump(scores, f)\n" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "benchmark_factscore_dir = \"results/benchmark\"\n", 139 | "\n", 140 | "benchmark_factscore = {}\n", 141 | "for name in model_output_dirs:\n", 142 | " benchmark_factscore[name] = {}\n", 143 | " for task_name in TASK_NAMES:\n", 144 | " dump_path = f\"{benchmark_factscore_dir}/{name}_{task_name}_fs.json\"\n", 145 | " with open(dump_path, \"r\") as f:\n", 146 | " scores = json.load(f)\n", 147 | " print(f\"{name} {task_name}: {scores['score']:0.3f}\")\n", 148 | " benchmark_factscore[name][task_name] = scores[\"score\"]" 149 | ] 150 | } 151 | ], 152 | "metadata": { 153 | "kernelspec": { 154 | "display_name": "llm", 155 | "language": "python", 156 | "name": "python3" 157 | }, 158 | "language_info": { 159 | "codemirror_mode": { 160 | "name": "ipython", 161 | "version": 3 162 | }, 163 | "file_extension": ".py", 164 | "mimetype": "text/x-python", 165 | "name": "python", 166 | "nbconvert_exporter": "python", 167 | "pygments_lexer": "ipython3", 168 | "version": "3.9.19" 169 | } 170 | }, 171 | "nbformat": 4, 172 | "nbformat_minor": 2 173 | } 174 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /factscore/npm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import time 4 | from collections import defaultdict 5 | from transformers import AutoModelForMaskedLM, AutoTokenizer 6 | 7 | from factscore.lm import LM 8 | from factscore.retrieval import Retrieval 9 | 10 | def softmax(x): 11 | return(np.exp(x - np.max(x)) / np.exp(x - np.max(x)).sum()) 12 | 13 | class NPM(LM): 14 | 15 | def __init__(self, bm25, model_name, cache_file): 16 | assert model_name.startswith("npm") 17 | self.bm25 = bm25 18 | self.model_name = model_name 19 | self.model = None 20 | 21 | self.tokenizer = AutoTokenizer.from_pretrained("facebook/" + self.model_name) 22 | self.mask_id = self.tokenizer.mask_token_id 23 | 24 | with open("roberta_stopwords.txt", "r") as f: 25 | self.stopwords = set() 26 | for line in f: 27 | self.stopwords.add(int(line.strip())) 28 | 29 | super().__init__(cache_file=cache_file) 30 | 31 | def load_model(self): 32 | self.model = AutoModelForMaskedLM.from_pretrained("facebook/" + self.model_name) 33 | self.model.cuda() 34 | self.model.eval() 35 | 36 | def save_cache(self): 37 | super().save_cache() 38 | self.bm25.save_cache() 39 | 40 | def tokenize(self, texts, skip_special_tokens=False, padding=True): 41 | assert type(texts)==list 42 | all_input_ids = self.tokenizer(texts)["input_ids"] 43 | if skip_special_tokens: 44 | for i, input_ids in enumerate(all_input_ids): 45 | assert input_ids[0]==0 and input_ids[-1]==2 46 | all_input_ids[i] = input_ids[1:-1] 47 | if not padding: 48 | return all_input_ids 49 | max_length = np.max([len(_ids) for _ids in all_input_ids]) 50 | _all_input_ids = [] 51 | _all_attention_mask = [] 52 | for i, input_ids in enumerate(all_input_ids): 53 | n_valid = len(input_ids) 54 | n_masks = max_length - n_valid 55 | _all_input_ids.append(input_ids + [0 for _ in range(n_masks)]) 56 | _all_attention_mask.append([1 for _ in range(n_valid)] + [0 for _ in range(n_masks)]) 57 | return torch.LongTensor(_all_input_ids), torch.LongTensor(_all_attention_mask) 58 | 59 | def decode(self, input_ids): 60 | return self.tokenizer.decode(input_ids) 61 | 62 | def encode(self, texts, skip_special_tokens=False, gt_input_ids=None): 63 | assert type(texts)==list 64 | if self.model is None: 65 | self.load_model() 66 | if gt_input_ids is not None: 67 | assert len(texts)==len(gt_input_ids) 68 | all_input_ids, all_attention_mask = self.tokenize(texts, skip_special_tokens=skip_special_tokens) 69 | 70 | with torch.no_grad(): 71 | outputs = self.model(all_input_ids.cuda(), 72 | all_attention_mask.cuda(), 73 | output_hidden_states=True, 74 | return_dict=True) 75 | all_logits = outputs["logits"].detach().cpu().numpy() 76 | all_hidden_states = outputs["hidden_states"][-1].detach().cpu().numpy() 77 | 78 | results = [] 79 | for i, (text, input_ids, logits, hidden_states) in enumerate(zip(texts, all_input_ids, all_logits, all_hidden_states)): 80 | input_ids = input_ids.numpy().tolist() 81 | if self.mask_id in input_ids: 82 | idx = input_ids.index(self.mask_id) 83 | assert gt_input_ids is not None 84 | prob = softmax(logits[idx])[gt_input_ids[i]] 85 | results.append((prob, hidden_states[idx])) 86 | else: 87 | _input_ids = [_id for _id in input_ids if _id not in [0, 2]] 88 | _hidden_states = [h for _id, h in zip(input_ids, hidden_states) if _id not in [0, 2]] 89 | results.append((_input_ids, _hidden_states)) 90 | 91 | return results 92 | 93 | def get_probabilty(self, topic, question): 94 | passages = self.bm25.get_passages(topic, question, k=3) 95 | passages = [p["text"].strip() for p in passages] 96 | cache_key = question + "#" + "#".join(passages) 97 | 98 | if cache_key not in self.cache_dict: 99 | encoded = self.encode(passages, skip_special_tokens=True) 100 | stacked_passage_tokens, stacked_passage_vectors = [], [] 101 | for input_ids, vectors in encoded: 102 | stacked_passage_tokens += input_ids 103 | if len(vectors)>0: 104 | stacked_passage_vectors.append(vectors) 105 | stacked_passage_vectors = np.concatenate(stacked_passage_vectors, 0) 106 | 107 | question_input_ids = self.tokenize(["Fact: " + question], skip_special_tokens=False, padding=False)[0] 108 | if 2 in question_input_ids: 109 | question_input_ids = question_input_ids[:question_input_ids.index(2)] 110 | question_input_ids = question_input_ids[1:] 111 | 112 | ''' 113 | triples = [] 114 | prefix = True 115 | for i, input_id in enumerate(question_input_ids): 116 | if prefix: 117 | if input_id==35: # the end of prefix 118 | prefix = False 119 | continue 120 | if input_id in [0, 2] or input_id in self.stopwords: 121 | continue 122 | new_question = self.decode(question_input_ids[:i] + [self.mask_id] + question_input_ids[i+1:]) 123 | prob, vector = self.encode(new_question, gt_input_id=input_id) 124 | triples.append((prob, vector, input_id)) 125 | ''' 126 | triples = [] 127 | batch = [] 128 | gt_input_ids = [] 129 | prefix = True 130 | for i, input_id in enumerate(question_input_ids): 131 | if prefix: 132 | if input_id==35: # the end of prefix 133 | prefix = False 134 | continue 135 | if input_id in [0, 2] or input_id in self.stopwords: 136 | continue 137 | batch.append(self.decode(question_input_ids[:i] + [self.mask_id] + question_input_ids[i+1:])) 138 | gt_input_ids.append(input_id) 139 | for (prob, vector), gt_input_id in zip(self.encode(batch, gt_input_ids=gt_input_ids), gt_input_ids): 140 | triples.append((prob, vector, gt_input_id)) 141 | 142 | stacked_question_vectors = np.stack([v for _, v, _ in triples], 0) 143 | all_scores = np.exp(np.inner(stacked_question_vectors, stacked_passage_vectors) / np.sqrt(stacked_passage_vectors.shape[-1])) 144 | 145 | probs = [] 146 | for (softmax_prob, vector, input_id), scores in zip(triples, all_scores): 147 | assert len(stacked_passage_tokens)==len(scores) 148 | if input_id not in stacked_passage_tokens: 149 | probs.append(0) 150 | else: 151 | aggregated_scores = defaultdict(list) 152 | for token, score in zip(stacked_passage_tokens, scores): 153 | aggregated_scores[token].append(score) 154 | tot = np.sum([np.sum(v) for v in aggregated_scores.values()]) 155 | prob = np.sum(aggregated_scores[input_id]) / tot 156 | probs.append(prob) 157 | 158 | self.cache_dict[cache_key] = np.mean(probs) 159 | self.add_n += 1 160 | 161 | return self.cache_dict[cache_key] 162 | 163 | 164 | 165 | -------------------------------------------------------------------------------- /massw/collection/aaai/script.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import requests 4 | import pandas as pd 5 | from bs4 import BeautifulSoup 6 | from tqdm import tqdm 7 | 8 | from massw.collection import BaseCollection 9 | 10 | 11 | def get_aaai_track_urls(archive_url, year: int): 12 | """ 13 | Retrieves AAAI track URLs for a given year from the specified archive URL. 14 | 15 | Args: 16 | archive_url (str): The base URL of the AAAI archive where technical tracks are stored. 17 | year (int): The year of the conference for which track URLs need to be extracted. 18 | 19 | Returns: 20 | list: A list of URLs corresponding to different tracks for the given year. 21 | 22 | Caveat: 23 | - The technical tracks for a year in AAAI conference may span across multiple pages, so handling pagination is necessary 24 | to ensure all track URLs are retrieved. The solution is simply adding a loop at the beginning of the collect part to get all urls. 25 | """ 26 | headers = { 27 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " 28 | "AppleWebKit/537.36 (KHTML, like Gecko) " 29 | "Chrome/112.0.0.0 Safari/537.36" 30 | } 31 | 32 | print(f"Fetching archive page: {archive_url}") 33 | response = requests.get(archive_url, headers=headers) 34 | if response.status_code != 200: 35 | print(f"Failed to fetch the archive page. Status code: {response.status_code}") 36 | return [] 37 | soup = BeautifulSoup(response.text, "html.parser") 38 | 39 | track_info = [] 40 | issue_summaries = soup.select(".obj_issue_summary") 41 | 42 | print(f"Found {len(issue_summaries)} issue summaries") 43 | conference_tag = f"AAAI-{str(year)[-2:]}" 44 | 45 | for summary in issue_summaries: 46 | title_element = summary.select_one("a.title") 47 | if title_element: 48 | title = title_element.text.strip() 49 | url = title_element["href"] 50 | if conference_tag in title: 51 | track_info.append(url) 52 | 53 | print(f"Extracted {len(track_info)} track URLs.") 54 | return track_info 55 | 56 | 57 | class AAAICollection(BaseCollection): 58 | def __init__( 59 | self, 60 | year: int, 61 | base_url: str = "https://ojs.aaai.org/index.php/AAAI/issue/archive", 62 | ): ## all 2024 technical tracks for AAAI conference is in this url, maybe be updated if AAAI conference releases more papers. 63 | super().__init__(year, "aaai") 64 | self.data_dir = os.path.join(os.path.dirname(__file__), f"data_{year}") 65 | self.base_url = base_url 66 | 67 | def collect(self): 68 | track_urls = get_aaai_track_urls(self.base_url, self.year) 69 | 70 | if not track_urls: 71 | print("No track URLs found. Checking if base_url itself contains papers...") 72 | track_urls = [self.base_url] 73 | for track_index, track_url in enumerate( 74 | tqdm(track_urls, desc="Processing tracks", position=0, leave=True) 75 | ): 76 | headers = { 77 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " 78 | "AppleWebKit/537.36 (KHTML, like Gecko) " 79 | "Chrome/112.0.0.0 Safari/537.36" 80 | } 81 | response = requests.get(track_url, headers=headers) 82 | if response.status_code != 200: 83 | print( 84 | f"Failed to fetch the track page. Status code: {response.status_code}" 85 | ) 86 | continue 87 | soup = BeautifulSoup(response.text, "html.parser") 88 | paper_links = [] 89 | for article_div in soup.select(".obj_article_summary"): 90 | title_element = article_div.select_one(".title") 91 | if title_element and title_element.a: 92 | paper_title = title_element.a.text.strip() 93 | paper_url = title_element.a["href"] 94 | paper_links.append((paper_title, paper_url)) 95 | print( 96 | f"Found {len(paper_links)} papers in track {track_index + 1}. Starting to fetch details..." 97 | ) 98 | for i, (title, url) in enumerate(paper_links): 99 | time.sleep(1) 100 | try: 101 | paper_response = requests.get(url, headers=headers) 102 | if paper_response.status_code == 200: 103 | paper_soup = BeautifulSoup(paper_response.text, "html.parser") 104 | 105 | authors_div = paper_soup.select_one(".authors") 106 | authors = [] 107 | if authors_div: 108 | for author in authors_div.select(".name"): 109 | authors.append(author.text.strip()) 110 | abstract_div = paper_soup.select_one(".item.abstract") 111 | abstract_text = "" 112 | if abstract_div: 113 | label = abstract_div.select_one(".label") 114 | if label: 115 | label.decompose() 116 | abstract_text = abstract_div.get_text(strip=True) 117 | pdf_url = "" 118 | pdf_link = paper_soup.select_one("a.obj_galley_link.pdf") 119 | if pdf_link and "href" in pdf_link.attrs: 120 | pdf_url = pdf_link["href"] 121 | self.add_paper( 122 | title=title, 123 | authors=authors, 124 | abstract=abstract_text, 125 | pdf_url=pdf_url, 126 | url=url, 127 | ) 128 | else: 129 | print( 130 | f"Failed to fetch paper at {url}. Status code: {paper_response.status_code}" 131 | ) 132 | 133 | except Exception as e: 134 | print(f"Error processing paper {url}: {str(e)}") 135 | 136 | print(f"Completed track {track_index + 1}/{len(track_urls)}") 137 | self.save_metadata() 138 | 139 | def download_pdfs(self): 140 | """ 141 | Download PDF files for all papers in the collection. 142 | 143 | The PDFs are saved to {data_dir}/pdf/{pid}.pdf. 144 | """ 145 | headers = { 146 | "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " 147 | "AppleWebKit/537.36 (KHTML, like Gecko) " 148 | "Chrome/112.0.0.0 Safari/537.36" 149 | } 150 | pdf_dir = os.path.join(self.data_dir, "pdf") 151 | if not os.path.exists(pdf_dir): 152 | os.makedirs(pdf_dir) 153 | if os.path.exists(os.path.join(self.data_dir, "metadata.tsv")): 154 | path = os.path.join(self.data_dir, "metadata.tsv") 155 | df = pd.read_csv(path, sep="\t") 156 | for index, row in tqdm( 157 | df.iterrows(), total=len(df), desc="Downloading PDFs" 158 | ): 159 | pdf_url = row["pdf_url"] 160 | pid = row["pid"] 161 | pdf_path = os.path.join(pdf_dir, f"{pid}.pdf") 162 | response = requests.get(pdf_url, headers=headers, timeout=30) 163 | with open(pdf_path, "wb") as f: 164 | f.write(response.content) 165 | 166 | 167 | if __name__ == "__main__": 168 | collection = AAAICollection(year=2024) 169 | if not os.path.exists(os.path.join(collection.data_dir, "metadata.tsv")): 170 | collection.collect() 171 | collection.download_pdfs() 172 | -------------------------------------------------------------------------------- /factscore/retrieval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import os 4 | 5 | import sqlite3 6 | import numpy as np 7 | import pickle as pkl 8 | 9 | from rank_bm25 import BM25Okapi 10 | 11 | SPECIAL_SEPARATOR = "####SPECIAL####SEPARATOR####" 12 | MAX_LENGTH = 256 13 | 14 | class DocDB(object): 15 | """Sqlite backed document storage. 16 | 17 | Implements get_doc_text(doc_id). 18 | """ 19 | 20 | def __init__(self, db_path=None, data_path=None): 21 | self.db_path = db_path 22 | self.connection = sqlite3.connect(self.db_path, check_same_thread=False) 23 | 24 | cursor = self.connection.cursor() 25 | cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") 26 | 27 | if len(cursor.fetchall())==0: 28 | assert data_path is not None, f"{self.db_path} is empty. Specify `data_path` in order to create a DB." 29 | print (f"{self.db_path} is empty. start building DB from {data_path}...") 30 | self.build_db(self.db_path, data_path) 31 | 32 | def __enter__(self): 33 | return self 34 | 35 | def __exit__(self, *args): 36 | self.close() 37 | 38 | def path(self): 39 | """Return the path to the file that backs this database.""" 40 | return self.path 41 | 42 | def close(self): 43 | """Close the connection to the database.""" 44 | self.connection.close() 45 | 46 | def build_db(self, db_path, data_path): 47 | from transformers import RobertaTokenizer 48 | tokenizer = RobertaTokenizer.from_pretrained("roberta-large") 49 | 50 | titles = set() 51 | output_lines = [] 52 | tot = 0 53 | start_time = time.time() 54 | c = self.connection.cursor() 55 | c.execute("CREATE TABLE documents (title PRIMARY KEY, text);") 56 | 57 | with open(data_path, "r") as f: 58 | for line in f: 59 | dp = json.loads(line) 60 | title = dp["title"] 61 | text = dp["text"] 62 | if title in titles: 63 | continue 64 | titles.add(title) 65 | if type(text)==str: 66 | text = [text] 67 | passages = [[]] 68 | for sent_idx, sent in enumerate(text): 69 | assert len(sent.strip())>0 70 | tokens = tokenizer(sent)["input_ids"] 71 | max_length = MAX_LENGTH - len(passages[-1]) 72 | if len(tokens) <= max_length: 73 | passages[-1].extend(tokens) 74 | else: 75 | passages[-1].extend(tokens[:max_length]) 76 | offset = max_length 77 | while offset < len(tokens): 78 | passages.append(tokens[offset:offset+MAX_LENGTH]) 79 | offset += MAX_LENGTH 80 | 81 | psgs = [tokenizer.decode(tokens) for tokens in passages if np.sum([t not in [0, 2] for t in tokens])>0] 82 | text = SPECIAL_SEPARATOR.join(psgs) 83 | output_lines.append((title, text)) 84 | tot += 1 85 | 86 | if len(output_lines) == 1000000: 87 | c.executemany("INSERT INTO documents VALUES (?,?)", output_lines) 88 | output_lines = [] 89 | print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60)) 90 | 91 | if len(output_lines) > 0: 92 | c.executemany("INSERT INTO documents VALUES (?,?)", output_lines) 93 | print ("Finish saving %dM documents (%dmin)" % (tot / 1000000, (time.time()-start_time)/60)) 94 | 95 | self.connection.commit() 96 | self.connection.close() 97 | 98 | def get_text_from_title(self, title): 99 | """Fetch the raw text of the doc for 'doc_id'.""" 100 | cursor = self.connection.cursor() 101 | cursor.execute("SELECT text FROM documents WHERE title = ?", (title,)) 102 | results = cursor.fetchall() 103 | results = [r for r in results] 104 | cursor.close() 105 | assert results is not None and len(results)==1, f"`topic` in your data ({title}) is likely to be not a valid title in the DB." 106 | results = [{"title": title, "text": para} for para in results[0][0].split(SPECIAL_SEPARATOR)] 107 | assert len(results)>0, f"`topic` in your data ({title}) is likely to be not a valid title in the DB." 108 | return results 109 | 110 | class Retrieval(object): 111 | 112 | def __init__(self, db, cache_path, embed_cache_path, 113 | retrieval_type="gtr-t5-large", batch_size=None): 114 | self.db = db 115 | self.cache_path = cache_path 116 | self.embed_cache_path = embed_cache_path 117 | self.retrieval_type = retrieval_type 118 | self.batch_size = batch_size 119 | assert retrieval_type=="bm25" or retrieval_type.startswith("gtr-") 120 | 121 | self.encoder = None 122 | self.load_cache() 123 | self.add_n = 0 124 | self.add_n_embed = 0 125 | 126 | def load_encoder(self): 127 | from sentence_transformers import SentenceTransformer 128 | encoder = SentenceTransformer("sentence-transformers/" + self.retrieval_type) 129 | # encoder = encoder.cuda() 130 | encoder = encoder.eval() 131 | self.encoder = encoder 132 | assert self.batch_size is not None 133 | 134 | def load_cache(self): 135 | if os.path.exists(self.cache_path): 136 | with open(self.cache_path, "r") as f: 137 | self.cache = json.load(f) 138 | else: 139 | self.cache = {} 140 | if os.path.exists(self.embed_cache_path): 141 | with open(self.embed_cache_path, "rb") as f: 142 | self.embed_cache = pkl.load(f) 143 | else: 144 | self.embed_cache = {} 145 | 146 | def save_cache(self): 147 | if self.add_n > 0: 148 | if os.path.exists(self.cache_path): 149 | with open(self.cache_path, "r") as f: 150 | new_cache = json.load(f) 151 | self.cache.update(new_cache) 152 | 153 | with open(self.cache_path, "w") as f: 154 | json.dump(self.cache, f) 155 | 156 | if self.add_n_embed > 0: 157 | if os.path.exists(self.embed_cache_path): 158 | with open(self.embed_cache_path, "rb") as f: 159 | new_cache = pkl.load(f) 160 | self.embed_cache.update(new_cache) 161 | 162 | with open(self.embed_cache_path, "wb") as f: 163 | pkl.dump(self.embed_cache, f) 164 | 165 | def get_bm25_passages(self, topic, query, passages, k): 166 | if topic in self.embed_cache: 167 | bm25 = self.embed_cache[topic] 168 | else: 169 | bm25 = BM25Okapi([psg["text"].replace("", "").replace("", "").split() for psg in passages]) 170 | self.embed_cache[topic] = bm25 171 | self.add_n_embed += 1 172 | scores = bm25.get_scores(query.split()) 173 | indices = np.argsort(-scores)[:k] 174 | return [passages[i] for i in indices] 175 | 176 | def get_gtr_passages(self, topic, retrieval_query, passages, k): 177 | if self.encoder is None: 178 | self.load_encoder() 179 | if topic in self.embed_cache: 180 | passage_vectors = self.embed_cache[topic] 181 | else: 182 | inputs = [psg["title"] + " " + psg["text"].replace("", "").replace("", "") for psg in passages] 183 | passage_vectors = self.encoder.encode(inputs, batch_size=self.batch_size, device=self.encoder.device) 184 | self.embed_cache[topic] = passage_vectors 185 | self.add_n_embed += 1 186 | query_vectors = self.encoder.encode([retrieval_query], 187 | batch_size=self.batch_size, 188 | device=self.encoder.device)[0] 189 | scores = np.inner(query_vectors, passage_vectors) 190 | indices = np.argsort(-scores)[:k] 191 | return [passages[i] for i in indices] 192 | 193 | def get_passages(self, topic, question, k): 194 | retrieval_query = topic + " " + question.strip() 195 | cache_key = topic + "#" + retrieval_query 196 | 197 | if cache_key not in self.cache: 198 | passages = self.db.get_text_from_title(topic) 199 | if self.retrieval_type=="bm25": 200 | self.cache[cache_key] = self.get_bm25_passages(topic, retrieval_query, passages, k) 201 | else: 202 | self.cache[cache_key] = self.get_gtr_passages(topic, retrieval_query, passages, k) 203 | assert len(self.cache[cache_key]) in [k, len(passages)] 204 | self.add_n += 1 205 | 206 | 207 | return self.cache[cache_key] 208 | 209 | 210 | 211 | 212 | 213 | -------------------------------------------------------------------------------- /massw/metrics.py: -------------------------------------------------------------------------------- 1 | """Compute cosine similarity, ROUGE, BLEU, METEOR, and BERTScore metrics. 2 | 3 | See example usage in the `__main__` block at the end of the file. 4 | """ 5 | 6 | import json 7 | from typing import List, Union 8 | 9 | import evaluate 10 | import numpy as np 11 | from sentence_transformers import SentenceTransformer 12 | 13 | 14 | class CosineSimilarity: 15 | """Compute cosine similarity between two ordered list of texts.""" 16 | 17 | def __init__(self): 18 | """Initialize the SentenceTransformer model.""" 19 | self.encoder = SentenceTransformer( 20 | 'intfloat/multilingual-e5-large-instruct') 21 | 22 | def get_detailed_instruct(self, query: str) -> str: 23 | """Generate a detailed instruct for the query.""" 24 | return f"Instruct: Retrieve semantically similar text.\nQuery: {query}" 25 | 26 | def get_embeddings(self, texts: List[str], is_query: bool): 27 | """Compute embeddings for the given texts.""" 28 | if is_query: 29 | texts = [self.get_detailed_instruct(query) for query in texts] 30 | embeddings = self.encoder.encode( 31 | texts, 32 | convert_to_numpy=True, 33 | normalize_embeddings=True, 34 | ) 35 | return embeddings 36 | 37 | def compute( 38 | self, 39 | predictions: List[str], 40 | references: Union[List[str], List[List[str]]], 41 | ): 42 | """Compute the cosine similarity between predictions and references.""" 43 | length = len(predictions) 44 | if isinstance(references[0], list): 45 | new_predictions = [] 46 | for pred, refs in zip(predictions, references): 47 | new_predictions.extend([pred] * len(refs)) 48 | new_references = [ref for refs in references for ref in refs] 49 | predictions, references = new_predictions, new_references 50 | predictions_embeddings = self.get_embeddings(predictions, 51 | is_query=True) 52 | references_embeddings = self.get_embeddings(references, is_query=False) 53 | # Compute pairwise cosine similarity 54 | cosine_similarities = [] 55 | for pred, ref in zip(predictions_embeddings, references_embeddings): 56 | cosine_similarities.append(np.dot(pred, ref)) # Already normalized 57 | cosine_similarities = np.array(cosine_similarities) 58 | cosine_similarities = cosine_similarities.reshape(length, -1) 59 | cosine_similarities = np.max(cosine_similarities, axis=1) 60 | return {"cosine": float(np.mean(cosine_similarities))} 61 | 62 | 63 | class NAHit: 64 | """Compute the precision, recall, and F1 score for N/A hit metric.""" 65 | 66 | def is_na(self, s: str): 67 | """Check if the string is N/A.""" 68 | if s.lower() in ["n/a", "na", "not applicable"]: 69 | return True 70 | if len(s.split()) < 3: 71 | return True 72 | return False 73 | 74 | def compute( 75 | self, 76 | predictions: List[str], 77 | references: Union[List[str], List[List[str]]], 78 | ): 79 | """Compute the precision, recall, and F1 score for N/A hit metric.""" 80 | predictions_na = [self.is_na(pred) for pred in predictions] 81 | if isinstance(references[0], list): 82 | references_na = [] 83 | references_na = [ 84 | all(self.is_na(ref) for ref in refs) for refs in references 85 | ] 86 | else: 87 | references_na = [self.is_na(ref) for ref in references] 88 | precision = evaluate.load("precision") 89 | recall = evaluate.load("recall") 90 | f1 = evaluate.load("f1") 91 | return { 92 | "precision": 93 | precision.compute( 94 | predictions=predictions_na, 95 | references=references_na, 96 | )["precision"], 97 | "recall": 98 | recall.compute( 99 | predictions=predictions_na, 100 | references=references_na, 101 | )["recall"], 102 | "f1": 103 | f1.compute( 104 | predictions=predictions_na, 105 | references=references_na, 106 | )["f1"], 107 | "pred_ratio": 108 | sum(predictions_na) / len(predictions_na), 109 | "ref_ratio": 110 | sum(references_na) / len(references_na), 111 | } 112 | 113 | 114 | cs = CosineSimilarity() 115 | bertscore = evaluate.load("bertscore") 116 | bleurt = evaluate.load("bleurt", 117 | module_type="metric", 118 | checkpoint="BLEURT-20-D12", 119 | config_name="BLEURT-20-D12") 120 | rouge = evaluate.load("rouge") 121 | bleu = evaluate.load("bleu") 122 | meteor = evaluate.load("meteor") 123 | nahit = NAHit() 124 | 125 | 126 | def compute_bleurt_score(predictions, references): 127 | """Compute BLEURT score for the given predictions and references.""" 128 | if isinstance(references[0], list): 129 | grouped_references = list(zip(*references)) 130 | scores = [] 131 | for grouped_reference in grouped_references: 132 | score = bleurt.compute( 133 | predictions=predictions, 134 | references=grouped_reference) 135 | scores.append(score["scores"]) 136 | scores = np.array(scores) # (num_refs, num_preds) 137 | score = np.mean(np.max(scores, axis=0)) 138 | else: 139 | score = bleurt.compute( 140 | predictions=predictions, 141 | references=references) 142 | score = np.mean(score["scores"]) 143 | return score 144 | 145 | 146 | def compute_metrics(predictions: List[str], 147 | references: List[List[str]], 148 | metric_names=None): 149 | """Compute cosine similarity, ROUGE, BLEU, METEOR, and BERTScore.""" 150 | if metric_names is None: 151 | metric_names = [ 152 | "cosine", 153 | "rouge", 154 | "bleu", 155 | "meteor", 156 | "bleurt", 157 | "bertscore", 158 | "nahit" 159 | ] 160 | metrics = {} 161 | if "nahit" in metric_names: 162 | metrics["nahit"] = nahit.compute( 163 | predictions=predictions, 164 | references=references, 165 | ) 166 | # Remove N/A predictions and references 167 | if isinstance(references[0], list): 168 | references_na = [ 169 | all(nahit.is_na(ref) for ref in refs) for refs in references 170 | ] 171 | else: 172 | references_na = [nahit.is_na(ref) for ref in references] 173 | predictions_na = [nahit.is_na(pred) for pred in predictions] 174 | both_not_na = [ 175 | not pred_na and not ref_na 176 | for pred_na, ref_na in zip(predictions_na, references_na) 177 | ] 178 | predictions = [ 179 | pred for pred, not_na in zip(predictions, both_not_na) if not_na 180 | ] 181 | references = [ 182 | ref for ref, not_na in zip(references, both_not_na) if not_na 183 | ] 184 | 185 | metric_computation_functions = { 186 | "cosine": cs, 187 | "rouge": rouge, 188 | "bleu": bleu, 189 | "meteor": meteor, 190 | "bertscore": bertscore, 191 | "bleurt": bleurt 192 | } 193 | 194 | for metric_name in metric_names: 195 | if metric_name in metric_computation_functions: 196 | if metric_name == "bertscore": 197 | score = metric_computation_functions[metric_name].compute( 198 | predictions=predictions, 199 | references=references, 200 | lang="en" 201 | ) 202 | metrics[metric_name] = { 203 | "precision": np.array(score["precision"]).mean(), 204 | "recall": np.array(score["recall"]).mean(), 205 | "f1": np.array(score["f1"]).mean() 206 | } 207 | elif metric_name == "bleurt": 208 | score = compute_bleurt_score(predictions, references) 209 | metrics[metric_name] = {"bleurt": score} 210 | else: 211 | metrics[metric_name] = \ 212 | metric_computation_functions[metric_name].compute( 213 | predictions=predictions, 214 | references=references 215 | ) 216 | 217 | return metrics 218 | 219 | 220 | def flatten_metrics(metric_dict: dict): 221 | """Flatten the metric dictionary for easy display.""" 222 | flat_metrics = {} 223 | if "meteor" in metric_dict: 224 | flat_metrics["METEOR"] = metric_dict["meteor"]["meteor"] 225 | if "cosine" in metric_dict: 226 | flat_metrics["Cosine Embedding"] = metric_dict["cosine"]["cosine"] 227 | if "bleu" in metric_dict: 228 | flat_metrics["BLEU"] = metric_dict["bleu"]["bleu"] 229 | flat_metrics["Precision-1"] = metric_dict["bleu"]["precisions"][0] 230 | flat_metrics["Precision-2"] = metric_dict["bleu"]["precisions"][1] 231 | flat_metrics["Length Ratio"] = metric_dict["bleu"]["length_ratio"] 232 | if "rouge" in metric_dict: 233 | flat_metrics["ROUGE-1"] = metric_dict["rouge"]["rouge1"] 234 | flat_metrics["ROUGE-2"] = metric_dict["rouge"]["rouge2"] 235 | if "nahit" in metric_dict: 236 | flat_metrics["N/A-precision"] = metric_dict["nahit"]["precision"] 237 | flat_metrics["N/A-recall"] = metric_dict["nahit"]["recall"] 238 | flat_metrics["N/A-f1"] = metric_dict["nahit"]["f1"] 239 | flat_metrics["N/A in pred"] = metric_dict["nahit"]["pred_ratio"] 240 | flat_metrics["N/A in ref"] = metric_dict["nahit"]["ref_ratio"] 241 | if "bertscore" in metric_dict: 242 | flat_metrics["BERTScore-precision"] = metric_dict["bertscore"][ 243 | "precision"] 244 | flat_metrics["BERTScore-recall"] = metric_dict["bertscore"]["recall"] 245 | flat_metrics["BERTScore-f1"] = metric_dict["bertscore"]["f1"] 246 | if "bleurt" in metric_dict: 247 | flat_metrics["BLEURT"] = metric_dict["bleurt"]["bleurt"] 248 | return flat_metrics 249 | 250 | 251 | if __name__ == "__main__": 252 | predictions_demo = ["The cat sat on the mat.", "The dog ate my homework."] 253 | references_demo = [["The cat sat on the mat.", "The cat sat on the desk."], 254 | ["The dog ate my homework.", "The dog ate my lunch."]] 255 | 256 | # Compute metrics 257 | metrics_demo = compute_metrics(predictions=predictions_demo, 258 | references=references_demo) 259 | 260 | # Print results 261 | print(json.dumps(metrics_demo, indent=2)) 262 | -------------------------------------------------------------------------------- /data/few_shot_examples.json: -------------------------------------------------------------------------------- 1 | { 2 | "idea_generation": [ 3 | {"user": "Given the context: 'Formulating efficient SQL queries is a challenging and iterative process requiring tuning and execution cycles. Current methods for providing insights about SQL query properties prior to execution depend on database instance statistics or query execution plans.', generate key ideas that could advance this area of study.", 4 | "assistant": "The authors propose a new approach for predicting SQL query properties, including the query answer size, run-time, and error class, relying on data-driven machine learning techniques and large query workloads, instead of database stats or execution plans."}, 5 | {"user": "Given the context: 'Two-sample tests are utilized to determine if two sets of samples come from the same distribution. Current tests typically use kernels which are spatially homogeneous and adapt only in lengthscale. Such kernels are not well-suited for dealing with high dimensions and complex data.', generate key ideas that could advance this area of study.", 6 | "assistant": "The authors suggest a new kernel-based two-sample test that is parameterized by deep neural networks. This test is designed specifically to adapt to variations in distribution shape and smoothness across different spaces, making it better suited for high dimensional and complex data."} 7 | ], 8 | "method_recommendation": [ 9 | {"user": "Given the context: 'Formulating efficient SQL queries is a challenging and iterative process requiring tuning and execution cycles. Current methods for providing insights about SQL query properties prior to execution depend on database instance statistics or query execution plans.' and the key idea: 'The authors propose a new approach for predicting SQL query properties, including the query answer size, run-time, and error class, relying on data-driven machine learning techniques and large query workloads, instead of database stats or execution plans.', recommend the most suitable method to validate this idea.", 10 | "assistant": "The authors employ data-driven machine learning techniques, alongside extensive query workloads to model and predict properties of SQL queries, in contrast to traditional techniques dependent on database-related statistics or execution plans."}, 11 | {"user": "Given the context: 'Two-sample tests are utilized to determine if two sets of samples come from the same distribution. Current tests typically use kernels which are spatially homogeneous and adapt only in lengthscale. Such kernels are not well-suited for dealing with high dimensions and complex data.' and the key idea: 'The authors suggest a new kernel-based two-sample test that is parameterized by deep neural networks. This test is designed specifically to adapt to variations in distribution shape and smoothness across different spaces, making it better suited for high dimensional and complex data.', recommend the most suitable method to validate this idea.", 12 | "assistant": "Deep neural networks are trained to maximize the power of these two-sample tests, allowing them to be particularly effective with complex data. The authors perform experiments on benchmark and real-world data to validate their method."} 13 | ], 14 | "outcome_prediction": [ 15 | {"user": "Based on the context: 'Formulating efficient SQL queries is a challenging and iterative process requiring tuning and execution cycles. Current methods for providing insights about SQL query properties prior to execution depend on database instance statistics or query execution plans.', the key idea: 'The authors propose a new approach for predicting SQL query properties, including the query answer size, run-time, and error class, relying on data-driven machine learning techniques and large query workloads, instead of database stats or execution plans.', and the recommended method: 'The authors employ data-driven machine learning techniques, alongside extensive query workloads to model and predict properties of SQL queries, in contrast to traditional techniques dependent on database-related statistics or execution plans.', predict the potential outcome of this research.", 16 | "assistant": "Empirical results demonstrate that the predictive accuracy of several query properties is higher when utilizing the new approach based on machine learning techniques, compared to existing methods."}, 17 | {"user": "Based on the context: 'Two-sample tests are utilized to determine if two sets of samples come from the same distribution. Current tests typically use kernels which are spatially homogeneous and adapt only in lengthscale. Such kernels are not well-suited for dealing with high dimensions and complex data.', the key idea: 'The authors suggest a new kernel-based two-sample test that is parameterized by deep neural networks. This test is designed specifically to adapt to variations in distribution shape and smoothness across different spaces, making it better suited for high dimensional and complex data.', and the recommended method: 'Deep neural networks are trained to maximize the power of these two-sample tests, allowing them to be particularly effective with complex data. The authors perform experiments on benchmark and real-world data to validate their method.', predict the potential outcome of this research.", 18 | "assistant": "The authors demonstrate the superior performance of their deep kernels in hypothesis testing on benchmark and real-world data, indicating they improve upon previous methods."} 19 | ], 20 | "future_work_recommendation": [ 21 | {"user": "Based on the context: 'Formulating efficient SQL queries is a challenging and iterative process requiring tuning and execution cycles. Current methods for providing insights about SQL query properties prior to execution depend on database instance statistics or query execution plans.', the key idea: 'The authors propose a new approach for predicting SQL query properties, including the query answer size, run-time, and error class, relying on data-driven machine learning techniques and large query workloads, instead of database stats or execution plans.', the method: 'The authors employ data-driven machine learning techniques, alongside extensive query workloads to model and predict properties of SQL queries, in contrast to traditional techniques dependent on database-related statistics or execution plans.', and the outcome: 'Empirical results demonstrate that the predictive accuracy of several query properties is higher when utilizing the new approach based on machine learning techniques, compared to existing methods.', suggest projected impact for this research.", 22 | "assistant": "The introduction of a data-driven machine learning approach to predict SQL query properties can significantly accelerate and improve query formulation and analysis by providing insights prior to their execution, potentially transforming the way SQL queries are formed."}, 23 | {"user": "Based on the context: 'Two-sample tests are utilized to determine if two sets of samples come from the same distribution. Current tests typically use kernels which are spatially homogeneous and adapt only in lengthscale. Such kernels are not well-suited for dealing with high dimensions and complex data.', the key idea: 'The authors suggest a new kernel-based two-sample test that is parameterized by deep neural networks. This test is designed specifically to adapt to variations in distribution shape and smoothness across different spaces, making it better suited for high dimensional and complex data.', the method: 'Deep neural networks are trained to maximize the power of these two-sample tests, allowing them to be particularly effective with complex data. The authors perform experiments on benchmark and real-world data to validate their method.', and the outcome: 'The authors demonstrate the superior performance of their deep kernels in hypothesis testing on benchmark and real-world data, indicating they improve upon previous methods.', suggest projected impact for this research.", 24 | "assistant": "The ideas and methods proposed in this paper have the potential to significantly improve the efficacy of two-sample tests in the future, particularly in instances where the data is high dimensional and complex. As a special case, this scheme includes popular classifier-based two-sample tests but improves upon them."} 25 | ], 26 | "title_prediction": [ 27 | {"user": "Given the context: 'Formulating efficient SQL queries is a challenging and iterative process requiring tuning and execution cycles. Current methods for providing insights about SQL query properties prior to execution depend on database instance statistics or query execution plans.', the key idea: 'The authors propose a new approach for predicting SQL query properties, including the query answer size, run-time, and error class, relying on data-driven machine learning techniques and large query workloads, instead of database stats or execution plans.', the method: 'The authors employ data-driven machine learning techniques, alongside extensive query workloads to model and predict properties of SQL queries, in contrast to traditional techniques dependent on database-related statistics or execution plans.', the outcome: 'Empirical results demonstrate that the predictive accuracy of several query properties is higher when utilizing the new approach based on machine learning techniques, compared to existing methods.', and the future impact: 'The introduction of a data-driven machine learning approach to predict SQL query properties can significantly accelerate and improve query formulation and analysis by providing insights prior to their execution, potentially transforming the way SQL queries are formed.', predict the title of this research paper. The title should be concise and reflective of the core aspects.", 28 | "assistant": "Facilitating SQL Query Composition and Analysis"}, 29 | {"user": "Given the context: 'Two-sample tests are utilized to determine if two sets of samples come from the same distribution. Current tests typically use kernels which are spatially homogeneous and adapt only in lengthscale. Such kernels are not well-suited for dealing with high dimensions and complex data.', the key idea: 'The authors suggest a new kernel-based two-sample test that is parameterized by deep neural networks. This test is designed specifically to adapt to variations in distribution shape and smoothness across different spaces, making it better suited for high dimensional and complex data.', the method: 'Deep neural networks are trained to maximize the power of these two-sample tests, allowing them to be particularly effective with complex data. The authors perform experiments on benchmark and real-world data to validate their method.', the outcome: 'The authors demonstrate the superior performance of their deep kernels in hypothesis testing on benchmark and real-world data, indicating they improve upon previous methods.', and the future impact: 'The ideas and methods proposed in this paper have the potential to significantly improve the efficacy of two-sample tests in the future, particularly in instances where the data is high dimensional and complex. As a special case, this scheme includes popular classifier-based two-sample tests but improves upon them.', predict the title of this research paper. The title should be concise and reflective of the core aspects.", 30 | "assistant": "Learning Deep Kernels for Non-Parametric Two-Sample Tests"} 31 | ] 32 | } 33 | -------------------------------------------------------------------------------- /data/cot_examples.json: -------------------------------------------------------------------------------- 1 | { 2 | "idea_generation": [ 3 | { 4 | "user": "Given the context: 'Formulating efficient SQL queries is a challenging and iterative process requiring tuning and execution cycles. Current methods for providing insights about SQL query properties prior to execution depend on database instance statistics or query execution plans.', generate key ideas that could advance this area of study. Let's think step by step. You should first present your reasoning. After that, the final prediction should start after the marker 'Prediction:'.", 5 | "assistant": "Considering the limitations of traditional methods that depend heavily on static database statistics and execution plans, a shift towards dynamic, predictive models can offer real-time insights and adjustments. Prediction: The authors propose a new approach for predicting SQL query properties, including the query answer size, run-time, and error class, relying on data-driven machine learning techniques and large query workloads, instead of database stats or execution plans." 6 | }, 7 | { 8 | "user": "Given the context: 'Two-sample tests are utilized to determine if two sets of samples come from the same distribution. Current tests typically use kernels which are spatially homogeneous and adapt only in lengthscale. Such kernels are not well-suited for dealing with high dimensions and complex data.', generate key ideas that could advance this area of study. Let's think step by step. You should first present your reasoning. After that, the final prediction should start after the marker 'Prediction:'.", 9 | "assistant": "Traditional kernels struggle with high-dimensional spaces due to their homogeneous nature. Incorporating adaptability and depth could enhance sensitivity to complex data distributions. Prediction: The authors suggest a new kernel-based two-sample test that is parameterized by deep neural networks. This test is designed specifically to adapt to variations in distribution shape and smoothness across different spaces, making it better suited for high dimensional and complex data." 10 | } 11 | ], 12 | "method_recommendation": [ 13 | { 14 | "user": "Given the context: 'Formulating efficient SQL queries is a challenging and iterative process requiring tuning and execution cycles. Current methods for providing insights about SQL query properties prior to execution depend on database instance statistics or query execution plans.' and the key idea: 'The authors propose a new approach for predicting SQL query properties, including the query answer size, run-time, and error class, relying on data-driven machine learning techniques and large query workloads, instead of database stats or execution plans.', recommend the most suitable method to validate this idea. Let's think step by step. You should first present your reasoning. After that, the final prediction should start after the marker 'Prediction:'.", 15 | "assistant": "To validate the proposed approach, leveraging a large dataset of diverse SQL queries would allow for robust training and testing of the machine learning model, ensuring its efficacy across various scenarios. Prediction: The authors employ data-driven machine learning techniques, alongside extensive query workloads to model and predict properties of SQL queries, in contrast to traditional techniques dependent on database-related statistics or execution plans." 16 | }, 17 | { 18 | "user": "Given the context: 'Two-sample tests are utilized to determine if two sets of samples come from the same distribution. Current tests typically use kernels which are spatially homogeneous and adapt only in lengthscale. Such kernels are not well-suited for dealing with high dimensions and complex data.' and the key idea: 'The authors suggest a new kernel-based two-sample test that is parameterized by deep neural networks. This test is designed specifically to adapt to variations in distribution shape and smoothness across different spaces, making it better suited for high dimensional and complex data.', recommend the most suitable method to validate this idea. Let's think step by step. You should first present your reasoning. After that, the final prediction should start after the marker 'Prediction:'.", 19 | "assistant": "Experimental validation using diverse datasets, including those with high-dimensional attributes, would effectively demonstrate the model's adaptability and superiority over traditional methods. Prediction: Deep neural networks are trained to maximize the power of these two-sample tests, allowing them to be particularly effective with complex data. The authors perform experiments on benchmark and real-world data to validate their method." 20 | } 21 | ], 22 | "outcome_prediction": [ 23 | { 24 | "user": "Based on the context: 'Formulating efficient SQL queries is a challenging and iterative process requiring tuning and execution cycles. Current methods for providing insights about SQL query properties prior to execution depend on database instance statistics or query execution plans.', the key idea: 'The authors propose a new approach for predicting SQL query properties, including the query answer size, run-time, and error class, relying on data-driven machine learning techniques and large query workloads, instead of database stats or execution plans.', and the recommended method: 'The authors employ data-driven machine learning techniques, alongside extensive query workloads to model and predict properties of SQL queries, in contrast to traditional techniques dependent on database-related statistics or execution plans.', predict the potential outcome of this research. Let's think step by step. You should first present your reasoning. After that, the final prediction should start after the marker 'Prediction:'.", 25 | "assistant": "Given the novel approach's reliance on real-time data and machine learning, it is expected to surpass traditional methods in accuracy and adaptability, providing actionable insights more reliably. Prediction: Empirical results demonstrate that the predictive accuracy of several query properties is higher when utilizing the new approach based on machine learning techniques, compared to existing methods." 26 | }, 27 | { 28 | "user": "Based on the context: 'Two-sample tests are utilized to determine if two sets of samples come from the same distribution. Current tests typically use kernels which are spatially homogeneous and adapt only in lengthscale. Such kernels are not well-suited for dealing with high dimensions and complex data.', the key idea: 'The authors suggest a new kernel-based two-sample test that is parameterized by deep neural networks. This test is designed specifically to adapt to variations in distribution shape and smoothness across different spaces, making it better suited for high dimensional and complex data.', and the recommended method: 'Deep neural networks are trained to maximize the power of these two-sample tests, allowing them to be particularly effective with complex data. The authors perform experiments on benchmark and real-world data to validate their method.', predict the potential outcome of this research. Let's think step by step. You should first present your reasoning. After that, the final prediction should start after the marker 'Prediction:'.", 29 | "assistant": "The integration of deep learning for kernel adaptation promises significant improvements in test performance, especially in complex scenarios where traditional kernels fall short. Prediction: The authors demonstrate the superior performance of their deep kernels in hypothesis testing on benchmark and real-world data, indicating they improve upon previous methods." 30 | } 31 | ], 32 | "future_work_recommendation": [ 33 | { 34 | "user": "Based on the context: 'Formulating efficient SQL queries is a challenging and iterative process requiring tuning and execution cycles. Current methods for providing insights about SQL query properties prior to execution depend on database instance statistics or query execution plans.', the key idea: 'The authors propose a new approach for predicting SQL query properties, including the query answer size, run-time, and error class, relying on data-driven machine learning techniques and large query workloads, instead of database stats or execution plans.', the method: 'The authors employ data-driven machine learning techniques, alongside extensive query workloads to model and predict properties of SQL queries, in contrast to traditional techniques dependent on database-related statistics or execution plans.', and the outcome: 'Empirical results demonstrate that the predictive accuracy of several query properties is higher when utilizing the new approach based on machine learning techniques, compared to existing methods.', suggest projected impact for this research. Let's think step by step. You should first present your reasoning. After that, the final prediction should start after the marker 'Prediction:'.", 35 | "assistant": "Given the success of this predictive model, exploring its applicability to other database systems or incorporating it into database management tools could further broaden its impact, making database operations more efficient and intuitive. Prediction: The introduction of a data-driven machine learning approach to predict SQL query properties can significantly accelerate and improve query formulation and analysis by providing insights prior to their execution, potentially transforming the way SQL queries are formed." 36 | }, 37 | { 38 | "user": "Based on the context: 'Two-sample tests are utilized to determine if two sets of samples come from the same distribution. Current tests typically use kernels which are spatially homogeneous and adapt only in lengthscale. Such kernels are not well-suited for dealing with high dimensions and complex data.', the key idea: 'The authors suggest a new kernel-based two-sample test that is parameterized by deep neural networks. This test is designed specifically to adapt to variations in distribution shape and smoothness across different spaces, making it better suited for high dimensional and complex data.', the method: 'Deep neural networks are trained to maximize the power of these two-sample tests, allowing them to be particularly effective with complex data. The authors perform experiments on benchmark and real-world data to validate their method.', and the outcome: 'The authors demonstrate the superior performance of their deep kernels in hypothesis testing on benchmark and real-world data, indicating they improve upon previous methods.', suggest projected impact for this research. Let's think step by step. You should first present your reasoning. After that, the final prediction should start after the marker 'Prediction:'.", 39 | "assistant": "Further advancements could include developing even more adaptable neural network architectures or extending these methods to other statistical tests to harness similar benefits in different contexts. Prediction: The ideas and methods proposed in this paper have the potential to significantly improve the efficacy of two-sample tests in the future, particularly in instances where the data is high dimensional and complex. As a special case, this scheme includes popular classifier-based two-sample tests but improves upon them." 40 | } 41 | ], 42 | "title_prediction": [ 43 | { 44 | "user": "Given the context: 'Formulating efficient SQL queries is a challenging and iterative process requiring tuning and execution cycles. Current methods for providing insights about SQL query properties prior to execution depend on database instance statistics.', the key idea: 'The authors propose a new approach for predicting SQL query properties, including the query answer size, run-time, and error class.', the method: 'The authors employ data-driven machine learning techniques, alongside extensive query workloads to model and predict properties of SQL queries.', the outcome: 'Empirical results demonstrate that the predictive accuracy of several query properties is higher when utilizing the new approach based on machine learning techniques.', and the future impact: 'The introduction of a data-driven machine learning approach to predict SQL query properties can significantly accelerate and improve query formulation and analysis by providing insights prior to their execution.', predict the title of this research paper. Let's think step by step. You should first present your reasoning. After that, the final prediction should start after the marker 'Prediction:'.", 45 | "assistant": "The paper's focus is on enhancing the predictability and efficiency of SQL query processing using advanced machine learning models. Prediction: Facilitating SQL Query Composition and Analysis Through Data-Driven Predictive Modeling." 46 | }, 47 | { 48 | "user": "Given the context: 'Two-sample tests are utilized to determine if two sets of samples come from the same distribution. Current tests typically use kernels which are spatially homogeneous and adapt only in lengthscale. Such kernels are not well-suited for dealing with high dimensions and complex data.', the key idea: 'The authors suggest a new kernel-based two-sample test that is parameterized by deep neural networks. This test is designed specifically to adapt to variations in distribution shape and smoothness across different spaces.', the method: 'Deep neural networks are trained to maximize the power of these two-sample tests. The authors perform experiments on benchmark and real-world data to validate their method.', the outcome: 'The authors demonstrate the superior performance of their deep kernels in hypothesis testing on benchmark and real-world data.', and the future impact: 'The ideas and methods proposed in this paper have the potential to significantly improve the efficacy of two-sample tests in the future.', predict the title of this research paper. Let's think step by step. You should first present your reasoning. After that, the final prediction should start after the marker 'Prediction:'.", 49 | "assistant": "This paper introduces a novel, adaptable kernel model enhanced by deep learning techniques for two-sample testing, which could set a new standard for statistical tests in complex data environments. Prediction: Learning Deep Kernels for Non-Parametric Two-Sample Tests." 50 | } 51 | ] 52 | } 53 | -------------------------------------------------------------------------------- /factscore/atomic_facts.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import re 4 | import functools 5 | import string 6 | import spacy 7 | import sys 8 | import nltk 9 | import openai 10 | from rank_bm25 import BM25Okapi 11 | import os 12 | import time 13 | from nltk.tokenize import sent_tokenize 14 | 15 | from factscore.openai_lm import OpenAIModel 16 | 17 | nltk.download("punkt") 18 | 19 | 20 | class AtomicFactGenerator(object): 21 | def __init__(self, key_path, demon_dir, gpt3_cache_file=None): 22 | self.nlp = spacy.load("en_core_web_sm") 23 | self.is_bio = True 24 | self.demon_path = os.path.join(demon_dir, "demons.json" if self.is_bio else "demons_complex.json") 25 | 26 | self.openai_lm = OpenAIModel("ChatGPT", cache_file=gpt3_cache_file, key_path=key_path) 27 | 28 | # get the demos 29 | with open(self.demon_path, 'r') as f: 30 | self.demons = json.load(f) 31 | 32 | tokenized_corpus = [doc.split(" ") for doc in self.demons.keys()] 33 | self.bm25 = BM25Okapi(tokenized_corpus) 34 | 35 | def save_cache(self): 36 | self.openai_lm.save_cache() 37 | 38 | def run(self, generation, cost_estimate=None): 39 | """Convert the generation into a set of atomic facts. Return a total words cost if cost_estimate != None.""" 40 | assert isinstance(generation, str), "generation must be a string" 41 | paragraphs = [para.strip() for para in generation.split("\n") if len(para.strip()) > 0] 42 | return self.get_atomic_facts_from_paragraph(paragraphs, cost_estimate=cost_estimate) 43 | 44 | def get_atomic_facts_from_paragraph(self, paragraphs, cost_estimate=None): 45 | sentences = [] 46 | para_breaks = [] 47 | for para_idx, paragraph in enumerate(paragraphs): 48 | if para_idx > 0 : 49 | para_breaks.append(len(sentences)) 50 | 51 | initials = detect_initials(paragraph) 52 | 53 | curr_sentences = sent_tokenize(paragraph) 54 | curr_sentences_2 = sent_tokenize(paragraph) 55 | 56 | curr_sentences = fix_sentence_splitter(curr_sentences, initials) 57 | curr_sentences_2 = fix_sentence_splitter(curr_sentences_2, initials) 58 | 59 | # checking this, just to ensure the crediability of the sentence splitter fixing algorithm 60 | assert curr_sentences == curr_sentences_2, (paragraph, curr_sentences, curr_sentences_2) 61 | 62 | sentences += curr_sentences 63 | 64 | atoms_or_estimate = self.get_init_atomic_facts_from_sentence([sent for i, sent in enumerate(sentences) if not (not self.is_bio and ( \ 65 | (i==0 and (sent.startswith("Sure") or sent.startswith("Here are"))) or \ 66 | (i==len(sentences)-1 and (sent.startswith("Please") or sent.startswith("I hope") or sent.startswith("Here are")))))], cost_estimate=cost_estimate) 67 | 68 | if cost_estimate: 69 | return atoms_or_estimate 70 | else: 71 | atoms = atoms_or_estimate 72 | 73 | atomic_facts_pairs = [] 74 | for i, sent in enumerate(sentences): 75 | if not self.is_bio and ( \ 76 | (i==0 and (sent.startswith("Sure") or sent.startswith("Here are"))) or \ 77 | (i==len(sentences)-1 and (sent.startswith("Please") or sent.startswith("I hope") or sent.startswith("Here are")))): 78 | atomic_facts_pairs.append((sent, [])) 79 | elif self.is_bio and sent.startswith("This sentence does not contain any facts"): 80 | atomic_facts_pairs.append((sent, [])) 81 | elif sent.startswith("Sure") or sent.startswith("Please") or (i==0 and sent.startswith("Here are")): 82 | atomic_facts_pairs.append((sent, [])) 83 | else: 84 | atomic_facts_pairs.append((sent, atoms[sent])) 85 | 86 | # postprocess_atomic_facts will fix minor issues from InstructGPT 87 | # it is supposed to handle sentence splitter issue too, but since here 88 | # we fixed sentence splitter issue already, 89 | # the new para_breaks should be identical to the original para_breaks 90 | if self.is_bio: 91 | atomic_facts_pairs, para_breaks = postprocess_atomic_facts(atomic_facts_pairs, list(para_breaks), self.nlp) 92 | 93 | return atomic_facts_pairs, para_breaks 94 | 95 | 96 | def get_init_atomic_facts_from_sentence(self, sentences, cost_estimate=None): 97 | """Get the initial atomic facts from the sentences. Return a total words cost if cost_estimate != None.""" 98 | 99 | is_bio = self.is_bio 100 | demons = self.demons 101 | 102 | k = 1 if is_bio else 0 103 | n = 7 if is_bio else 8 104 | 105 | prompts = [] 106 | prompt_to_sent = {} 107 | atoms = {} 108 | for sentence in sentences: 109 | if sentence in atoms: 110 | continue 111 | top_machings = best_demos(sentence, self.bm25, list(demons.keys()), k) 112 | prompt = "" 113 | 114 | for i in range(n): 115 | prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(list(demons.keys())[i]) 116 | for fact in demons[list(demons.keys())[i]]: 117 | prompt = prompt + "- {}\n".format(fact) 118 | prompt = prompt + "\n" 119 | 120 | for match in top_machings: 121 | prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(match) 122 | for fact in demons[match]: 123 | prompt = prompt + "- {}\n".format(fact) 124 | prompt = prompt + "\n" 125 | prompt = prompt + "Please breakdown the following sentence into independent facts: {}\n".format(sentence) 126 | prompts.append(prompt) 127 | prompt_to_sent[prompt] = sentence 128 | 129 | if cost_estimate: 130 | total_words_estimate = 0 131 | for prompt in prompts: 132 | if cost_estimate == "consider_cache" and (prompt.strip() + "_0") in self.openai_lm.cache_dict: 133 | continue 134 | total_words_estimate += len(prompt.split()) 135 | return total_words_estimate 136 | else: 137 | for prompt in prompts: 138 | output, _ = self.openai_lm.generate(prompt) 139 | atoms[prompt_to_sent[prompt]] = text_to_sentences(output) 140 | 141 | for key, value in demons.items(): 142 | if key not in atoms: 143 | atoms[key] = value 144 | 145 | return atoms 146 | 147 | 148 | def best_demos(query, bm25, demons_sents, k): 149 | tokenized_query = query.split(" ") 150 | top_machings = bm25.get_top_n(tokenized_query, demons_sents, k) 151 | return top_machings 152 | 153 | 154 | # transform InstructGPT output into sentences 155 | def text_to_sentences(text): 156 | sentences = text.split("- ")[1:] 157 | sentences = [sent.strip()[:-1] if sent.strip()[-1] == '\n' else sent.strip() for sent in sentences] 158 | if len(sentences) > 0: 159 | if sentences[-1][-1] != '.': 160 | sentences[-1] = sentences[-1] + '.' 161 | else: 162 | sentences = [] 163 | return sentences 164 | 165 | 166 | def normalize_answer(s): 167 | """Lower text and remove punctuation, articles and extra whitespace.""" 168 | def remove_articles(text): 169 | regex = re.compile(r'\b(a|an|the)\b', re.UNICODE) 170 | return re.sub(regex, ' ', text) 171 | def white_space_fix(text): 172 | return ' '.join(text.split()) 173 | def remove_punc(text): 174 | exclude = set(string.punctuation) 175 | return ''.join(ch for ch in text if ch not in exclude) 176 | def lower(text): 177 | return text.lower() 178 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 179 | 180 | MONTHS = ["January", "February", "March", "April", "May", "June", "July", "August", "September", "October", "November", "December"] 181 | MONTHS = [m.lower() for m in MONTHS] 182 | 183 | def is_num(text): 184 | try: 185 | text = int(text) 186 | return True 187 | except Exception: 188 | return False 189 | 190 | def is_date(text): 191 | text = normalize_answer(text) 192 | for token in text.split(" "): 193 | if (not is_num(token)) and token not in MONTHS: 194 | return False 195 | return True 196 | 197 | def extract_numeric_values(text): 198 | pattern = r'\b\d+\b' # regular expression pattern for integers 199 | numeric_values = re.findall(pattern, text) # find all numeric values in the text 200 | return set([value for value in numeric_values]) # convert the values to float and return as a list 201 | 202 | 203 | def detect_entities(text, nlp): 204 | doc = nlp(text) 205 | entities = set() 206 | 207 | def _add_to_entities(text): 208 | if "-" in text: 209 | for _text in text.split("-"): 210 | entities.add(_text.strip()) 211 | else: 212 | entities.add(text) 213 | 214 | 215 | for ent in doc.ents: 216 | # spacy often has errors with other types of entities 217 | if ent.label_ in ["DATE", "TIME", "PERCENT", "MONEY", "QUANTITY", "ORDINAL", "CARDINAL"]: 218 | 219 | if is_date(ent.text): 220 | _add_to_entities(ent.text) 221 | else: 222 | for token in ent.text.split(): 223 | if is_date(token): 224 | _add_to_entities(token) 225 | 226 | for new_ent in extract_numeric_values(text): 227 | if not np.any([new_ent in ent for ent in entities]): 228 | entities.add(new_ent) 229 | 230 | return entities 231 | 232 | def postprocess_atomic_facts(_atomic_facts, para_breaks, nlp): 233 | 234 | verbs = ["born.", " appointed.", " characterized.", " described.", " known.", " member.", " advocate.", "served.", "elected."] 235 | permitted_verbs = ["founding member."] 236 | 237 | atomic_facts = [] 238 | new_atomic_facts = [] 239 | new_para_breaks = [] 240 | 241 | for i, (sent, facts) in enumerate(_atomic_facts): 242 | sent = sent.strip() 243 | if len(sent.split())==1 and i not in para_breaks and i > 0: 244 | assert i not in para_breaks 245 | atomic_facts[-1][0] += " " + sent 246 | atomic_facts[-1][1] += facts 247 | else: 248 | if i in para_breaks: 249 | new_para_breaks.append(len(atomic_facts)) 250 | atomic_facts.append([sent, facts]) 251 | 252 | for i, (sent, facts) in enumerate(atomic_facts): 253 | entities = detect_entities(sent, nlp) 254 | covered_entities = set() 255 | # print (entities) 256 | new_facts = [] 257 | for i, fact in enumerate(facts): 258 | if any([fact.endswith(verb) for verb in verbs]) and not any([fact.endswith(verb) for verb in permitted_verbs]): 259 | if any([fact[:-1] in other_fact for j, other_fact in enumerate(facts) if j != i]): 260 | continue 261 | sent_entities = detect_entities(fact, nlp) 262 | covered_entities |= set([e for e in sent_entities if e in entities]) 263 | new_entities = sent_entities - entities 264 | if len(new_entities) > 0: 265 | do_pass = False 266 | for new_ent in new_entities: 267 | pre_ent = None 268 | for ent in entities: 269 | if ent.startswith(new_ent): 270 | pre_ent = ent 271 | break 272 | if pre_ent is None: 273 | do_pass = True 274 | break 275 | fact = fact.replace(new_ent, pre_ent) 276 | covered_entities.add(pre_ent) 277 | if do_pass: 278 | continue 279 | if fact in new_facts: 280 | continue 281 | new_facts.append(fact) 282 | try: 283 | assert entities==covered_entities 284 | except Exception: 285 | new_facts = facts # there is a bug in spacy entity linker, so just go with the previous facts 286 | 287 | new_atomic_facts.append((sent, new_facts)) 288 | 289 | return new_atomic_facts, new_para_breaks 290 | 291 | def is_integer(s): 292 | try: 293 | s = int(s) 294 | return True 295 | except Exception: 296 | return False 297 | 298 | def detect_initials(text): 299 | pattern = r"[A-Z]\. ?[A-Z]\." 300 | match = re.findall(pattern, text) 301 | return [m for m in match] 302 | 303 | def fix_sentence_splitter(curr_sentences, initials): 304 | for initial in initials: 305 | if not np.any([initial in sent for sent in curr_sentences]): 306 | alpha1, alpha2 = [t.strip() for t in initial.split(".") if len(t.strip())>0] 307 | for i, (sent1, sent2) in enumerate(zip(curr_sentences, curr_sentences[1:])): 308 | if sent1.endswith(alpha1 + ".") and sent2.startswith(alpha2 + "."): 309 | # merge sentence i and i+1 310 | curr_sentences = curr_sentences[:i] + [curr_sentences[i] + " " + curr_sentences[i+1]] + curr_sentences[i+2:] 311 | break 312 | sentences = [] 313 | combine_with_previous = None 314 | for sent_idx, sent in enumerate(curr_sentences): 315 | if len(sent.split())<=1 and sent_idx==0: 316 | assert not combine_with_previous 317 | combine_with_previous = True 318 | sentences.append(sent) 319 | elif len(sent.split())<=1: 320 | assert sent_idx > 0 321 | sentences[-1] += " " + sent 322 | combined_with_previous = False 323 | elif sent[0].isalpha() and not sent[0].isupper() and sent_idx > 0: 324 | assert sent_idx > 0, curr_sentences 325 | sentences[-1] += " " + sent 326 | combine_with_previous = False 327 | elif combine_with_previous: 328 | assert sent_idx > 0 329 | sentences[-1] += " " + sent 330 | combine_with_previous = False 331 | else: 332 | assert not combine_with_previous 333 | sentences.append(sent) 334 | return sentences 335 | 336 | 337 | def main(): 338 | generator = AtomicFactGenerator("api.key", "demos", gpt3_cache_dir=None) 339 | atomic_facts, para_breaks = generator.run("Thierry Henry (born 17 August 1977) is a French professional football coach, pundit, and former player. He is considered one of the greatest strikers of all time, and one the greatest players of the Premier League history. He has been named Arsenal F.C's greatest ever player.\n\nHenry made his professional debut with Monaco in 1994 before signing for defending Serie A champions Juventus. However, limited playing time, coupled with disagreements with the club's hierarchy, led to him signing for Premier League club Arsenal for £11 million in 1999.") 340 | 341 | print(atomic_facts) 342 | print(para_breaks) 343 | 344 | if __name__ == "__main__": 345 | main() --------------------------------------------------------------------------------