├── assets └── forecasting_system.png ├── notebooks ├── results │ ├── data │ │ ├── answers.pickle │ │ ├── questions_data.pickle │ │ ├── base_predictions.pickle │ │ ├── crowd_predictions.pickle │ │ ├── finetuned_predictions.pickle │ │ └── finetuned_other_predictions.pickle │ └── results.ipynb └── demo │ └── sample_questions.pickle ├── llm_forecasting ├── prompts │ ├── system.py │ ├── alignment.py │ ├── relevance.py │ ├── ensemble_reasoning.py │ ├── base_eval.py │ ├── prompts.py │ ├── data_wrangling.py │ ├── summarization.py │ └── search_query.py ├── utils │ ├── logging_utils.py │ ├── model_utils.py │ ├── utils.py │ ├── article_utils.py │ ├── validation_utils.py │ ├── api_utils.py │ ├── metrics_utils.py │ ├── db_utils.py │ ├── time_utils.py │ ├── string_utils.py │ └── data_utils.py ├── config │ ├── keys.py │ └── constants.py ├── alignment.py ├── data_scraping.py ├── evaluation.py ├── summarize.py ├── model_eval.py └── ensemble.py ├── pyproject.toml ├── scripts ├── fine_tune │ └── fine_tune.py ├── data_scraping │ ├── cset.py │ ├── gjopen.py │ ├── polymarket.py │ └── manifold.py └── training_data │ └── training_point_generation.py └── README.md /assets/forecasting_system.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/assets/forecasting_system.png -------------------------------------------------------------------------------- /notebooks/results/data/answers.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/notebooks/results/data/answers.pickle -------------------------------------------------------------------------------- /notebooks/demo/sample_questions.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/notebooks/demo/sample_questions.pickle -------------------------------------------------------------------------------- /notebooks/results/data/questions_data.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/notebooks/results/data/questions_data.pickle -------------------------------------------------------------------------------- /notebooks/results/data/base_predictions.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/notebooks/results/data/base_predictions.pickle -------------------------------------------------------------------------------- /notebooks/results/data/crowd_predictions.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/notebooks/results/data/crowd_predictions.pickle -------------------------------------------------------------------------------- /notebooks/results/data/finetuned_predictions.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/notebooks/results/data/finetuned_predictions.pickle -------------------------------------------------------------------------------- /notebooks/results/data/finetuned_other_predictions.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/notebooks/results/data/finetuned_other_predictions.pickle -------------------------------------------------------------------------------- /llm_forecasting/prompts/system.py: -------------------------------------------------------------------------------- 1 | SYSTEM_SUPERFORECASTER_0 = """You are an expert superforecaster, familiar with the work of Tetlock and others. 2 | Your mission is to generate accurate predictions for forecasting questions. 3 | Aggregate the information provided by the user. Make sure to give detailed reasonings.""" 4 | -------------------------------------------------------------------------------- /llm_forecasting/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | def setup_file_logger(logger, file_name): 5 | """ 6 | Set up a custom logger that writes to a file. 7 | 8 | Args: 9 | file_name (str): Name of the file to write to. 10 | logger_name (str): Name of the logger. 11 | """ 12 | logger.setLevel(logging.INFO) 13 | file_handler = logging.FileHandler(file_name) 14 | formatter = logging.Formatter( 15 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s" 16 | ) 17 | file_handler.setFormatter(formatter) 18 | logger.addHandler(file_handler) 19 | -------------------------------------------------------------------------------- /llm_forecasting/config/keys.py: -------------------------------------------------------------------------------- 1 | AWS_ACCESS_KEY = "" 2 | AWS_SECRET_KEY = "" 3 | 4 | METACULUS_KEY = "" 5 | MANIFOLD_KEY = "" 6 | POLYMARKET_KEY = "" 7 | CRYPTO_PRIVATE_KEY = "" 8 | EMAIL = "" 9 | GJOPEN_CSET_PASSWORD = "" 10 | 11 | 12 | NEWSCASTCHER_KEY = "" 13 | 14 | OPENAI_KEY = "" 15 | ANTHROPIC_KEY = "" 16 | TOGETHER_KEY = "" 17 | GOOGLE_AI_KEY = "" 18 | HF_ACCESS_TOKEN = "" 19 | 20 | keys = { 21 | "AWS_ACCESS_KEY": AWS_ACCESS_KEY, 22 | "AWS_SECRET_KEY": AWS_SECRET_KEY, 23 | 24 | "METACULUS_KEY": METACULUS_KEY, 25 | "MANIFOLD_KEY": MANIFOLD_KEY, 26 | "POLYMARKET_KEY": POLYMARKET_KEY, 27 | "CRYPTO_PRIVATE_KEY": CRYPTO_PRIVATE_KEY, 28 | "EMAIL": EMAIL, 29 | "GJOPEN_CSET_PASSWORD": GJOPEN_CSET_PASSWORD, 30 | 31 | "NEWSCASTCHER_KEY": NEWSCASTCHER_KEY, 32 | 33 | "OPENAI_KEY": OPENAI_KEY, 34 | "ANTHROPIC_KEY": ANTHROPIC_KEY, 35 | "TOGETHER_KEY": TOGETHER_KEY, 36 | "GOOGLE_AI_KEY": GOOGLE_AI_KEY, 37 | "HF_ACCESS_TOKEN": HF_ACCESS_TOKEN, 38 | } 39 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "llm_forecasting" 7 | version = "0.0.1" 8 | dependencies = [ 9 | "requests==2.26.0", 10 | "pandas==1.5.3", 11 | "numpy==1.24.3", 12 | "scipy", 13 | "matplotlib", 14 | "openai==1.6.1", 15 | "tqdm", 16 | "together", 17 | "torch", 18 | "bardapi", 19 | "tiktoken", 20 | "langchain", 21 | "transformers", 22 | "boto3", 23 | "anthropic", 24 | "gnews==0.3.6", 25 | "newspaper4k", 26 | "newscatcherapi", 27 | "sentence_transformers", 28 | "markdown2", 29 | "google-generativeai", 30 | "jsonlines", 31 | "selenium", 32 | "aws-wsgi==0.2.7", 33 | "python-dotenv", 34 | "google-cloud-secret-manager", 35 | "black", 36 | "black[jupyter]", 37 | "autopep8", 38 | "flake8", 39 | "pytest", 40 | "pytest-asyncio", 41 | "rich", 42 | ] 43 | 44 | [tool.setuptools.packages.find] 45 | where = ["llm_forecasting"] 46 | -------------------------------------------------------------------------------- /llm_forecasting/prompts/alignment.py: -------------------------------------------------------------------------------- 1 | ALIGNMENT_PROMPT = ( 2 | """Question: 3 | {question} 4 | 5 | Background: 6 | {background} 7 | 8 | Resolution Criteria: 9 | {resolution_criteria} 10 | 11 | Model’s Thinking: 12 | {reasoning} 13 | 14 | Task: 15 | Evaluate the alignment between the model's thinking and its prediction. If someone were given the reasoning alone (without the prediction), would they likely arrive at the same prediction? 16 | 17 | Alignment Ratings: 18 | 1 — Very Not Aligned 19 | 2 — Not Aligned 20 | 3 — Slightly Not Aligned 21 | 4 — Slightly Aligned 22 | 5 — Aligned 23 | 6 — Very Aligned 24 | 25 | Please use these ratings to indicate the degree of alignment between the model's reasoning and its prediction. 26 | 27 | Note: If the response indicates that this question is old or it's already been resolved, give it an alignment rating of 1. 28 | 29 | I want your answer to follow this format: 30 | 31 | Thinking: {{ insert your thinking here }} 32 | Rating: {{ insert your alignment rating here (a number between 1 and 6) }}""", 33 | ("QUESTION", "BACKGROUND", "RESOLUTION_CRITERIA", "REASONING"), 34 | ) 35 | -------------------------------------------------------------------------------- /llm_forecasting/prompts/relevance.py: -------------------------------------------------------------------------------- 1 | RELEVANCE_PROMPT_0 = ( 2 | """Please consider the following forecasting question and its background information. 3 | After that, I will give you a news article and ask you to rate its relevance with respect to the forecasting question. 4 | 5 | Question: 6 | {question} 7 | 8 | Question Background: 9 | {background} 10 | 11 | Question Resolution Criteria: 12 | {resolution_criteria} 13 | 14 | Article: 15 | {article} 16 | 17 | Please rate the relevance of the article to the question, at the scale of 1-6 18 | 1 -- irrelevant 19 | 2 -- slightly relevant 20 | 3 -- somewhat relevant 21 | 4 -- relevant 22 | 5 -- highly relevant 23 | 6 -- most relevant 24 | 25 | Guidelines: 26 | - You don't need to access any external sources. Just consider the information provided. 27 | - Focus on the content of the article, not the title. 28 | - If the text content is an error message about JavaScript, paywall, cookies or other technical issues, output a score of 1. 29 | 30 | Your response should look like the following: 31 | Thoughts: {{ insert your thinking }} 32 | Rating: {{ insert your rating }}""", 33 | ("QUESTION", "BACKGROUND", "RESOLUTION_CRITERIA", "ARTICLE"), 34 | ) 35 | -------------------------------------------------------------------------------- /llm_forecasting/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | # Related third-party imports 2 | import tiktoken 3 | 4 | # Local application/library specific imports 5 | from config.constants import OAI_SOURCE, MODEL_NAME_TO_SOURCE 6 | 7 | 8 | def count_tokens(text, model_name): 9 | """ 10 | Count the number of tokens for a given text. 11 | 12 | Args: 13 | - text (str): The input text whose tokens need to be counted. 14 | - model_name (str): Name of the OpenAI model to be used for token counting. 15 | 16 | Returns: 17 | - int: Number of tokens in the text for the specified model. 18 | """ 19 | model_source = infer_model_source(model_name) 20 | if model_source == OAI_SOURCE: 21 | enc = tiktoken.encoding_for_model(model_name) 22 | token_length = len(enc.encode(text)) 23 | else: 24 | token_length = len(text) / 3 25 | 26 | return token_length 27 | 28 | 29 | def infer_model_source(model_name): 30 | """ 31 | Infer the model source from the model name. 32 | 33 | Args: 34 | - model_name (str): The name of the model. 35 | """ 36 | if "ft:gpt" in model_name: # fine-tuned GPT-3 or 4 37 | return OAI_SOURCE 38 | if model_name not in MODEL_NAME_TO_SOURCE: 39 | raise ValueError(f"Invalid model name: {model_name}") 40 | return MODEL_NAME_TO_SOURCE[model_name] 41 | -------------------------------------------------------------------------------- /llm_forecasting/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | from collections import Counter 3 | 4 | 5 | def flatten_list(nested_list): 6 | flat_list = [item for sublist in nested_list for item in sublist] 7 | return flat_list 8 | 9 | 10 | def most_frequent_item(lst): 11 | """ 12 | Return the most frequent item in the given list. 13 | 14 | If there are multiple items with the same highest frequency, one of them is 15 | returned. 16 | 17 | Args: 18 | lst (list): The list from which to find the most frequent item. 19 | 20 | Returns: 21 | The most frequent item in the list. 22 | """ 23 | if not lst: 24 | return None # Return None if the list is empty 25 | # Count the frequency of each item in the list 26 | count = Counter(lst) 27 | # Find the item with the highest frequency 28 | most_common = count.most_common(1) 29 | return most_common[0][0] # Return the item (not its count) 30 | 31 | 32 | def indices_of_N_largest_numbers(list_of_numbers, N=3): 33 | """ 34 | Return the indices of the N largest numbers in the given list of numbers. 35 | 36 | Args: 37 | list_of_numbers (list): The list of numbers from which to find the N 38 | largest numbers. 39 | N (int, optional): The number of largest numbers to find. Defaults to 3. 40 | 41 | Returns: 42 | list: The indices of the N largest numbers in the given list of numbers. 43 | """ 44 | # Get the indices of the N largest numbers 45 | indices = sorted(range(len(list_of_numbers)), key=lambda i: list_of_numbers[i])[-N:] 46 | return indices 47 | -------------------------------------------------------------------------------- /llm_forecasting/utils/article_utils.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | from datetime import datetime 3 | 4 | # Local application/library-specific imports 5 | from utils import db_utils 6 | from config.keys import keys 7 | from config.constants import S3_BUCKET_NAME, S3 8 | 9 | # Set up constants 10 | AWS_ACCESS_KEY = keys["AWS_ACCESS_KEY"] 11 | AWS_SECRET_KEY = keys["AWS_SECRET_KEY"] 12 | 13 | 14 | def article_object_to_dict(article): 15 | """ 16 | Convert an article object to a dictionary 17 | 18 | Args: 19 | article (Article): An article object (such as NewscatcherArticle) 20 | 21 | Returns: 22 | article_dict (dict): A dictionary containing the article's attributes 23 | such as title, text, authors, etc. 24 | """ 25 | article_dict = {} 26 | for attribute in article.__dict__: 27 | field = getattr(article, attribute) 28 | if ( 29 | isinstance(field, str) # title, text, etc 30 | or isinstance(field, int) 31 | or isinstance(field, float) # relevance ratings 32 | or isinstance(field, list) # authors list 33 | ): 34 | article_dict[attribute] = field 35 | if isinstance(field, datetime): # datetime, etc 36 | article_dict[attribute] = field.strftime("%Y-%m-%d") 37 | return article_dict 38 | 39 | 40 | def article_object_list_to_dict(article_list): 41 | """ 42 | Convert a list of article objects to a list of dictionaries 43 | """ 44 | return [article_object_to_dict(article) for article in article_list] 45 | 46 | 47 | def upload_articles_to_s3(article_list, s3_path="system/info-hp"): 48 | """ 49 | Upload a list of articles to S3 50 | 51 | Args: 52 | article_list (list): A list of article objects (such as NewscatcherArticle) 53 | s3_path (str): The path to save the articles to in S3 54 | 55 | Returns: 56 | None 57 | """ 58 | articles_dict = article_object_list_to_dict(article_list) 59 | s3_filename = f"{s3_path}/articles.pickle" 60 | db_utils.upload_data_structure_to_s3( 61 | s3=S3, data_structure=articles_dict, bucket=S3_BUCKET_NAME, s3_path=s3_filename 62 | ) 63 | -------------------------------------------------------------------------------- /llm_forecasting/utils/validation_utils.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | import logging 3 | import sys 4 | 5 | # Related third-party imports 6 | import openai 7 | 8 | # Set up logging 9 | logging.basicConfig(level=logging.INFO) 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def is_valid_openai_key(api_key): 14 | """ 15 | Check if the given OpenAI API key is valid. 16 | 17 | Args: 18 | api_key (str): OpenAI API key to be validated. 19 | 20 | Returns: 21 | bool: Whether the given API key is valid. 22 | """ 23 | try: 24 | # Set the API key 25 | openai.api_key = api_key 26 | # Make a test request (e.g., listing available models) 27 | openai.Model.list() 28 | # If the above line didn't raise an exception, the key is valid 29 | return True 30 | except openai.error.AuthenticationError: 31 | logger.error("Invalid API key.") 32 | return False 33 | except openai.error.OpenAIError as e: 34 | logger.error(f"An error occurred in validating OpenAI API key: {str(e)}") 35 | return False 36 | 37 | 38 | def is_valid_openai_model(model_name, api_key): 39 | """ 40 | Check if the model name is valid, given a valid OpenAI API key. 41 | 42 | Args: 43 | - model_name (str): Name of the model to be validated, such as "gpt-4" 44 | - api_key (str): OpenAI API key, assumed to be valid. 45 | 46 | Returns: 47 | - bool: Whether the given model name is valid. 48 | """ 49 | try: 50 | openai.api_key = api_key 51 | # Attempt to retrieve information about the model 52 | _ = openai.Model.retrieve(model_name) 53 | return True 54 | except openai.error.OpenAIError as e: 55 | logger.error(f"An error occurred in validing the model name: {str(e)}") 56 | return False 57 | 58 | 59 | def validate_key_and_model(key, model): 60 | """ 61 | Check if the given OpenAI API key and model name are valid. 62 | 63 | Args: 64 | - key (str): OpenAI API key to be validated. 65 | - model (str): Name of the model to be validated, such as "gpt-4" 66 | 67 | If either the key or model is invalid, exit the program. 68 | """ 69 | if not is_valid_openai_key(key) or not is_valid_openai_model(model, key): 70 | sys.exit(1) 71 | -------------------------------------------------------------------------------- /llm_forecasting/prompts/ensemble_reasoning.py: -------------------------------------------------------------------------------- 1 | ENSEMBLE_PROMPT_0 = ( 2 | """I need your assistance with making a forecast. Here is the question and its metadata. 3 | Question: {question} 4 | 5 | Background: {background} 6 | 7 | Resolution criteria: {resolution_criteria} 8 | 9 | Today's date: {date_begin} 10 | Question close date: {date_end} 11 | 12 | I have retrieved the following information about this question. 13 | Retrieved Info: 14 | {retrieved_info} 15 | 16 | In addition, I have generated a collection of other responses and reasonings from other forecasters: 17 | {base_reasonings} 18 | 19 | Your goal is to aggregate the information and make a final prediction. 20 | 21 | Instructions: 22 | 1. Provide reasons why the answer might be no. 23 | {{ Insert your thoughts here }} 24 | 25 | 2. Provide reasons why the answer might be yes. 26 | {{ Insert your thoughts here }} 27 | 28 | 3. Aggregate your considerations. 29 | {{ Insert your aggregated considerations here }} 30 | 31 | 4. Output your prediction (a number between 0 and 1) with an asterisk at the beginning and end of the decimal. 32 | {{ Insert the probability here }}""", 33 | ( 34 | "QUESTION", 35 | "BACKGROUND", 36 | "RESOLUTION_CRITERIA", 37 | "RETRIEVED_INFO", 38 | "DATES", 39 | "BASE_REASONINGS", 40 | ), 41 | ) 42 | 43 | 44 | ENSEMBLE_PROMPT_1 = ( 45 | """I need your assistance with making a forecast. Here is the question and its metadata. 46 | Question: {question} 47 | 48 | Background: {background} 49 | 50 | Resolution criteria: {resolution_criteria} 51 | 52 | Today's date: {date_begin} 53 | Question close date: {date_end} 54 | 55 | I have retrieved the following information about this question. 56 | Retrieved Info: 57 | {retrieved_info} 58 | 59 | In addition, I have generated a collection of other responses and reasonings from other forecasters: 60 | {base_reasonings} 61 | 62 | Your goal is to aggregate the information and make a final prediction. 63 | 64 | Instructions: 65 | 1. Think step by step: {{ Insert your step by step consideration }} 66 | 67 | 2. Aggregate your considerations: {{ Aggregate your considerations }} 68 | 69 | 3. Output your prediction (a number between 0 and 1) with an asterisk at the beginning and end of the decimal. 70 | {{ Insert the probability here }}""", 71 | ( 72 | "QUESTION", 73 | "BACKGROUND", 74 | "RESOLUTION_CRITERIA", 75 | "RETRIEVED_INFO", 76 | "DATES", 77 | "BASE_REASONINGS", 78 | ), 79 | ) 80 | -------------------------------------------------------------------------------- /llm_forecasting/alignment.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | import logging 3 | 4 | # Local application/library-specific imports 5 | import model_eval 6 | import ranking 7 | from utils import string_utils 8 | from prompts.prompts import PROMPT_DICT 9 | 10 | # Set up logging 11 | logging.basicConfig(level=logging.INFO) 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | def get_alignment_scores( 16 | reasonings, 17 | alignment_prompt=PROMPT_DICT["alignment"]["0"], 18 | model_name="gpt-3.5-turbo-1106", 19 | temperature=0, 20 | question=None, 21 | background=None, 22 | resolution_criteria=None, 23 | ): 24 | """ 25 | Compute the alignment score of each reasoning for each model. 26 | 27 | The alignment score assesses if the reasoning is consistent with the 28 | model's prediction, i.e. if one were given the reasoning alone, would she 29 | also predict a similar probability. 30 | 31 | Args: 32 | reasonings (list[list[str]]): A list containing a list of reasonings. 33 | alignment_prompt(dict, optional): Alignment prompt to use. 34 | model_name (str, optional): Model used to compute score. 35 | question (str, optional): Forecasting question. 36 | background (str, optional): Background of question. 37 | resolution_criteria(str, optional): Resolution criteria of question. 38 | 39 | Returns: 40 | list[list[int]]: A list containing a list of scores. 41 | """ 42 | alignment_scores = [] 43 | for model_reasonings in reasonings: 44 | alignment_scores_ = [] 45 | for reasoning in model_reasonings: 46 | prompt = string_utils.get_prompt( 47 | alignment_prompt[0], 48 | alignment_prompt[1], 49 | question=question, 50 | background=background, 51 | resolution_criteria=resolution_criteria, 52 | reasoning=reasoning, 53 | ) 54 | try: 55 | alignment_response = model_eval.get_response_from_model( 56 | model_name=model_name, 57 | prompt=prompt, 58 | max_tokens=2000, 59 | temperature=temperature, 60 | ) 61 | alignment_score = ranking.extract_rating_from_response( 62 | alignment_response 63 | ) 64 | alignment_scores_.append(alignment_score) 65 | except Exception as e: 66 | logger.error(f"Error message: {e}") 67 | logger.info("Failed to calculate alignment score") 68 | continue 69 | alignment_scores.append(alignment_scores_) 70 | return alignment_scores 71 | -------------------------------------------------------------------------------- /llm_forecasting/prompts/base_eval.py: -------------------------------------------------------------------------------- 1 | KNOWLEDGE_PROMPT = ( 2 | """Question: {question} 3 | 4 | The question was posed on {date_begin} and closed on {date_end}. 5 | Instructions: 6 | - Please output "1" if the answer is Yes", "0" if the answer is No or "IDK" if you don't know the answer. Do not return anything else. 7 | - Do not guess. 8 | Answer: {{ Insert answer here }}""", 9 | ("QUESTION", "DATES"), 10 | ) 11 | 12 | GEMINI_KNOLWEDGE_PROMPT = ( 13 | """Here was a forecasting question posed in {date_begin} and resolved in {date_end}. 14 | Question: {question} (Only Output: 1 (Yes), 0 (No), or IDK if you don't know the answer.) 15 | Answer: {{ Insert answer here }} """, 16 | ("QUESTION", "DATES"), 17 | ) 18 | 19 | ZERO_SHOT_PROMPT = ( 20 | """You are an expert superforecaster, familiar with the work of Tetlock and others. Make a prediction of the probability that the question will be resolved as true. You MUST give a probability estimate between 0 and 1 UNDER ALL CIRCUMSTANCES. If for some reason you can't answer, pick the base rate, but return a number between 0 and 1. 21 | 22 | Question: 23 | {question} 24 | 25 | Question Background: 26 | {background} 27 | 28 | Resolution Criteria: 29 | {resolution_criteria} 30 | 31 | Today's date: {date_begin} 32 | Question close date: {date_end} 33 | 34 | Output your answer (a number between 0 and 1) with an asterisk at the beginning and end of the decimal. Do not output anything else. 35 | Answer: {{ Insert answer here }}""", 36 | ("QUESTION", "BACKGROUND", "RESOLUTION_CRITERIA", "DATES"), 37 | ) 38 | 39 | BINARY_SCRATCH_PAD_PROMPT = ( 40 | """Question: 41 | {question} 42 | 43 | Question Background: 44 | {background} 45 | 46 | Resolution Criteria: 47 | {resolution_criteria} 48 | 49 | Today's date: {date_begin} 50 | Question close date: {date_end} 51 | 52 | 53 | Instructions: 54 | 1. Provide reasons why the answer might be no. 55 | {{ Insert your thoughts }} 56 | 57 | 2. Provide reasons why the answer might be yes. 58 | {{ Insert your thoughts }} 59 | 60 | 3. Aggregate your considerations. 61 | {{ Insert your aggregated considerations }} 62 | 63 | 4. Output your answer (a number between 0 and 1) with an asterisk at the beginning and end of the decimal. 64 | {{ Insert your answer }}""", 65 | ("QUESTION", "BACKGROUND", "RESOLUTION_CRITERIA", "DATES"), 66 | ) 67 | 68 | GEMINI_BINARY_SCRATCH_PAD_PROMPT = ( 69 | """Question: 70 | {question} 71 | 72 | Question Background: 73 | {background} 74 | 75 | Resolution Criteria: 76 | {resolution_criteria} 77 | 78 | Today's date: {date_begin} 79 | Question close date: {date_end} 80 | 81 | Output why the answer might be no, why the answer might be yes, aggredate your considerations, and then your answer (a number between 0 and 1) with an asterisk at the beginning and end of the decimal.""", 82 | ("QUESTION", "BACKGROUND", "RESOLUTION_CRITERIA", "DATES"), 83 | ) 84 | -------------------------------------------------------------------------------- /scripts/fine_tune/fine_tune.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | import logging 3 | import json 4 | import openai 5 | from openai import OpenAI 6 | 7 | # Set up logging 8 | logging.basicConfig(level=logging.INFO) 9 | logger = logging.getLogger(__name__) 10 | 11 | client = OpenAI( 12 | api_key="", 13 | organization="", 14 | ) 15 | 16 | 17 | def create_jsonl_for_finetuning(training_data, file_path): 18 | """ 19 | Writes training data to a JSONL file for fine-tuning purposes. 20 | 21 | Args: 22 | training_data (list): A list of tuples containing user and assistant messages. 23 | file_path (str): Path to save the JSONL file. 24 | 25 | Returns: 26 | None: Logs the completion of file writing. 27 | """ 28 | with open(file_path, "w") as jsonl_file: 29 | for user, assistant in training_data: 30 | message = { 31 | "messages": [ 32 | {"role": "user", "content": user}, 33 | {"role": "assistant", "content": assistant}, 34 | ] 35 | } 36 | example = json.dumps(message) 37 | jsonl_file.write(example + "\n") 38 | logger.info(f"|training_data| saved to {file_path} as jsonl") 39 | return None 40 | 41 | 42 | def upload_oai_file(file_path): 43 | """ 44 | Uploads a file to OpenAI API for fine-tuning. 45 | 46 | Args: 47 | file_path (str): Path of the file to be uploaded. 48 | 49 | Returns: 50 | object: Returns a file object as a response from OpenAI API. 51 | """ 52 | with open(file_path, "rb") as file_data: 53 | file_obj = client.files.create(file=file_data, purpose="fine-tune") 54 | 55 | logging.info(f"Uploaded dataset {file_path} to OpenAI API.") 56 | return file_obj 57 | 58 | 59 | def create_oai_finetuning_job(model_name, train_file_id, model_suffix): 60 | """ 61 | Creates a fine-tuning job with OpenAI. 62 | 63 | Args: 64 | model_name (str): Name of the base model to fine-tune. 65 | train_file_id (str): ID of the training file uploaded to OpenAI. 66 | model_suffix (str): Suffix to be added to the fine-tuned model. 67 | 68 | Returns: 69 | object: Returns a fine-tuning job object. 70 | """ 71 | model = client.fine_tuning.jobs.create( 72 | model=model_name, 73 | training_file=train_file_id, 74 | suffix=model_suffix, 75 | ) 76 | return model 77 | 78 | 79 | def check_on_finetuning_job(ft_id): 80 | """ 81 | Checks the status of a fine-tuning job. 82 | 83 | Args: 84 | ft_id (str): ID of the fine-tuning job. 85 | 86 | Returns: 87 | None: Fetches and logs the latest events of the fine-tuning job. 88 | """ 89 | # Use this to check on the output of your job 90 | openai.FineTuningJob.list_events(id=ft_id, limit=5) 91 | return None 92 | -------------------------------------------------------------------------------- /llm_forecasting/prompts/prompts.py: -------------------------------------------------------------------------------- 1 | from prompts.relevance import * 2 | from prompts.base_reasoning import * 3 | from prompts.search_query import * 4 | from prompts.summarization import * 5 | from prompts.ensemble_reasoning import * 6 | from prompts.alignment import * 7 | from prompts.system import * 8 | from prompts.data_wrangling import * 9 | from prompts.base_eval import * 10 | 11 | PROMPT_DICT = { 12 | "binary": { 13 | "scratch_pad": { 14 | "0": BINARY_SCRATCH_PAD_PROMPT_0, 15 | "1": BINARY_SCRATCH_PAD_PROMPT_1, 16 | "2": BINARY_SCRATCH_PAD_PROMPT_2, 17 | "2_tok": BINARY_SCRATCH_PAD_PROMPT_2_TOKENS, 18 | "3": BINARY_SCRATCH_PAD_PROMPT_3, 19 | "new_0": BINARY_SCRATCH_PAD_PROMPT_NEW_0, 20 | "new_1": BINARY_SCRATCH_PAD_PROMPT_NEW_1, 21 | "new_2": BINARY_SCRATCH_PAD_PROMPT_NEW_2, 22 | "new_3": BINARY_SCRATCH_PAD_PROMPT_NEW_3, 23 | "new_4": BINARY_SCRATCH_PAD_PROMPT_NEW_4, 24 | "new_5": BINARY_SCRATCH_PAD_PROMPT_NEW_5, 25 | "new_6": BINARY_SCRATCH_PAD_PROMPT_NEW_6, 26 | "new_7": BINARY_SCRATCH_PAD_PROMPT_NEW_7, 27 | }, 28 | "rar": { 29 | "0": RAR_PROMPT_0, 30 | "1": BINARY_SCRATCH_PAD_PROMPT_1_RAR, 31 | "2": BINARY_SCRATCH_PAD_PROMPT_2_RAR, 32 | }, 33 | "emotion": { 34 | "0": EMOTION_PROMPT_0, 35 | }, 36 | }, 37 | "ranking": { 38 | "0": RELEVANCE_PROMPT_0, 39 | }, 40 | "alignment": { 41 | "0": ALIGNMENT_PROMPT, 42 | }, 43 | "search_query": { 44 | "0": SEARCH_QUERY_PROMPT_0, 45 | "1": SEARCH_QUERY_PROMPT_1, 46 | "2": SEARCH_QUERY_PROMPT_2, 47 | "3": SEARCH_QUERY_PROMPT_3, 48 | "4": SEARCH_QUERY_PROMPT_4, 49 | "5": SEARCH_QUERY_PROMPT_5, 50 | "6": SEARCH_QUERY_PROMPT_6, 51 | "7": SEARCH_QUERY_PROMPT_7, 52 | "8": SEARCH_QUERY_PROMPT_8, 53 | }, 54 | "summarization": { 55 | "0": SUMMARIZATION_PROMPT_0, 56 | "1": SUMMARIZATION_PROMPT_1, 57 | "2": SUMMARIZATION_PROMPT_2, 58 | "3": SUMMARIZATION_PROMPT_3, 59 | "4": SUMMARIZATION_PROMPT_4, 60 | "5": SUMMARIZATION_PROMPT_5, 61 | "6": SUMMARIZATION_PROMPT_6, 62 | "7": SUMMARIZATION_PROMPT_7, 63 | "8": SUMMARIZATION_PROMPT_8, 64 | "9": SUMMARIZATION_PROMPT_9, 65 | "10": SUMMARIZATION_PROMPT_10, 66 | "11": SUMMARIZATION_PROMPT_11, 67 | }, 68 | "meta_reasoning": { 69 | "0": ENSEMBLE_PROMPT_0, 70 | "1": ENSEMBLE_PROMPT_1, 71 | }, 72 | "system": { 73 | "0": SYSTEM_SUPERFORECASTER_0, 74 | }, 75 | "data_wrangling": { 76 | "is_bad_title": IS_BAD_TITLE_PROMPT, 77 | "reformat": REFORMAT_PROMPT, 78 | "assign_category": ASSIGN_CATEGORY_PROMPT, 79 | }, 80 | "base_eval": { 81 | "knowledge": KNOWLEDGE_PROMPT, 82 | "gemini_knowledge": GEMINI_KNOLWEDGE_PROMPT, 83 | "zero_shot": ZERO_SHOT_PROMPT, 84 | "scratch_pad": BINARY_SCRATCH_PAD_PROMPT, 85 | "gemini_scratch_pad": GEMINI_BINARY_SCRATCH_PAD_PROMPT, 86 | }, 87 | } 88 | -------------------------------------------------------------------------------- /llm_forecasting/prompts/data_wrangling.py: -------------------------------------------------------------------------------- 1 | ASSIGN_CATEGORY_PROMPT = ( 2 | """Question: {question} 3 | 4 | Background: {background} 5 | 6 | Options: 7 | ['Science & Tech', 8 | 'Healthcare & Biology', 9 | 'Economics & Business', 10 | 'Environment & Energy', 11 | 'Politics & Governance', 12 | 'Education & Research', 13 | 'Arts & Recreation', 14 | 'Security & Defense', 15 | 'Social Sciences', 16 | 'Sports', 17 | 'Other'] 18 | 19 | Instruction: Assign a category for the given question. 20 | 21 | Rules: 22 | 1. Make sure you only return one of the options from the option list. 23 | 2. Only output the category, and do not output any other words in your response. 24 | 3. You have to pick a string from the above categories. 25 | 26 | Answer:""", 27 | ("QUESTION", "BACKGROUND"), 28 | ) 29 | 30 | IS_BAD_TITLE_PROMPT = ( 31 | """I'm trying to assess the quality of an old forecasting dataset. 32 | 33 | Here is a forecasting question from the dataset: {question}. 34 | 35 | Please flag questions that don't sound like binary forecasting questions by outputting "flag". If it sounds like a reasonable question, output "ok". 36 | 37 | Examples of strings that should be flagged: 38 | "Will I finish my homework tonight?" 39 | "Metaculus party 2023" 40 | "Will Hell freeze over?" 41 | "Heads or tails" 42 | "Will this video reach 100k views by the EOD?" 43 | Examples of strings that should not be flagged: 44 | "Will Megan Markle and Prince Harry have a baby by the end of the year?" 45 | "Will the Brain Preservation Foundation's Large Mammal preservation prize be won by Feb 9th, 2017?" 46 | "Will there be more novel new drugs approved by the FDA in 2016 than in 2015?" 47 | 48 | If a question is already resolved, that doesn't mean it should be flagged. When in doubt, mark it as "ok". 49 | 50 | Your response should take the following structure: 51 | Insert thinking: 52 | {{ insert your concise thoughts here }} 53 | Classification: 54 | {{ insert "flag" or "ok"}}""", 55 | ("QUESTION", "BACKGROUND"), 56 | ) 57 | 58 | REFORMAT_PROMPT = ( 59 | """I have questions that need to be transformed for clarity. 60 | 61 | Here are some examples: 62 | Example 1: 63 | Before: Who will win the 2022-2023 Premier League? (Leicester City) 64 | After: *Will Leicester City win the 2022-2023 Premier League?* 65 | 66 | Example 2: 67 | Before: What coalition will govern Berlin after the 2023 repeat state election? (SPD+Greens) 68 | After: *Will SPD+Greens govern Berlin after the 2023 repeat state election?* 69 | 70 | Example 3: 71 | Before: If Republicans win control of the House of Representatives in the 2022 election, who will be the next Majority Whip of the U.S. House of Representatives? (Rep. Jim Banks) 72 | After: *If Republicans win control of the House of Representatives in the 2022 election, will Jim Banks be the next Majority Whip of the U.S. House of Representatives?* 73 | 74 | Example 4: 75 | Before: Economic Trouble: Will a country’s currency depreciate 15% or more in the second half of 2022? (Thai Baht ฿) 76 | After: *Economic Trouble: Will the Thai Baht ฿ currency depreciate 15% or more in the second half of 2022?* 77 | 78 | Example 5: 79 | Before: How many of the claims from study 3 of "Behavioral nudges reduce failure to appear for court" (Science, 2020) replicate? After: *Will exactly 2 claims from study 3 of "Behavioral nudges reduce failure to appear for court" (Science, 2020) replicate?* 80 | 81 | Can you now transform this question for clarity: {question} 82 | 83 | Please place stars around the transformed question. 84 | 85 | Your output should take the following structure: 86 | Before: {insert the original question} 87 | After: *{insert the transformed question}*""", 88 | ("QUESTION",), 89 | ) 90 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLM Forecasting 2 | 3 | This repository hosts the code for the paper: [Approaching Human-Level Forecasting with Language Models](https://arxiv.org/abs/2402.18563). 4 | 5 |

6 | 7 |

8 | 9 | Our system is designed to make automated, _simulated_ forecasts by following these steps: 10 | 1. **Search Query Generation**: A language model (LM) is prompted to create search queries to retrieve articles published before a certain date from a news API. 11 | 2. **Assess Article Relevancy**: A LM rates the relevancy of the retrieved articles and filters out non-relevant ones. 12 | 3. **Summarize Articles**: A LM is prompted to retain the salient information relevant to the question from the filtered articles. 13 | 4. **Reason and Predict**: A LM (base or fine-tuned) is prompted multiple times to produce reasoning and predictions based on the article summaries. 14 | 5. **Forecast Aggregation**: An aggregation method is applied to all the predictions to obtain a final forecast. 15 | 16 | We've designed our system to be easily scalable to other news APIs and language models. 17 | 18 | ## Development Guide 19 | This guide will help you set up your development environment for the project. Make sure you follow these steps under the root directory of the project (where **pyproject.toml** is located). 20 | 21 | ### Prerequisites 22 | **Miniconda or Anaconda**: 23 | Ensure you have Miniconda or Anaconda installed as they include the Conda package manager. 24 | 25 | ### Setting Up Your Environment 26 | 1. **Create and Activate a Conda Environment**
27 | Start by creating a Conda environment specifically for this project. 28 | ``` 29 | conda create -n myenv python=3.11 30 | conda activate myenv 31 | ``` 32 | Replace **myenv** with your preferred environment name and adjust **python=3.11** if needed. 33 | 34 | 35 | 2. **Install Setuptools and Upgrade Build**
36 | Within your Conda environment: 37 | ``` 38 | conda install setuptools 39 | pip install --upgrade build 40 | ``` 41 | 42 | 3. **Install Package in Editable Mode**
43 | Install the package in editable mode to allow changes in your code to be immediately reflected in the installed package. This also installs all dependencies listed in `project.toml`. 44 | ``` 45 | pip install --editable . 46 | ``` 47 | 48 | ### Usage 49 | Once the setup is complete, you can start using the package. For example, you can import modules from the forecasting subdirectory as follows: 50 | ``` 51 | import information_retrieval 52 | import utils 53 | from prompts.prompts import prompt_dict 54 | ``` 55 | 56 | ### Deactivating Your Environment 57 | When you're finished working, you can deactivate the Conda environment: 58 | ``` 59 | conda deactivate 60 | ``` 61 | 62 | ## Demo 63 | See the [system_demo](https://github.com/dannyallover/llm_forecasting/blob/main/notebooks/demo/system_demo.ipynb) for an example on how to run our system. 64 | 65 | ## Dataset 66 | The cleaned and formatted dataset can be found on [huggingface](https://huggingface.co/datasets/YuehHanChen/forecasting), as well as the [raw dataset](https://huggingface.co/datasets/YuehHanChen/forecasting_raw). 67 | 68 | ## Contributing 69 | We welcome contributions to this repository. If you'd like to contribute, please follow these steps: 70 | 71 | 1. Fork the repository. 72 | 2. Create a new branch for your feature or bug fix: `git checkout -b my-new-feature` 73 | 3. Make your changes and commit them: `git commit -am 'Add some feature'` 74 | 4. Push your changes to your forked repository: `git push origin my-new-feature` 75 | 5. Open a pull request against the main repository. 76 | 77 | Please ensure that your code adheres to the project's coding conventions and that you include tests for any new functionality or bug fixes. 78 | 79 | ## Citation 80 | If our codebase, dataset, or paper proves fruitful for your work, please consider citing us. 81 | ``` 82 | @misc{halawi2024approaching, 83 | title={Approaching Human-Level Forecasting with Language Models}, 84 | author={Danny Halawi and Fred Zhang and Chen Yueh-Han and Jacob Steinhardt}, 85 | year={2024}, 86 | eprint={2402.18563}, 87 | archivePrefix={arXiv}, 88 | primaryClass={cs.LG} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /llm_forecasting/utils/api_utils.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | import logging 3 | import time 4 | 5 | # Related third-party imports 6 | import requests 7 | 8 | # Set up logging 9 | logging.basicConfig(level=logging.INFO) 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | def request_with_retries( 14 | method, url, headers, params=None, data=None, max_retries=5, delay=30 15 | ): 16 | """ 17 | Make an API request (GET or POST) with retries in case of rate-limiting 18 | (HTTP 429) or other defined conditions and return the JSON content or log 19 | an error and return None. 20 | 21 | Args: 22 | method (str): HTTP method ('GET' or 'POST'). 23 | url (str): The API endpoint. 24 | headers (dict): Headers for the API request. 25 | params (dict, optional): Parameters for the API request. 26 | data (dict, optional): JSON data for the API request (used for POST). 27 | max_retries (int, optional): Maximum number of retries. Defaults to 5. 28 | delay (int, optional): Delay (in seconds) between retries. Defaults to 29 | 30. 30 | 31 | Returns: 32 | dict or None: The JSON response content as a dictionary or None if an 33 | error occurred. 34 | """ 35 | for _ in range(max_retries): 36 | try: 37 | if method == "GET": 38 | response = requests.get(url, headers=headers, params=params) 39 | elif method == "POST": 40 | response = requests.post(url, headers=headers, json=data) 41 | else: 42 | logging.error(f"Unsupported method: {method}") 43 | return None 44 | 45 | if response.status_code == 429: 46 | time.sleep(delay) 47 | continue 48 | 49 | response.raise_for_status() 50 | return response.json() 51 | 52 | except requests.RequestException as e: 53 | logging.error(f"Request error: {e}") 54 | return None 55 | 56 | logging.error(f"Exceeded max retries for URL: {url}") 57 | return None 58 | 59 | 60 | def get_response_content(url, headers, params=None, max_retries=5, delay=30): 61 | """ 62 | Create a wrapper function that issues a GET API request, utilizing a 63 | generic retry mechanism. 64 | """ 65 | return request_with_retries( 66 | "GET", url, headers, params=params, max_retries=max_retries, delay=delay 67 | ) 68 | 69 | 70 | def post_request_with_retries(endpoint, headers, payload, retries=5): 71 | """ 72 | Create a wrapper function that makes a POST API request using the generic 73 | retry mechanism. 74 | """ 75 | response = request_with_retries( 76 | "POST", endpoint, headers, data=payload, max_retries=retries 77 | ) 78 | if ( 79 | response 80 | and "detail" in response 81 | and "Expected available in" in response["detail"] 82 | ): 83 | wait_seconds = int(response["detail"].split(" ")[-2]) + 1 84 | time.sleep(wait_seconds) 85 | return request_with_retries( 86 | "POST", endpoint, headers, data=payload, max_retries=retries 87 | ) 88 | return response 89 | 90 | 91 | def fetch_all_questions(base_url, headers, params): 92 | """ 93 | Fetch all questions from the API using pagination. 94 | 95 | Args: 96 | - base_url (str): The base URL of the API. 97 | - headers (dict): The headers to use for the requests. 98 | - params (dict): The parameters to use for the requests. 99 | 100 | Returns: 101 | - list: List of all questions fetched from the API. 102 | """ 103 | all_questions = [] 104 | 105 | current_url = base_url 106 | while current_url: 107 | logging.info(f"Fetching data from {current_url} with params: {params}") 108 | 109 | data = get_response_content(current_url, headers, params) 110 | if not data: 111 | break 112 | 113 | if "results" in data: 114 | all_questions.extend(data["results"]) 115 | else: 116 | all_questions.append(data) 117 | 118 | current_url = data.get("next") 119 | 120 | return all_questions 121 | -------------------------------------------------------------------------------- /llm_forecasting/utils/metrics_utils.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | import logging 3 | 4 | # Related third-party imports 5 | import numpy as np 6 | from numpy.linalg import norm 7 | import torch 8 | 9 | # Local application/library-specific imports 10 | from utils import time_utils 11 | 12 | # Set up logging 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def brier_score(probabilities, answer_idx): 18 | """ 19 | Calculate the Brier score for a set of probabilities and the correct answer 20 | index. 21 | 22 | Args: 23 | - probabilities (numpy array): The predicted probabilities for each class. 24 | - answer_idx (int): Index of the correct answer. 25 | 26 | Returns: 27 | - float: The Brier score. 28 | """ 29 | answer = np.zeros_like(probabilities) 30 | answer[answer_idx] = 1 31 | return ((probabilities - answer) ** 2).sum() / 2 32 | 33 | 34 | def calculate_cosine_similarity_bert(text_list, tokenizer, model): 35 | """ 36 | Calculate the average cosine similarity between texts in a given list using 37 | embeddings. 38 | 39 | Define Bert outside the function: 40 | from transformers import BertTokenizer, BertModel 41 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 42 | model = BertModel.from_pretrained('bert-base-uncased') 43 | 44 | Parameters: 45 | text_list (List[str]): A list of strings where each string is a text 46 | document. 47 | 48 | Returns: 49 | float: The average cosine similarity between each pair of texts in the list. 50 | Returns 0 if the list contains less than two text documents. 51 | """ 52 | if len(text_list) < 2: 53 | return 0 54 | 55 | # Function to get embeddings 56 | def get_embedding(text): 57 | inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512) 58 | outputs = model(**inputs) 59 | return torch.mean(outputs.last_hidden_state, dim=1) 60 | 61 | # Generating embeddings for each text 62 | embeddings = [get_embedding(text) for text in text_list] 63 | 64 | # Calculating cosine similarity between each pair of embeddings 65 | similarity_scores = [] 66 | for i in range(len(embeddings)): 67 | for j in range(i + 1, len(embeddings)): 68 | similarity = torch.nn.functional.cosine_similarity( 69 | embeddings[i], embeddings[j] 70 | ) 71 | similarity_scores.append(similarity.item()) 72 | 73 | # Calculating average similarity 74 | average_similarity = np.mean(similarity_scores) 75 | 76 | return average_similarity 77 | 78 | 79 | def cosine_similarity(u, v): 80 | """ 81 | Compute the cosine similarity between two vectors. 82 | """ 83 | return np.dot(u, v) / (norm(u) * norm(v)) 84 | 85 | 86 | def get_average_forecast(date_pred_list): 87 | """ 88 | Retrieve the average forecast value from the list of predictions. 89 | 90 | Args: 91 | - date_pred_list (list of tuples): list contain tuples of (date str, pred). 92 | 93 | Returns: 94 | - float: The average prediction. 95 | """ 96 | if not date_pred_list or len(date_pred_list) == 0: 97 | return 0.5 # Return a default value of 0.5 if there is no history 98 | return sum(tup[1] for tup in date_pred_list) / len(date_pred_list) 99 | 100 | 101 | def compute_bs_and_crowd_bs(pred, date_pred_list, retrieve_date, answer): 102 | """ 103 | Computes Brier scores for individual prediction and community prediction. 104 | 105 | Parameters: 106 | - pred (float): The individual's probability prediction for an event. 107 | - date_pred_list (list of tuples): A list of tuples containing dates 108 | and community predictions. Each tuple is in the format (date, prediction). 109 | - retrieve_date (date): The date for which the community prediction is to be retrieved. 110 | - answer (int): The actual outcome of the event, where 0 indicates the event 111 | did not happen, and 1 indicates it did. 112 | 113 | Returns: 114 | - bs (float): The Brier score for the individual prediction. 115 | - bs_comm (float): The Brier score for the community prediction closest to the specified retrieve_date. 116 | """ 117 | pred_comm = time_utils.find_closest_date(retrieve_date, date_pred_list)[-1] 118 | bs = brier_score([1 - pred, pred], answer) 119 | bs_comm = brier_score([1 - pred_comm, pred_comm], answer) 120 | 121 | return bs, bs_comm 122 | -------------------------------------------------------------------------------- /llm_forecasting/utils/db_utils.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | from io import StringIO 3 | import logging 4 | import os 5 | import pickle 6 | import random 7 | import string 8 | 9 | # Related third-party imports 10 | import boto3 11 | import pandas as pd 12 | 13 | # Set up logging 14 | logging.basicConfig(level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def initialize_s3_client(aws_access_key_id, aws_secret_access_key): 19 | """ 20 | Initialize an Amazon S3 client using provided AWS credentials. 21 | 22 | Args: 23 | - aws_access_key_id (str): AWS access key for authentication. 24 | - aws_secret_access_key (str): AWS secret access key for authentication. 25 | 26 | Returns: 27 | - boto3.client: Initialized S3 client. 28 | """ 29 | return boto3.client( 30 | "s3", 31 | aws_access_key_id=aws_access_key_id, 32 | aws_secret_access_key=aws_secret_access_key, 33 | ) 34 | 35 | 36 | def upload_data_structure_to_s3(s3, data_structure, bucket, s3_path): 37 | """ 38 | Upload a local file to a specified path in an Amazon S3 bucket. 39 | 40 | Args: 41 | s3 (boto3.client): An initialized S3 client instance. 42 | data_structure (.): Data structure to save. 43 | bucket (str): Name of the target S3 bucket. 44 | s3_path (str): Desired filename within the S3 bucket. 45 | """ 46 | try: 47 | extension = s3_path.split(".")[-1] 48 | hash = "".join(random.choices(string.ascii_letters + string.digits, k=5)) 49 | temp_file = f"temp{hash}.{extension}" 50 | with open(temp_file, "wb") as f: 51 | pickle.dump(data_structure, f) 52 | 53 | s3.upload_file(temp_file, bucket, s3_path) 54 | os.remove(temp_file) 55 | logging.info(f"Successfully uploaded data to {bucket}/{s3_path}") 56 | except Exception as e: 57 | logging.error(f"Error uploading data to {bucket}/{s3_path}. Error: {e}") 58 | 59 | 60 | def upload_file_to_s3(s3, local_file, bucket, s3_path): 61 | """ 62 | Upload a local file to a specified path in an Amazon S3 bucket. 63 | 64 | Args: 65 | - s3 (boto3.client): An initialized S3 client instance. 66 | - local_file (str): Path of the local file to upload. 67 | - bucket (str): Name of the target S3 bucket. 68 | - s3_path (str): Desired path within the S3 bucket. 69 | 70 | """ 71 | try: 72 | s3.upload_file(local_file, bucket, s3_path) 73 | logging.info(f"Successfully uploaded {local_file} to {bucket}/{s3_path}") 74 | except Exception as e: 75 | logging.error(f"Error uploading {local_file} to {bucket}/{s3_path}. Error: {e}") 76 | 77 | 78 | def read_pickle_from_s3(s3, bucket, s3_path): 79 | """ 80 | Fetch and deserialize a pickle file from an S3 bucket. 81 | 82 | This can be used in conjunction with `upload_data_structure_to_s3` to 83 | upload and download arbitrary data structures to/from S3. 84 | 85 | Args: 86 | - s3 (boto3.client): An initialized S3 client. 87 | - bucket (str): Name of the S3 bucket containing the pickle file. 88 | - s3_path (str): Path of the pickle file within the S3 bucket. 89 | 90 | Returns: 91 | - object: Deserialized object from the pickle file. 92 | """ 93 | obj = s3.get_object(Bucket=bucket, Key=s3_path) 94 | data = pickle.loads(obj["Body"].read()) 95 | return data 96 | 97 | 98 | def read_pickle_files_from_s3_folder(s3, bucket, s3_folder_path): 99 | # List objects in the specified S3 folder 100 | objects = s3.list_objects(Bucket=bucket, Prefix=s3_folder_path) 101 | 102 | pickle_files = [] 103 | # Loop through the objects and filter for pickle files 104 | for obj in objects.get("Contents", []): 105 | key = obj["Key"] 106 | if key.endswith(".pickle"): 107 | # Download the pickle file 108 | response = s3.get_object(Bucket=bucket, Key=key) 109 | # Read the pickle file from the response 110 | pickle_data = pickle.loads(response["Body"].read()) 111 | pickle_files.append(pickle_data) 112 | 113 | return pickle_files 114 | 115 | 116 | def read_csv_from_s3(s3, bucket, s3_path): 117 | """ 118 | Fetch and read a CSV file from an S3 bucket into a pandas DataFrame. 119 | 120 | Args: 121 | - s3 (boto3.client): An initialized S3 client. 122 | - bucket (str): Name of the S3 bucket containing the CSV file. 123 | - s3_path (str): Path of the CSV file within the S3 bucket. 124 | 125 | Returns: 126 | - pd.DataFrame: DataFrame populated with the CSV data. 127 | """ 128 | obj = s3.get_object(Bucket=bucket, Key=s3_path) 129 | df = pd.read_csv(StringIO(obj["Body"].read().decode("utf-8"))) 130 | return df 131 | -------------------------------------------------------------------------------- /notebooks/results/results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "aab054a6", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import pickle\n", 11 | "import numpy as np" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "340d2fa4", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "def calculate_normalized_weighted_trimmed_mean(predictions):\n", 22 | " # Step 1: Find the median\n", 23 | " median_prediction = np.median(predictions)\n", 24 | "\n", 25 | " # Step 2: Determine the prediction farthest from the median\n", 26 | " distances = np.abs(predictions - median_prediction)\n", 27 | " max_distance = np.max(distances)\n", 28 | "\n", 29 | " # Step 3: Down-weight the furthest prediction by half\n", 30 | " weights = np.ones(len(predictions))\n", 31 | " weights[distances == max_distance] *= 0.5\n", 32 | "\n", 33 | " # Step 4: Distribute the saved weight among other predictions\n", 34 | " saved_weight = (1.0 - 0.5) / (len(predictions) - 1)\n", 35 | " weights[distances != max_distance] += saved_weight\n", 36 | "\n", 37 | " # Step 5: Calculate the weighted mean\n", 38 | " weighted_mean = np.average(predictions, weights=weights)\n", 39 | "\n", 40 | " return weighted_mean" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": 3, 46 | "id": "dc954c5e", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "with open(\"data/answers.pickle\", \"rb\") as file:\n", 51 | " answers = pickle.load(file)\n", 52 | "with open(\"data/base_predictions.pickle\", \"rb\") as file:\n", 53 | " base_predictions = pickle.load(file)\n", 54 | "with open(\"data/finetuned_predictions.pickle\", \"rb\") as file:\n", 55 | " finetuned_predictions = pickle.load(file)\n", 56 | "with open(\"data/finetuned_other_predictions.pickle\", \"rb\") as file:\n", 57 | " finetuned_other_predictions = pickle.load(file)\n", 58 | "with open(\"data/crowd_predictions.pickle\", \"rb\") as file:\n", 59 | " community_predictions = pickle.load(file)" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 4, 65 | "id": "be79c2e2", 66 | "metadata": {}, 67 | "outputs": [ 68 | { 69 | "name": "stdout", 70 | "output_type": "stream", 71 | "text": [ 72 | "Base Brier Score: 0.1863574497732625\n", 73 | "Finetuned Brier Score: 0.18005945446102836\n", 74 | "Finetuned and Base Brier Score: 0.17988713531620448\n", 75 | "Crowd Brier Score: 0.1486199294280867\n" 76 | ] 77 | } 78 | ], 79 | "source": [ 80 | "base_brier_score = 0\n", 81 | "finetuned_brier_score = 0\n", 82 | "finetuned_and_base_brier_score = 0\n", 83 | "crowd_brier_score = 0\n", 84 | "n = 0\n", 85 | "for i in range(5): # num retrieval dates\n", 86 | " for j in range(len(finetuned_predictions[i])):\n", 87 | " answer = answers[i][j]\n", 88 | " \n", 89 | " base_preds = base_predictions[i][j]\n", 90 | " finetuned_preds = finetuned_predictions[i][j]\n", 91 | " finetuned_other_preds = finetuned_other_predictions[i][j]\n", 92 | " crowd_pred = community_predictions[i][j]\n", 93 | " \n", 94 | " base_pred = calculate_normalized_weighted_trimmed_mean(base_preds)\n", 95 | " finetuned_pred = np.mean(finetuned_preds + finetuned_other_preds)\n", 96 | " finetuned_and_base_pred = calculate_normalized_weighted_trimmed_mean(base_preds + finetuned_preds + finetuned_other_preds)\n", 97 | " \n", 98 | " base_brier_score += (base_pred - answer) ** 2\n", 99 | " finetuned_brier_score += (finetuned_pred - answer) ** 2\n", 100 | " finetuned_and_base_brier_score += (finetuned_and_base_pred - answer) ** 2\n", 101 | " crowd_brier_score += (crowd_pred - answer) ** 2\n", 102 | " n += 1\n", 103 | "\n", 104 | "print(\"Base Brier Score:\", base_brier_score/n)\n", 105 | "print(\"Finetuned Brier Score:\", finetuned_brier_score/n)\n", 106 | "print(\"Finetuned and Base Brier Score:\", finetuned_and_base_brier_score/n)\n", 107 | "print(\"Crowd Brier Score:\", crowd_brier_score/n)" 108 | ] 109 | } 110 | ], 111 | "metadata": { 112 | "kernelspec": { 113 | "display_name": "Python 3 (ipykernel)", 114 | "language": "python", 115 | "name": "python3" 116 | }, 117 | "language_info": { 118 | "codemirror_mode": { 119 | "name": "ipython", 120 | "version": 3 121 | }, 122 | "file_extension": ".py", 123 | "mimetype": "text/x-python", 124 | "name": "python", 125 | "nbconvert_exporter": "python", 126 | "pygments_lexer": "ipython3", 127 | "version": "3.11.0" 128 | } 129 | }, 130 | "nbformat": 4, 131 | "nbformat_minor": 5 132 | } 133 | -------------------------------------------------------------------------------- /scripts/data_scraping/cset.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | import argparse 3 | import datetime 4 | import json 5 | import logging 6 | import os 7 | import random 8 | import time 9 | 10 | # Related third-party imports 11 | import jsonlines 12 | from selenium.webdriver.common.by import By 13 | 14 | # Local application/library-specific imports 15 | import data_scraping 16 | from config.keys import keys 17 | 18 | logger = logging.getLogger(__name__) 19 | MAX_CONSECUTIVE_NOT_FOUND = 1000 20 | # Writing to file for debugging purposes. It will be deleted once the script is done. 21 | FILE_PATH = "cset_dump.jsonl" 22 | # Use your own executable_path (download from https://chromedriver.chromium.org/). 23 | CHROMEDRIVER_PATH = "/Users/apple/Downloads/chromedriver-mac-x64/chromedriver" 24 | CSET_EMAIL = keys["EMAIL_JOHN"] 25 | CSET_PASSWORD = keys["GJOPEN_CSET_PASSWORD_JOHN"] 26 | 27 | 28 | def main(n_days): 29 | """ 30 | Scrape, process, and upload question data from CSET (https://www.infer-pub.com/) 31 | 32 | Args: 33 | n_days (int): Number of days to look back for questions. 34 | 35 | Returns: 36 | list: A list of processed question data. 37 | """ 38 | driver = data_scraping.initialize_and_login( 39 | signin_page="https://www.infer-pub.com/users/sign_in", 40 | email=CSET_EMAIL, 41 | password=CSET_PASSWORD, 42 | executable_path=CHROMEDRIVER_PATH, 43 | ) 44 | 45 | question_counter = 0 46 | consecutive_not_found_or_skipped = 0 47 | while True: 48 | question_counter += 1 49 | url = f"https://www.infer-pub.com/questions/{question_counter}" 50 | 51 | try: 52 | driver.get(url) 53 | trend_graph_element = driver.find_element( 54 | By.CSS_SELECTOR, 55 | "div[data-react-class='FOF.Forecast.QuestionTrendGraph']", 56 | ) 57 | props = json.loads(trend_graph_element.get_attribute("data-react-props")) 58 | props["extracted_articles_urls"] = data_scraping.get_source_links( 59 | driver, url 60 | ) 61 | 62 | with jsonlines.open(FILE_PATH, mode="a") as writer: 63 | writer.write(props) 64 | consecutive_not_found_or_skipped = 0 65 | except BaseException: 66 | if data_scraping.question_not_found(driver): 67 | logger.info(f"Question {question_counter} not found") 68 | else: 69 | logger.info(f"Skipping question {question_counter}") 70 | consecutive_not_found_or_skipped += 1 71 | if consecutive_not_found_or_skipped > MAX_CONSECUTIVE_NOT_FOUND: 72 | logger.info("Reached maximum consecutive not found.") 73 | break 74 | 75 | time.sleep(random.uniform(0, 2)) # random delay between requests 76 | 77 | data = [] 78 | with open(FILE_PATH, "r") as file: 79 | for line in file: 80 | json_line = json.loads(line) 81 | data.append(json_line) 82 | 83 | # Remove duplicated dicts 84 | unique_tuples = {data_scraping.make_hashable(d) for d in data} 85 | all_questions = [data_scraping.unhashable_to_dict(t) for t in unique_tuples] 86 | 87 | if n_days is not None: 88 | date_limit = datetime.datetime.now() - datetime.timedelta(days=n_days) 89 | all_questions = [ 90 | q 91 | for q in all_questions 92 | if datetime.datetime.fromisoformat(q["question"]["created_at"][:-1]) 93 | >= date_limit 94 | ] 95 | 96 | logger.info(f"Number of cset questions fetched: {len(all_questions)}") 97 | 98 | for i in range(len(all_questions)): 99 | try: 100 | all_questions[i]["community_prediction"] = all_questions[i].pop( 101 | "trend_graph_probabilities" 102 | ) 103 | except BaseException: 104 | all_questions[i]["community_prediction"] = [] 105 | all_questions[i]["url"] = "https://www.infer-pub.com/questions/" + str( 106 | all_questions[i]["question"]["id"] 107 | ) 108 | all_questions[i]["title"] = all_questions[i]["question"]["name"] 109 | all_questions[i]["close_time"] = all_questions[i]["question"].pop("closed_at") 110 | all_questions[i]["created_time"] = all_questions[i]["question"].pop( 111 | "created_at" 112 | ) 113 | all_questions[i]["background"] = all_questions[i]["question"]["description"] 114 | all_questions[i]["data_source"] = "cset" 115 | 116 | if all_questions[i]["question"]["state"] != "resolved": 117 | all_questions[i]["resolution"] = "Not resolved." 118 | all_questions[i]["is_resolved"] = False 119 | else: 120 | all_questions[i]["resolution"] = all_questions[i]["question"]["answers"][0][ 121 | "probability" 122 | ] 123 | all_questions[i]["is_resolved"] = True 124 | 125 | all_questions[i]["question_type"] = "binary" 126 | answer_set = set( 127 | [answer["name"] for answer in all_questions[i]["question"]["answers"]] 128 | ) 129 | if answer_set == {"Yes", "No"} or answer_set == {"Yes"}: 130 | all_questions[i]["question_type"] = "binary" 131 | else: 132 | all_questions[i]["question_type"] = "multiple_choice" 133 | 134 | logger.info("Uploading to s3...") 135 | question_types = ["binary", "multiple_choice"] 136 | data_scraping.upload_scraped_data(all_questions, "cset", question_types, n_days) 137 | 138 | # Delete the file after script completion 139 | if os.path.exists(FILE_PATH): 140 | os.remove(FILE_PATH) 141 | logger.info(f"Deleted the file: {FILE_PATH}") 142 | else: 143 | logger.info(f"The file {FILE_PATH} does not exist") 144 | 145 | 146 | if __name__ == "__main__": 147 | parser = argparse.ArgumentParser(description="Fetch cset data.") 148 | parser.add_argument( 149 | "--n_days", 150 | type=int, 151 | help="Fetch markets created in the last N days", 152 | default=None, 153 | ) 154 | args = parser.parse_args() 155 | main(args.n_days) 156 | -------------------------------------------------------------------------------- /scripts/data_scraping/gjopen.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | import argparse 3 | import datetime 4 | import json 5 | import logging 6 | import os 7 | import random 8 | import time 9 | 10 | # Related third-party imports 11 | import jsonlines 12 | from selenium.webdriver.common.by import By 13 | 14 | # Local application/library-specific imports 15 | import data_scraping 16 | from config.keys import keys 17 | 18 | logger = logging.getLogger(__name__) 19 | MAX_CONSECUTIVE_NOT_FOUND = 1000 20 | # Writing to file for debugging purposes. It will be deleted once the script is done. 21 | FILE_PATH = "gjopen_dump.jsonl" 22 | # Use your own executable_path (download from https://chromedriver.chromium.org/). 23 | CHROMEDRIVER_PATH = "/Users/apple/Downloads/chromedriver-mac-x64/chromedriver" 24 | GJOPEN_EMAIL = keys["EMAIL"] 25 | GJOPEN_PASSWORD = keys["GJOPEN_CSET_PASSWORD"] 26 | 27 | 28 | def main(n_days): 29 | """ 30 | Scrape, process, and upload question data from gjopen (https://www.gjopen.com/) 31 | 32 | Args: 33 | n_days (int): Number of days to look back for questions. 34 | 35 | Returns: 36 | list: A list of processed question data. 37 | """ 38 | driver = data_scraping.initialize_and_login( 39 | signin_page="https://www.gjopen.com/users/sign_in", 40 | email=GJOPEN_EMAIL, 41 | password=GJOPEN_PASSWORD, 42 | executable_path=CHROMEDRIVER_PATH, 43 | ) 44 | 45 | question_counter = 0 46 | consecutive_not_found_or_skipped = 0 47 | while True: 48 | question_counter += 1 49 | url = f"https://www.gjopen.com/questions/{question_counter}" 50 | 51 | try: 52 | driver.get(url) 53 | trend_graph_element = driver.find_element( 54 | By.CSS_SELECTOR, 55 | "div[data-react-class='FOF.Forecast.QuestionTrendGraph']", 56 | ) 57 | props = json.loads(trend_graph_element.get_attribute("data-react-props")) 58 | props["extracted_articles_urls"] = data_scraping.get_source_links( 59 | driver, url 60 | ) 61 | 62 | with jsonlines.open(FILE_PATH, mode="a") as writer: 63 | writer.write(props) 64 | consecutive_not_found_or_skipped = 0 65 | except BaseException: 66 | if data_scraping.question_not_found(driver): 67 | logger.info(f"Question {question_counter} not found") 68 | else: 69 | logger.info(f"Skipping question {question_counter}") 70 | consecutive_not_found_or_skipped += 1 71 | if consecutive_not_found_or_skipped > MAX_CONSECUTIVE_NOT_FOUND: 72 | logger.info("Reached maximum consecutive not found.") 73 | break 74 | 75 | time.sleep(random.uniform(0, 2)) # random delay between requests 76 | 77 | data = [] 78 | with open(FILE_PATH, "r") as file: 79 | for line in file: 80 | json_line = json.loads(line) 81 | data.append(json_line) 82 | 83 | # Remove duplicated dicts 84 | unique_tuples = {data_scraping.make_hashable(d) for d in data} 85 | all_questions = [data_scraping.unhashable_to_dict(t) for t in unique_tuples] 86 | 87 | if n_days is not None: 88 | date_limit = datetime.datetime.now() - datetime.timedelta(days=n_days) 89 | all_questions = [ 90 | q 91 | for q in all_questions 92 | if datetime.datetime.fromisoformat(q["question"]["created_at"][:-1]) 93 | >= date_limit 94 | ] 95 | 96 | logger.info(f"Number of gjopen questions fetched: {len(all_questions)}") 97 | 98 | for i in range(len(all_questions)): 99 | try: 100 | all_questions[i]["community_prediction"] = all_questions[i].pop( 101 | "trend_graph_probabilities" 102 | ) 103 | except BaseException: 104 | all_questions[i]["community_prediction"] = [] 105 | all_questions[i]["url"] = "https://www.gjopen.com/questions/" + str( 106 | all_questions[i]["question"]["id"] 107 | ) 108 | all_questions[i]["title"] = all_questions[i]["question"]["name"] 109 | all_questions[i]["close_time"] = all_questions[i]["question"].pop("closed_at") 110 | all_questions[i]["created_time"] = all_questions[i]["question"].pop( 111 | "created_at" 112 | ) 113 | all_questions[i]["background"] = all_questions[i]["question"]["description"] 114 | all_questions[i]["data_source"] = "gjopen" 115 | 116 | if all_questions[i]["question"]["state"] != "resolved": 117 | all_questions[i]["resolution"] = "Not resolved." 118 | all_questions[i]["is_resolved"] = False 119 | else: 120 | all_questions[i]["resolution"] = all_questions[i]["question"]["answers"][0][ 121 | "probability" 122 | ] 123 | all_questions[i]["is_resolved"] = True 124 | 125 | all_questions[i]["question_type"] = "binary" 126 | answer_set = set( 127 | [answer["name"] for answer in all_questions[i]["question"]["answers"]] 128 | ) 129 | if answer_set == {"Yes", "No"} or answer_set == {"Yes"}: 130 | all_questions[i]["question_type"] = "binary" 131 | else: 132 | all_questions[i]["question_type"] = "multiple_choice" 133 | 134 | logger.info("Uploading to s3...") 135 | question_types = ["binary", "multiple_choice"] 136 | data_scraping.upload_scraped_data(all_questions, "gjopen", question_types, n_days) 137 | 138 | # Delete the file after script completion 139 | if os.path.exists(FILE_PATH): 140 | os.remove(FILE_PATH) 141 | logger.info(f"Deleted the file: {FILE_PATH}") 142 | else: 143 | logger.info(f"The file {FILE_PATH} does not exist") 144 | 145 | 146 | if __name__ == "__main__": 147 | parser = argparse.ArgumentParser(description="Fetch gjopen data.") 148 | parser.add_argument( 149 | "--n_days", 150 | type=int, 151 | help="Fetch markets created in the last N days", 152 | default=None, 153 | ) 154 | args = parser.parse_args() 155 | main(args.n_days) 156 | -------------------------------------------------------------------------------- /llm_forecasting/prompts/summarization.py: -------------------------------------------------------------------------------- 1 | SUMMARIZATION_PROMPT_0 = ( 2 | """Summarize the article below, ensuring to include details pertinent to the subsequent question. 3 | 4 | Question: {question} 5 | Question Background: {background} 6 | 7 | Article: 8 | --- 9 | {article} 10 | ---""", 11 | ("QUESTION", "BACKGROUND"), 12 | ) 13 | 14 | SUMMARIZATION_PROMPT_1 = ( 15 | """I will present a forecasting question and a related article. 16 | 17 | Question: {question} 18 | Question Background: {background} 19 | 20 | Article: 21 | --- 22 | {article} 23 | --- 24 | 25 | A forecaster prefers a list of bullet points containing facts, observations, details, analysis, etc., over reading a full article. 26 | 27 | Your task is to distill the article as a list of bullet points that would help a forecaster in his deliberation.""", 28 | ("QUESTION", "BACKGROUND"), 29 | ) 30 | 31 | SUMMARIZATION_PROMPT_2 = ( 32 | """I want to make the following article shorter (condense it to no more than 500 words). 33 | 34 | Article: 35 | --- 36 | {article} 37 | --- 38 | 39 | When doing this task for me, please do not remove any details that would be helpful for making considerations about the following forecasting question. 40 | 41 | Forecasting Question: {question} 42 | Question Background: {background}""", 43 | ("QUESTION", "BACKGROUND"), 44 | ) 45 | 46 | SUMMARIZATION_PROMPT_3 = ( 47 | """I will present a forecasting question and a related article. 48 | 49 | Forecasting Question: {question} 50 | Question Background: {background} 51 | 52 | Article: 53 | --- 54 | {article} 55 | --- 56 | 57 | Use the article to write a list of bullet points that help a forecaster in their deliberation. 58 | 59 | Guidelines: 60 | - Ensure each bullet point contains specific, detailed information. 61 | - Avoid vague statements; instead, focus on summarizing key observations, data, predictions, or analysis presented in the article. 62 | - Also, extract points that directly or indirectly contribute to a better understanding or prediction of the specified question.""", 63 | ("QUESTION", "BACKGROUND"), 64 | ) 65 | 66 | 67 | SUMMARIZATION_PROMPT_4 = ( 68 | """Summarize the below article. 69 | 70 | Article: 71 | --- 72 | {article} 73 | ---""", 74 | ("",), 75 | ) 76 | 77 | 78 | SUMMARIZATION_PROMPT_5 = ( 79 | """I will present a forecasting question and a related article. 80 | 81 | Question: {question} 82 | Question Background: {background} 83 | 84 | Article: 85 | --- 86 | {article} 87 | --- 88 | 89 | I want to shorten the following article (condense it to no more than 500 words). When doing this task for me, please do not remove any details that would be helpful for making considerations about the forecasting question.""", 90 | ("QUESTION", "BACKGROUND"), 91 | ) 92 | 93 | 94 | SUMMARIZATION_PROMPT_6 = ( 95 | """Create a summary of the following article that assists in making a prediction for the following question. 96 | 97 | Forecasting Question: {question} 98 | Question Background: {background} 99 | 100 | Article: 101 | --- 102 | {article} 103 | --- 104 | 105 | Guidelines for Summary: 106 | - Include bullet points that extract key facts, observations, and analyses directly relevant to the forecasting question. 107 | - Then include analysis that connects the article's content to the forecasting question. 108 | - Strive for a balance between brevity and completeness, aiming for a summary that is informative yet efficient for a forecaster's analysis.""", 109 | ("QUESTION", "BACKGROUND"), 110 | ) 111 | 112 | 113 | SUMMARIZATION_PROMPT_7 = ( 114 | """Create a summary of the following article that assists in making a prediction for the following question. 115 | 116 | Prediction Question: {question} 117 | Question Background: {background} 118 | 119 | Article: 120 | --- 121 | {article} 122 | --- 123 | 124 | Guidelines for Summary: 125 | - Include bullet points that extract key facts, observations, and analyses directly relevant to the forecasting question. 126 | - Where applicable, highlight direct or indirect connections between the article's content and the forecasting question. 127 | - Strive for a balance between brevity and completeness, aiming for a summary that is informative yet efficient for a forecaster's analysis.""", 128 | ("QUESTION", "BACKGROUND"), 129 | ) 130 | 131 | 132 | SUMMARIZATION_PROMPT_8 = ( 133 | """I want to make the following article shorter (condense it to no more than 100 words). 134 | 135 | Article: 136 | --- 137 | {article} 138 | --- 139 | 140 | When doing this task for me, please do not remove any details that would be helpful for making considerations about the following forecasting question. 141 | 142 | Forecasting Question: {question} 143 | Question Background: {background}""", 144 | ("QUESTION", "BACKGROUND"), 145 | ) 146 | 147 | SUMMARIZATION_PROMPT_9 = ( 148 | """I want to make the following article shorter (condense it to no more than 100 words). 149 | 150 | Article: 151 | --- 152 | {article} 153 | --- 154 | 155 | When doing this task for me, please do not remove any details that would be helpful for making considerations about the following forecasting question. 156 | 157 | Forecasting Question: {question} 158 | Question Background: {background}""", 159 | ("QUESTION", "BACKGROUND"), 160 | ) 161 | 162 | SUMMARIZATION_PROMPT_10 = ( 163 | """I will present a forecasting question and a related article. 164 | 165 | Forecasting Question: {question} 166 | Question Background: {background} 167 | 168 | Article: 169 | --- 170 | {article} 171 | --- 172 | 173 | Use the article to write a list of bullet points that help a forecaster in their deliberation. 174 | 175 | Guidelines: 176 | - Ensure each bullet point contains specific, detailed information. 177 | - Avoid vague statements; instead, focus on summarizing key observations, data, predictions, or analysis presented in the article. 178 | - Your list should never exceed 5 bullet points.""", 179 | ("QUESTION", "BACKGROUND"), 180 | ) 181 | 182 | 183 | SUMMARIZATION_PROMPT_11 = ( 184 | """I will present a forecasting question and a related article. 185 | 186 | Question: {question} 187 | Question Background: {background} 188 | 189 | Article: 190 | --- 191 | {article} 192 | --- 193 | 194 | A forecaster prefers a list of bullet points containing facts, observations, details, analysis, etc., over reading a full article. 195 | 196 | Your task is to distill the article as a list of bullet points that would help a forecaster in his deliberation. Ensure, that you make your list as concise and short as possible without removing critical information.""", 197 | ("QUESTION", "BACKGROUND"), 198 | ) 199 | -------------------------------------------------------------------------------- /llm_forecasting/data_scraping.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import requests 3 | import logging 4 | 5 | from selenium import webdriver 6 | from selenium.webdriver.chrome.options import Options 7 | from selenium.webdriver.chrome.service import Service 8 | from selenium.webdriver.common.by import By 9 | from selenium.webdriver.support import expected_conditions as EC 10 | from selenium.webdriver.support.ui import WebDriverWait 11 | 12 | from config.constants import S3, S3_BUCKET_NAME 13 | from utils import db_utils 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def upload_scraped_data(data, source, question_types, n_days_or_not=None): 19 | """ 20 | Upload data (scraped by the script) by type to S3, partitioned by 21 | the specified source and date range if specified. 22 | 23 | Args: 24 | data (list): List of question data to process. 25 | source (str): The source identifier for categorizing uploaded data. 26 | question_types (list): List of question types to filter and upload. 27 | n_days_or_not (int or None): Number of days to look back for questions, 28 | or None for all available data. 29 | """ 30 | today_date = datetime.datetime.now().strftime("%Y_%m_%d") 31 | 32 | for q_type in question_types: 33 | questions = [q for q in data if q["question_type"] == q_type] 34 | q_type = q_type.lower() 35 | if n_days_or_not: 36 | file_name = f"n_days_updated_{q_type}_questions_{today_date}.pickle" 37 | s3_path = f"{source}/n_days_updated_{q_type}_questions/{file_name}" 38 | else: 39 | file_name = f"updated_{q_type}_questions_{today_date}.pickle" 40 | s3_path = f"{source}/updated_{q_type}_questions/{file_name}" 41 | 42 | db_utils.upload_data_structure_to_s3(S3, questions, S3_BUCKET_NAME, s3_path) 43 | 44 | 45 | def fetch_question_description(headers, question_id): 46 | """ 47 | Fetch and return the description of a specific question from Metaculus. 48 | 49 | The function queries the Metaculus API for a given question ID and 50 | extracts the question's description from the response. 51 | 52 | Parameters: 53 | - headers (dict): Headers to include in the API request. 54 | - question_id (int): The ID of the question to fetch. 55 | 56 | Returns: 57 | - str: Description of the question, or an empty string if not found. 58 | """ 59 | endpoint = f"https://www.metaculus.com/api2/questions/{question_id}" 60 | response = requests.get(endpoint, headers=headers) 61 | return response.json().get("description", "") 62 | 63 | 64 | # Use your own executable_path (download from https://chromedriver.chromium.org/). 65 | def initialize_and_login( 66 | signin_page, 67 | email, 68 | password, 69 | executable_path="/Users/apple/Downloads/chromedriver-mac-x64/chromedriver", 70 | ): 71 | """ 72 | Initialize the WebDriver and log into the website. 73 | 74 | Returns: 75 | selenium.webdriver.Chrome: An instance of the Chrome WebDriver after logging in. 76 | """ 77 | # Webdriver options 78 | chrome_options = Options() 79 | chrome_options.add_argument( 80 | "user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 " 81 | "(KHTML, like Gecko) Chrome/88.0.4324.150 Safari/537.36" 82 | ) 83 | service = Service(executable_path=executable_path) 84 | driver = webdriver.Chrome(service=service, options=chrome_options) 85 | 86 | # log in 87 | driver.get(signin_page) 88 | driver.find_element(By.ID, "user_email").send_keys(email) 89 | driver.find_element(By.ID, "user_password").send_keys(password) 90 | driver.find_element(By.NAME, "commit").click() 91 | 92 | return driver 93 | 94 | 95 | def question_not_found(driver): 96 | """ 97 | Check if a specific question is not found on the website. 98 | 99 | Args: 100 | driver (webdriver.Chrome): The Selenium Chrome WebDriver instance. 101 | 102 | Returns: 103 | bool: True if the question is not found, False otherwise. 104 | """ 105 | try: 106 | return ( 107 | driver.find_element(By.CLASS_NAME, "flash-message").text 108 | == "Could not find that question." 109 | ) 110 | except BaseException: 111 | return False 112 | 113 | 114 | def get_source_links(driver, url): 115 | """ 116 | Retrieve source links from a given question page. 117 | 118 | Args: 119 | driver (webdriver.Chrome): The Selenium Chrome WebDriver instance. 120 | url (str): The URL of the question page. 121 | 122 | Returns: 123 | list: A list of retrieved source links. 124 | """ 125 | source_links = set() 126 | try: 127 | driver.get(url) # Make sure to navigate to the page first 128 | driver.find_element(By.XPATH, '//a[text() = "Source Links"]').click() 129 | WebDriverWait(driver, 2).until( 130 | EC.visibility_of_element_located((By.ID, "links-table")) 131 | ) 132 | rows = driver.find_elements( 133 | By.XPATH, '//table[contains(@id, "links-table")]/tbody/tr' 134 | ) 135 | for entry in rows: 136 | try: 137 | url_elem = entry.find_element(By.TAG_NAME, "a") 138 | source_links.add(url_elem.get_attribute("href")) 139 | except BaseException: 140 | # If no tag is found in this table data, it will skip to 141 | # the next 142 | continue 143 | except: # Catch any other exception 144 | logger.info("Failed to get links.") 145 | return list(source_links) 146 | 147 | 148 | def make_hashable(e): 149 | """ 150 | Convert elements, including dictionaries, into a hashable form. 151 | 152 | Args: 153 | e (Any): The element to be converted. 154 | 155 | Returns: 156 | tuple: A tuple representing the hashable form of the element. 157 | """ 158 | if isinstance(e, dict): 159 | return tuple((key, make_hashable(val)) for key, val in sorted(e.items())) 160 | elif isinstance(e, list): 161 | return tuple(make_hashable(x) for x in e) 162 | else: 163 | return e 164 | 165 | 166 | def unhashable_to_dict(t): 167 | """ 168 | Convert tuples back into their original dictionary or list forms. 169 | 170 | Args: 171 | t (tuple): The tuple to be converted. 172 | 173 | Returns: 174 | Any: The original dictionary or list form of the tuple. 175 | """ 176 | if isinstance(t, tuple): 177 | try: 178 | return dict((k, unhashable_to_dict(v)) for k, v in t) 179 | except ValueError: 180 | return [unhashable_to_dict(x) for x in t] 181 | else: 182 | return t 183 | -------------------------------------------------------------------------------- /scripts/training_data/training_point_generation.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | import argparse 3 | import asyncio 4 | import logging 5 | import random 6 | 7 | # Local application/library specific imports 8 | from utils import data_utils 9 | from prompts.prompts import PROMPT_DICT 10 | import evaluation 11 | 12 | # Set up logging 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | TRAINING_RETRIEVAL_CONFIG = { 17 | # Number of search query keywords per question. 18 | "NUM_SEARCH_QUERY_KEYWORDS": [4, 5, 6], 19 | "MAX_WORDS_NEWSCATCHER": [5], 20 | "MAX_WORDS_GNEWS": [7, 8, 9], 21 | "SEARCH_QUERY_MODEL_NAME": ["gpt-4-1106-preview"], 22 | "SEARCH_QUERY_TEMPERATURE": [0.0], 23 | "SEARCH_QUERY_PROMPT_TEMPLATES": [ 24 | [PROMPT_DICT["search_query"]["0"]], 25 | [PROMPT_DICT["search_query"]["1"]], 26 | [PROMPT_DICT["search_query"]["0"], PROMPT_DICT["search_query"]["1"]], 27 | ], 28 | "NUM_ARTICLES_PER_QUERY": [5, 7, 9], 29 | "SUMMARIZATION_MODEL_NAME": ["gpt-3.5-turbo-1106"], 30 | "SUMMARIZATION_TEMPERATURE": [0.2], 31 | "SUMMARIZATION_PROMPT_TEMPLATE": [ 32 | PROMPT_DICT["summarization"]["0"], 33 | PROMPT_DICT["summarization"]["1"], 34 | PROMPT_DICT["summarization"]["9"], 35 | ], 36 | "RANKING_MODEL_NAME": ["gpt-3.5-turbo-1106"], 37 | "RANKING_TEMPERATURE": [0.0], 38 | "RANKING_PROMPT_TEMPLATE": [PROMPT_DICT["ranking"]["0"]], 39 | "RANKING_RELEVANCE_THRESHOLD": [4], 40 | "SORT_BY": ["relevancy"], 41 | "RANKING_METHOD": ["llm-rating"], 42 | "RANKING_METHOD_LLM": ["title_250_tokens"], 43 | "NUM_SUMMARIES_THRESHOLD": [15, 20, 25], 44 | } 45 | 46 | TRAINING_REASONING_CONFIG = { 47 | "BASE_REASONING_MODEL_NAMES": [["claude-2.1", "gpt-4-1106-preview"]], 48 | "BASE_REASONING_TEMPERATURE": [0.2, 0.3, 0.4], 49 | "BASE_REASONING_PROMPT_TEMPLATES": [[], []], 50 | "AGGREGATION_METHOD": ["meta"], 51 | "AGGREGATION_PROMPT_TEMPLATE": [PROMPT_DICT["meta_reasoning"]["0"]], 52 | "AGGREGATION_TEMPERATURE": [0.2, 0.3, 0.4], 53 | "AGGREGATION_MODEL_NAME": ["claude-2.1", "gpt-4-1106-preview"], 54 | "AGGREGATION_WEIGTHTS": None, 55 | } 56 | 57 | 58 | def sample_retrieval_hyperparms(ir_config): 59 | """ 60 | Sample hyperparameters for information retrieval configuration. 61 | 62 | Args: 63 | ir_config (dict): A dictionary containing different hyperparameters for information retrieval. 64 | 65 | Returns: 66 | dict: A dictionary with the same keys as ir_config, but each key has a single randomly 67 | sampled hyperparameter. 68 | """ 69 | sampled_ir_config = ir_config.copy() 70 | for key, hyperparams in ir_config.items(): 71 | sampled_ir_config[key] = random.choice(hyperparams) 72 | return sampled_ir_config 73 | 74 | 75 | def sample_reasoning_hyperparams(reasoning_config, prompts_to_sample, prompt_weights): 76 | sampled_reasoning_config = reasoning_config.copy() 77 | for key, hyperparams in reasoning_config.items(): 78 | if key != "BASE_REASONING_PROMPT_TEMPLATES": 79 | sampled_reasoning_config[key] = random.choice(hyperparams) 80 | else: 81 | # For BASE_REASONING_PROMPT_TEMPLATES, sample 5 prompts for each model 82 | models = sampled_reasoning_config["BASE_REASONING_MODEL_NAMES"] 83 | print(models) 84 | sampled_prompts = [] 85 | for model in models: 86 | model_prompts = random.choices( 87 | prompts_to_sample, weights=prompt_weights[model], k=5 88 | ) 89 | sampled_prompts.append(model_prompts) 90 | sampled_reasoning_config[key] = sampled_prompts 91 | return sampled_reasoning_config 92 | 93 | 94 | async def generate_training_points( 95 | s3_path, 96 | retrieval_index, 97 | num_retrievals, 98 | questions_after, 99 | ir_config, 100 | reasoning_config, 101 | output_dir, 102 | prompts_to_sample, 103 | prompt_weights, 104 | ): 105 | """ 106 | Asynchronously generates training data points. 107 | 108 | Args: 109 | s3_path (str): The S3 path to retrieve initial data. 110 | retrieval_index (int): An index specifying the location or category for retrieval. 111 | num_retrievals (int): Number of data retrievals to perform. 112 | questions_after (str): A filtering criterion for questions. 113 | ir_config (dict, optional): Configuration for information retrieval. Defaults to {}. 114 | reasoning_config (dict, optional): Configuration for reasoning processes. Defaults to {}. 115 | output_dir (str, optional): The directory where output files are stored. Defaults to 'data_point_generation'. 116 | 117 | Description: 118 | Retrieves training data, iterates through questions, evaluates them if necessary, processes them 119 | based on given configurations, and saves the output. 120 | 121 | Returns: 122 | None: The function is used for its side effect of processing and saving data. 123 | """ 124 | data_dict, raw_data = data_utils.get_data( 125 | s3_path, 126 | retrieval_index, 127 | num_retrievals, 128 | questions_after=questions_after, 129 | return_raw_question_data=True, 130 | ) 131 | for q_index, question in enumerate(data_dict["question_list"]): 132 | if not evaluation.to_eval(question, retrieval_index, output_dir): 133 | logger.info(f"Already processed question, {q_index}: {question}") 134 | continue 135 | 136 | logger.info(f"Starting question, {q_index}: {question}") 137 | try: 138 | ir_config_samp = sample_retrieval_hyperparms(ir_config) 139 | reasoning_config_samp = sample_reasoning_hyperparams( 140 | reasoning_config, prompts_to_sample, prompt_weights 141 | ) 142 | output, _, ranked_articles = await evaluation.retrieve_and_forecast( 143 | data_utils.format_single_question(data_dict, q_index), 144 | raw_data[q_index], 145 | ir_config=ir_config_samp, 146 | reason_config=reasoning_config_samp, 147 | return_articles=True, 148 | calculate_alignment=True, 149 | ) 150 | output["ranked_articles"] = [ 151 | (art.summary, art.relevance_rating) for art in ranked_articles 152 | ] 153 | evaluation.save_results(output, question, retrieval_index, output_dir) 154 | except Exception as e: 155 | logger.error(f"Error processing question {q_index}: {e}") 156 | 157 | return None 158 | 159 | 160 | async def main(): 161 | """ 162 | Start the training point generation job. 163 | """ 164 | parser = argparse.ArgumentParser() 165 | parser.add_argument( 166 | "--s3_path", 167 | type=str, 168 | default="training_sets/forecasting_binary_training_set.pickle", 169 | help="S3 dataset path to run (use 'default' for metaculus training set).", 170 | ) 171 | parser.add_argument( 172 | "--retrieval_index", 173 | type=int, 174 | default=1, 175 | help="Index for retrieval (1 to |num_retrievals|).", 176 | ) 177 | parser.add_argument( 178 | "--num_retrievals", 179 | type=int, 180 | default=5, 181 | help="Number of ideal retrieval dates.", 182 | ) 183 | parser.add_argument( 184 | "--questions_after", 185 | type=str, 186 | default="2015", 187 | help="The lower-bound year for questions to evaluate.", 188 | ) 189 | args = parser.parse_args() 190 | 191 | await generate_training_points( 192 | args.s3_path, 193 | args.retrieval_index, 194 | args.num_retrievals, 195 | args.questions_after, 196 | ir_config=TRAINING_RETRIEVAL_CONFIG, 197 | reasoning_config=TRAINING_REASONING_CONFIG, 198 | output_dir="data_point_generation", 199 | ) 200 | 201 | 202 | if __name__ == "__main__": 203 | asyncio.run(main()) 204 | -------------------------------------------------------------------------------- /llm_forecasting/config/constants.py: -------------------------------------------------------------------------------- 1 | # Local application/library specific imports 2 | from config.keys import keys 3 | from prompts.prompts import PROMPT_DICT 4 | from utils import db_utils 5 | 6 | OAI_SOURCE = "OAI" 7 | ANTHROPIC_SOURCE = "ANTHROPIC" 8 | TOGETHER_AI_SOURCE = "TOGETHER" 9 | GOOGLE_SOURCE = "GOOGLE" 10 | HUGGINGFACE_SOURCE = "HUGGINGFACE" 11 | 12 | DEFAULT_RETRIEVAL_CONFIG = { 13 | "NUM_SEARCH_QUERY_KEYWORDS": 3, 14 | "MAX_WORDS_NEWSCATCHER": 5, 15 | "MAX_WORDS_GNEWS": 8, 16 | "SEARCH_QUERY_MODEL_NAME": "gpt-4-1106-preview", 17 | "SEARCH_QUERY_TEMPERATURE": 0.0, 18 | "SEARCH_QUERY_PROMPT_TEMPLATES": [ 19 | PROMPT_DICT["search_query"]["0"], 20 | PROMPT_DICT["search_query"]["1"], 21 | ], 22 | "NUM_ARTICLES_PER_QUERY": 5, 23 | "SUMMARIZATION_MODEL_NAME": "gpt-3.5-turbo-1106", 24 | "SUMMARIZATION_TEMPERATURE": 0.2, 25 | "SUMMARIZATION_PROMPT_TEMPLATE": PROMPT_DICT["summarization"]["9"], 26 | "NUM_SUMMARIES_THRESHOLD": 10, 27 | "PRE_FILTER_WITH_EMBEDDING": True, 28 | "PRE_FILTER_WITH_EMBEDDING_THRESHOLD": 0.32, 29 | "RANKING_MODEL_NAME": "gpt-3.5-turbo-1106", 30 | "RANKING_TEMPERATURE": 0.0, 31 | "RANKING_PROMPT_TEMPLATE": PROMPT_DICT["ranking"]["0"], 32 | "RANKING_RELEVANCE_THRESHOLD": 4, 33 | "RANKING_COSINE_SIMILARITY_THRESHOLD": 0.5, 34 | "SORT_BY": "date", 35 | "RANKING_METHOD": "llm-rating", 36 | "RANKING_METHOD_LLM": "title_250_tokens", 37 | "NUM_SUMMARIES_THRESHOLD": 20, 38 | "EXTRACT_BACKGROUND_URLS": True, 39 | } 40 | 41 | DEFAULT_REASONING_CONFIG = { 42 | "BASE_REASONING_MODEL_NAMES": ["gpt-4-1106-preview"], 43 | "BASE_REASONING_TEMPERATURE": 1.0, 44 | "BASE_REASONING_PROMPT_TEMPLATES": [ 45 | [ 46 | PROMPT_DICT["binary"]["scratch_pad"]["1"], 47 | PROMPT_DICT["binary"]["scratch_pad"]["2"], 48 | ], 49 | ], 50 | "ALIGNMENT_MODEL_NAME": "gpt-3.5-turbo-1106", 51 | "ALIGNMENT_TEMPERATURE": 0, 52 | "ALIGNMENT_PROMPT": PROMPT_DICT["alignment"]["0"], 53 | "AGGREGATION_METHOD": "meta", 54 | "AGGREGATION_PROMPT_TEMPLATE": PROMPT_DICT["meta_reasoning"]["0"], 55 | "AGGREGATION_TEMPERATURE": 0.2, 56 | "AGGREGATION_MODEL_NAME": "gpt-4", 57 | "AGGREGATION_WEIGTHTS": None, 58 | } 59 | 60 | CHARS_PER_TOKEN = 4 61 | 62 | S3 = db_utils.initialize_s3_client(keys["AWS_ACCESS_KEY"], keys["AWS_SECRET_KEY"]) 63 | S3_BUCKET_NAME = "my-forecasting-bucket" 64 | 65 | MODEL_TOKEN_LIMITS = { 66 | "claude-2.1": 200000, 67 | "claude-2": 100000, 68 | "claude-3-opus-20240229": 200000, 69 | "claude-3-sonnet-20240229": 200000, 70 | "gpt-4": 8000, 71 | "gpt-3.5-turbo-1106": 16000, 72 | "gpt-3.5-turbo-16k": 16000, 73 | "gpt-3.5-turbo": 8000, 74 | "gpt-4-1106-preview": 128000, 75 | "gemini-pro": 30720, 76 | "togethercomputer/llama-2-7b-chat": 4096, 77 | "togethercomputer/llama-2-13b-chat": 4096, 78 | "togethercomputer/llama-2-70b-chat": 4096, 79 | "togethercomputer/StripedHyena-Hessian-7B": 32768, 80 | "togethercomputer/LLaMA-2-7B-32K": 32768, 81 | "mistralai/Mistral-7B-Instruct-v0.2": 32768, 82 | "mistralai/Mixtral-8x7B-Instruct-v0.1": 32768, 83 | "zero-one-ai/Yi-34B-Chat": 4096, 84 | "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": 32768, 85 | "NousResearch/Nous-Hermes-2-Yi-34B": 32768, 86 | } 87 | 88 | MODEL_NAME_TO_SOURCE = { 89 | "claude-2.1": ANTHROPIC_SOURCE, 90 | "claude-2": ANTHROPIC_SOURCE, 91 | "claude-3-opus-20240229": ANTHROPIC_SOURCE, 92 | "claude-3-sonnet-20240229": ANTHROPIC_SOURCE, 93 | "gpt-4": OAI_SOURCE, 94 | "gpt-3.5-turbo-1106": OAI_SOURCE, 95 | "gpt-3.5-turbo-16k": OAI_SOURCE, 96 | "gpt-3.5-turbo": OAI_SOURCE, 97 | "gpt-4-1106-preview": OAI_SOURCE, 98 | "gemini-pro": GOOGLE_SOURCE, 99 | "togethercomputer/llama-2-7b-chat": TOGETHER_AI_SOURCE, 100 | "togethercomputer/llama-2-13b-chat": TOGETHER_AI_SOURCE, 101 | "togethercomputer/llama-2-70b-chat": TOGETHER_AI_SOURCE, 102 | "togethercomputer/LLaMA-2-7B-32K": TOGETHER_AI_SOURCE, 103 | "togethercomputer/StripedHyena-Hessian-7B": TOGETHER_AI_SOURCE, 104 | "mistralai/Mistral-7B-Instruct-v0.2": TOGETHER_AI_SOURCE, 105 | "mistralai/Mixtral-8x7B-Instruct-v0.1": TOGETHER_AI_SOURCE, 106 | "zero-one-ai/Yi-34B-Chat": TOGETHER_AI_SOURCE, 107 | "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": TOGETHER_AI_SOURCE, 108 | "NousResearch/Nous-Hermes-2-Yi-34B": TOGETHER_AI_SOURCE, 109 | } 110 | 111 | ANTHROPIC_RATE_LIMIT = 5 112 | 113 | IRRETRIEVABLE_SITES = [ 114 | "wsj.com", 115 | "english.alarabiya.net", 116 | "consilium.europa.eu", 117 | "abc.net.au", 118 | "thehill.com", 119 | "democracynow.org", 120 | "fifa.com", 121 | "si.com", 122 | "aa.com.tr", 123 | "thestreet.com", 124 | "newsweek.com", 125 | "spokesman.com", 126 | "aninews.in", 127 | "commonslibrary.parliament.uk", 128 | "cybernews.com", 129 | "lineups.com", 130 | "expressnews.com", 131 | "news-herald.com", 132 | "c-span.org/video", 133 | "investors.com", 134 | "finance.yahoo.com", # This site has a “read more” button. 135 | "metaculus.com", # newspaper4k cannot parse metaculus pages well 136 | "houstonchronicle.com", 137 | "unrwa.org", 138 | "njspotlightnews.org", 139 | "crisisgroup.org", 140 | "vanguardngr.com", # protected by Cloudflare 141 | "ahram.org.eg", # protected by Cloudflare 142 | "reuters.com", # blocked by Javascript and CAPTCHA 143 | "carnegieendowment.org", 144 | "casino.org", 145 | "legalsportsreport.com", 146 | "thehockeynews.com", 147 | "yna.co.kr", 148 | "carrefour.com", 149 | "carnegieeurope.eu", 150 | "arabianbusiness.com", 151 | "inc.com", 152 | "joburg.org.za", 153 | "timesofindia.indiatimes.com", 154 | "seekingalpha.com", 155 | "producer.com", # protected by Cloudflare 156 | "oecd.org", 157 | "almayadeen.net", # protected by Cloudflare 158 | "manifold.markets", # prevent data contamination 159 | "goodjudgment.com", # prevent data contamination 160 | "infer-pub.com", # prevent data contamination 161 | "www.gjopen.com", # prevent data contamination 162 | "polymarket.com", # prevent data contamination 163 | "betting.betfair.com", # protected by Cloudflare 164 | "news.com.au", # blocks crawler 165 | "predictit.org", # prevent data contamination 166 | "atozsports.com", 167 | "barrons.com", 168 | "forex.com", 169 | "www.cnbc.com/quotes", # stock market data: prevent data contamination 170 | "montrealgazette.com", 171 | "bangkokpost.com", 172 | "editorandpublisher.com", 173 | "realcleardefense.com", 174 | "axios.com", 175 | "mensjournal.com", 176 | "warriormaven.com", 177 | "tapinto.net", 178 | "indianexpress.com", 179 | "science.org", 180 | "businessdesk.co.nz", 181 | "mmanews.com", 182 | "jdpower.com", 183 | "hrexchangenetwork.com", 184 | "arabnews.com", 185 | "nationalpost.com", 186 | "bizjournals.com", 187 | "thejakartapost.com", 188 | ] 189 | 190 | QUESTION_CATEGORIES = [ 191 | "Science & Tech", 192 | "Healthcare & Biology", 193 | "Economics & Business", 194 | "Environment & Energy", 195 | "Politics & Governance", 196 | "Education & Research", 197 | "Arts & Recreation", 198 | "Security & Defense", 199 | "Social Sciences", 200 | "Sports", 201 | "Other", 202 | ] 203 | 204 | ( 205 | METACULUS_PLATFORM, 206 | CSET_PLATFORM, 207 | GJOPEN_PLATFORM, 208 | MANIFOLD_PLATFORM, 209 | POLYMARKET_PLATFORM, 210 | ) = ("metaculus", "cset", "gjopen", "manifold", "polymarket") 211 | 212 | ALL_PLATFORMS = [ 213 | METACULUS_PLATFORM, 214 | CSET_PLATFORM, 215 | GJOPEN_PLATFORM, 216 | MANIFOLD_PLATFORM, 217 | POLYMARKET_PLATFORM, 218 | ] 219 | 220 | END_WORDS_TO_PROBS_6 = { 221 | "No": 0.05, 222 | "Very Unlikely": 0.15, 223 | "Unlikely": 0.35, 224 | "Likely": 0.55, 225 | "Very Likely": 0.75, 226 | "Yes": 0.95, 227 | } 228 | 229 | END_WORDS_TO_PROBS_10 = { 230 | "No": 0.05, 231 | "Extremely Unlikely": 0.15, 232 | "Very Unlikely": 0.25, 233 | "Unlikely": 0.35, 234 | "Slightly Unlikely": 0.45, 235 | "Slightly Likely": 0.55, 236 | "Likely": 0.65, 237 | "Very Likely": 0.75, 238 | "Extremely Likely": 0.85, 239 | "Yes": 0.95, 240 | } 241 | 242 | TOKENS_TO_PROBS_DICT = { 243 | "six_options": END_WORDS_TO_PROBS_6, 244 | "ten_options": END_WORDS_TO_PROBS_10, 245 | } -------------------------------------------------------------------------------- /llm_forecasting/utils/time_utils.py: -------------------------------------------------------------------------------- 1 | # Related third-party imports 2 | from datetime import datetime, timedelta 3 | import pandas as pd 4 | import math 5 | 6 | 7 | def extract_date(datetime): 8 | """ 9 | Extract a date string from a datetime object or raw string. 10 | 11 | Args: 12 | datetime (datetime or str): A datetime object or string. 13 | If a string in the format 'YYYY-MM-DDTHH:MM:SSZ', the date part 14 | is extracted. 15 | 16 | Returns: 17 | str: A date string in the format 'YYYY-MM-DD'. 18 | """ 19 | if isinstance(datetime, str): 20 | if "T" in datetime: 21 | return datetime.split("T")[0] 22 | else: 23 | return datetime 24 | else: 25 | return str(datetime.date()) 26 | 27 | 28 | def convert_date_string_to_tuple(date_string): 29 | """ 30 | Convert a date string of the form 'year-month-day' to a tuple (year, month, day). 31 | 32 | Args: 33 | date_string (str): A string representing the date in 'year-month-day' format. 34 | 35 | Returns: 36 | tuple: A tuple containing the year, month, and day as integers. 37 | """ 38 | # Split the date string by '-' 39 | parts = date_string.split("-") 40 | # Check that the date string is in the correct format 41 | assert len(parts) == 3, "Date string must be in 'year-month-day' format." 42 | # Convert the parts to integers and return as a tuple 43 | return tuple(map(int, parts)) 44 | 45 | 46 | def safe_to_datetime(date_str): 47 | """ 48 | Safely convert a date string to a datetime object. 49 | 50 | Args: 51 | date_str (str): Date string to be converted. 52 | 53 | Returns: 54 | datetime: Converted datetime object or None if conversion fails. 55 | """ 56 | try: 57 | return pd.to_datetime(date_str.replace("Z", "+00:00")) 58 | except pd.errors.OutOfBoundsDatetime: 59 | return None 60 | 61 | 62 | def move_date_by_percentage(date_str1, date_str2, percentage): 63 | """ 64 | Compute a date that is a specified percentage between two dates. 65 | 66 | Returns the date before |date_str2| if the computed date is equal 67 | to |date_str2|. 68 | 69 | Args: 70 | date_str1 (str): Start date in "YYYY-MM-DD" format. 71 | date_str2 (str): End date in "YYYY-MM-DD" format. 72 | percentage (float): Percentage to move from start date towards end date. 73 | 74 | Returns: 75 | str: The new date in "YYYY-MM-DD" format. 76 | """ 77 | # Parse dates 78 | date1 = datetime.strptime(date_str1, "%Y-%m-%d") 79 | date2 = datetime.strptime(date_str2, "%Y-%m-%d") 80 | 81 | # Ensure date1 is earlier than date2 82 | if date1 > date2: 83 | date1, date2 = date2, date1 84 | 85 | # Calculate new date at the given percentage between them 86 | target_date = date1 + (date2 - date1) * (percentage / 100.0) 87 | 88 | # Check if target date is the same as date_str2, if so, subtract one day 89 | if target_date.strftime("%Y-%m-%d") == date_str2: 90 | target_date -= timedelta(days=1) 91 | 92 | return target_date.strftime("%Y-%m-%d") 93 | 94 | 95 | def adjust_date_by_days(date_str, days_to_adjust): 96 | """ 97 | Adjust a date string by a specified number of days. 98 | 99 | Args: 100 | date_str (str): A date string in the format 'YYYY-MM-DD'. 101 | days_to_adjust (int): The number of days to adjust the date by. Can be 102 | positive or negative. 103 | 104 | Returns: 105 | str: A new date string in the format 'YYYY-MM-DD' adjusted by the 106 | specified number of days. 107 | """ 108 | # Parse the date string into a datetime object 109 | date_obj = datetime.strptime(date_str, "%Y-%m-%d") 110 | 111 | # Adjust the date by the given number of days 112 | adjusted_date = date_obj + datetime.timedelta(days=days_to_adjust) 113 | 114 | # Convert the adjusted datetime object back into a string 115 | new_date_str = adjusted_date.strftime("%Y-%m-%d") 116 | 117 | return new_date_str 118 | 119 | 120 | def convert_timestamp(timestamp): 121 | """ 122 | Convert a numeric timestamp into a formatted date string. 123 | 124 | This function checks if the given timestamp is in milliseconds or seconds, 125 | and converts it to a date string in the 'YYYY-MM-DD' format. It assumes that 126 | timestamps are in milliseconds if they are greater than 1e10, which 127 | typically corresponds to dates after the year 2001. 128 | 129 | Args: 130 | timestamp (float or int): The timestamp to convert. Can be in seconds 131 | or milliseconds. 132 | 133 | Returns: 134 | str: The converted timestamp as a date string in 'YYYY-MM-DD' format. 135 | """ 136 | # Identify if the timestamp is in milliseconds or seconds 137 | timestamp = float(timestamp) 138 | is_millisecond = timestamp > 1e10 # assuming data is from after 2001 139 | if is_millisecond: 140 | timestamp = timestamp / 1000 # Convert from milliseconds to seconds 141 | 142 | # Convert to formatted date string 143 | return datetime.utcfromtimestamp(timestamp).strftime("%Y-%m-%d") 144 | 145 | 146 | def is_more_recent(first_date_str, second_date_str, or_equal_to=False): 147 | """ 148 | Determine if |second_date_str| is more recent than |first_date_str|. 149 | 150 | Args: 151 | first_date_str (str): A string representing the first date to compare against. Expected format: 'YYYY-MM-DD'. 152 | second_date_str (str): A string representing the second date. Expected format: 'YYYY-MM-DD'. 153 | 154 | Returns: 155 | bool: True if the second date is more recent than the first date, False otherwise. 156 | """ 157 | first_date_obj = datetime.strptime(first_date_str, "%Y-%m-%d") 158 | second_date_obj = datetime.strptime(second_date_str, "%Y-%m-%d") 159 | if or_equal_to: 160 | return second_date_obj >= first_date_obj 161 | return second_date_obj > first_date_obj 162 | 163 | 164 | def is_less_than_N_days_apart(date1_str, date2_str, N=3): 165 | """ 166 | Check if the difference between two dates is less than N days. 167 | 168 | :param date1_str: First date as a string in 'YYYY-MM-DD' format. 169 | :param date2_str: Second date as a string in 'YYYY-MM-DD' format. 170 | :param N: Number of days for comparison. 171 | :return: True if the difference is less than N days, otherwise False. 172 | """ 173 | date_obj1 = datetime.strptime(date1_str, "%Y-%m-%d") 174 | date_obj2 = datetime.strptime(date2_str, "%Y-%m-%d") 175 | return (date_obj2 - date_obj1) < timedelta(days=N) 176 | 177 | 178 | def find_pred_with_closest_date(date_str, date_pred_list): 179 | """ 180 | Finds the tuple with the date closest to the given reference date from a list of (date, prediction) tuples. 181 | 182 | Parameters: 183 | - date_str (str): Reference date in 'YYYY-MM-DD' format. 184 | - date_pred_list (list of tuples): Each tuple contains a date string in 'YYYY-MM-DD' format and an associated prediction. 185 | 186 | Returns: 187 | - tuple: The tuple with the closest date to date_str. Returns None if date_pred_list is empty. 188 | 189 | Raises: 190 | - ValueError: If date_str or dates in date_pred_list are not in the correct format. 191 | """ 192 | # Convert the reference date string to a datetime object 193 | ref_date = datetime.strptime(date_str, "%Y-%m-%d") 194 | 195 | # Initialize variables to store the closest date and its difference 196 | closest_date = None 197 | min_diff = float("inf") 198 | 199 | # Iterate through the list of tuples 200 | for date_tuple in date_pred_list: 201 | # Convert the date string in the tuple to a datetime object 202 | current_date = datetime.strptime(date_tuple[0], "%Y-%m-%d") 203 | 204 | # Calculate the absolute difference in days 205 | diff = abs((current_date - ref_date).days) 206 | 207 | # Update the closest date and min_diff if this date is closer 208 | if diff < min_diff: 209 | min_diff = diff 210 | closest_date = date_tuple 211 | 212 | return closest_date 213 | 214 | 215 | def get_retrieval_date( 216 | retrieval_index, num_retrievals, date_begin, date_close, resolve_date 217 | ): 218 | """ 219 | Calculate a specific retrieval date within a given time range. 220 | 221 | The retrieval date is determined using an exponential distribution based on the total number of retrievals. 222 | If the calculated date is after the resolve date, None is returned. 223 | 224 | Args: 225 | retrieval_index (int): Index of the current retrieval (starting from 0). 226 | num_retrievals (int): Total number of retrievals planned within the date range. 227 | date_begin (str): Start date of the range in 'YYYY-MM-DD' format. 228 | date_close (str): End date of the range in 'YYYY-MM-DD' format. 229 | resolve_date (str): Date by which the retrieval should be resolved in 'YYYY-MM-DD' format. 230 | 231 | Returns: 232 | str or None: The calculated retrieval date in 'YYYY-MM-DD' format, or None if it falls after the resolve date. 233 | """ 234 | date_begin_obj = datetime.strptime(date_begin, "%Y-%m-%d") 235 | date_close_obj = datetime.strptime(date_close, "%Y-%m-%d") 236 | resolve_date_obj = datetime.strptime(resolve_date, "%Y-%m-%d") 237 | 238 | # Early return if date range is invalid or reversed 239 | if date_begin_obj >= date_close_obj or date_begin_obj > resolve_date_obj: 240 | return None 241 | 242 | # Calculate the total number of days in the range 243 | total_days = (date_close_obj - date_begin_obj).days 244 | 245 | # Calculate the retrieval date 246 | retrieval_days = math.exp((math.log(total_days) / num_retrievals) * retrieval_index) 247 | retrieval_date_obj = date_begin_obj + timedelta(days=retrieval_days) 248 | 249 | if retrieval_date_obj >= date_close_obj: 250 | retrieval_date_obj = date_close_obj - timedelta(days=1) 251 | if retrieval_date_obj >= resolve_date_obj: 252 | return None 253 | 254 | # Check against the previous retrieval date 255 | if retrieval_index > 1: 256 | previous_days = math.exp( 257 | (math.log(total_days) / num_retrievals) * (retrieval_index - 1) 258 | ) 259 | previous_date_obj = date_begin_obj + timedelta(days=previous_days) 260 | if retrieval_date_obj <= previous_date_obj: 261 | return None 262 | 263 | return retrieval_date_obj.strftime("%Y-%m-%d") 264 | -------------------------------------------------------------------------------- /llm_forecasting/evaluation.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | import logging 3 | 4 | # Local application/library-specific imports 5 | import alignment 6 | from config.constants import ( 7 | DEFAULT_RETRIEVAL_CONFIG, 8 | DEFAULT_REASONING_CONFIG, 9 | S3, 10 | S3_BUCKET_NAME, 11 | ) 12 | import ensemble 13 | import ranking 14 | import summarize 15 | from utils import db_utils 16 | 17 | # Set up logging 18 | logging.basicConfig(level=logging.INFO) 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def to_eval(question, retrieval_number, output_dir): 23 | """ 24 | Determines if a given question needs evaluation based on its presence in an S3 storage. 25 | 26 | Args: 27 | question (str): The question to be evaluated. 28 | retrieval_number (int): A unique identifier for the retrieval process. 29 | output_dir (str): The directory path where the file is expected to be found. 30 | file_type (str, optional): The type of the file. Default is "pickle". 31 | 32 | Returns: 33 | bool: True if the question needs evaluation, False otherwise. 34 | """ 35 | question_formatted = question.replace(" ", "_").replace("/", "") 36 | file_name = f"{output_dir}/{retrieval_number}/{question_formatted}.pickle" 37 | try: 38 | # Try reading the file from S3 39 | _ = db_utils.read_pickle_from_s3(S3, S3_BUCKET_NAME, file_name) 40 | return False 41 | except BaseException: 42 | # If the file is not found in S3, return True to indicate it needs evaluation 43 | return True 44 | 45 | 46 | def save_results(save_dict, question, retrieval_number, output_dir, file_type="pickle"): 47 | """ 48 | Save a dictionary of results to an S3 storage, formatted based on a specific question. 49 | 50 | Args: 51 | save_dict (dict): The dictionary of results to be saved. 52 | question (str): The title of the question related to the results. 53 | retrieval_number (int): A unique identifier for the retrieval process. 54 | output_dir (str): The directory path where the file will be saved. 55 | file_type (str, optional): The type of the file to save. Default is "pickle". 56 | """ 57 | question_formatted = question.replace(" ", "_").replace("/", "") 58 | file_name = f"{output_dir}/{retrieval_number}/{question_formatted}.{file_type}" 59 | db_utils.upload_data_structure_to_s3(S3, save_dict, S3_BUCKET_NAME, file_name) 60 | 61 | 62 | async def retrieve_and_forecast( 63 | question_dict, 64 | question_raw, 65 | ir_config=DEFAULT_RETRIEVAL_CONFIG, 66 | reason_config=DEFAULT_REASONING_CONFIG, 67 | calculate_alignment=False, 68 | return_articles=False, 69 | ): 70 | """ 71 | Asynchronously evaluates the forecasting question using the end-to-end system. 72 | 73 | This function integrates steps of information retrieval, summarization, reasoning, 74 | and alignment scoring. 75 | 76 | Args: 77 | question_dict (dict): A dictionary containing detailed data about the question. 78 | question_raw (dict): The raw question data. 79 | ir_config (dict, optional): The configuration for information retrieval. 80 | Defaults to DEFAULT_RETRIEVAL_CONFIG. 81 | reason_config (dict, optional): The configuration for reasoning. 82 | Defaults to DEFAULT_REASONING_CONFIG. 83 | calculate_alignment (bool, optional): Flag to determine if alignment scores 84 | should be calculated. Defaults to False. 85 | return_articles (bool, optional): Flag to decide if the retrieved articles 86 | should be returned. Defaults to False. 87 | 88 | Returns: 89 | dict: The reasoning for the questions, along with metadata and intermediate 90 | outputs from the system. If return_articles is True, also includes retrieved 91 | articles. Returns None if the question does not meet evaluation criteria. 92 | """ 93 | assert isinstance( 94 | reason_config["BASE_REASONING_PROMPT_TEMPLATES"], list 95 | ), "BASE_REASONING_PROMPT_TEMPLATES must be a list." 96 | assert len(reason_config["BASE_REASONING_PROMPT_TEMPLATES"]) == len( 97 | reason_config["BASE_REASONING_MODEL_NAMES"] 98 | ), "BASE_REASONING_PROMPT_TEMPLATES and BASE_REASONING_MODEL_NAMES must have the same length." 99 | 100 | question = question_dict["question"] 101 | background_info = question_dict["background"] 102 | resolution_criteria = question_dict["resolution_criteria"] 103 | answer = question_dict["answer"] 104 | question_dates = question_dict["question_dates"] 105 | retrieval_dates = question_dict["retrieval_dates"] 106 | urls_in_background = question_dict["urls_in_background"] 107 | # Information retrieval 108 | try: 109 | ( 110 | ranked_articles, 111 | all_articles, 112 | search_queries_list_gnews, 113 | search_queries_list_nc, 114 | ) = await ranking.retrieve_summarize_and_rank_articles( 115 | question, 116 | background_info, 117 | resolution_criteria, 118 | retrieval_dates, 119 | urls=urls_in_background, 120 | config=ir_config, 121 | return_intermediates=True, 122 | ) 123 | except Exception as e: # skip the question if failed 124 | logger.error(f"Error message: {e}") 125 | logger.info(f"IR failed at question: {question}.") 126 | return None 127 | subset_ranked_articles = ranked_articles[ 128 | : ir_config.get("NUM_SUMMARIES_THRESHOLD", 100) 129 | ].copy() 130 | all_summaries = summarize.concat_summaries(subset_ranked_articles) 131 | logger.info(f"Information retrieval complete for question: {question}.") 132 | logger.info(f"Number of summaries: {len(subset_ranked_articles)}.") 133 | 134 | # Reasoning (using ensemble) 135 | today_to_close_date = [retrieval_dates[1], question_dates[1]] 136 | ensemble_dict = await ensemble.meta_reason( 137 | question=question, 138 | background_info=background_info, 139 | resolution_criteria=resolution_criteria, 140 | today_to_close_date_range=today_to_close_date, 141 | retrieved_info=all_summaries, 142 | reasoning_prompt_templates=reason_config["BASE_REASONING_PROMPT_TEMPLATES"], 143 | base_model_names=reason_config["BASE_REASONING_MODEL_NAMES"], 144 | base_temperature=reason_config["BASE_REASONING_TEMPERATURE"], 145 | aggregation_method=reason_config["AGGREGATION_METHOD"], 146 | answer_type="probability", 147 | weights=reason_config["AGGREGATION_WEIGTHTS"], 148 | meta_model_name=reason_config["AGGREGATION_MODEL_NAME"], 149 | meta_prompt_template=reason_config["AGGREGATION_PROMPT_TEMPLATE"], 150 | meta_temperature=reason_config["AGGREGATION_TEMPERATURE"], 151 | ) 152 | 153 | alignment_scores = None 154 | if calculate_alignment: 155 | alignment_scores = alignment.get_alignment_scores( 156 | ensemble_dict["base_reasonings"], 157 | alignment_prompt=reason_config["ALIGNMENT_PROMPT"], 158 | model_name=reason_config["ALIGNMENT_MODEL_NAME"], 159 | temperature=reason_config["ALIGNMENT_TEMPERATURE"], 160 | question=question, 161 | background=background_info, 162 | resolution_criteria=resolution_criteria, 163 | ) 164 | 165 | # Compute brier score (base_predictions is a list of lists of 166 | # probabilities) 167 | base_brier_scores = [] 168 | # For each sublist (corresponding to a base model name) 169 | for base_predictions in ensemble_dict["base_predictions"]: 170 | base_brier_scores.append( 171 | [(base_prediction - answer) ** 2 for base_prediction in base_predictions] 172 | ) 173 | # Visualization (draw the HTML) 174 | base_html = visualize_utils.visualize_all( 175 | question_data=question_raw, 176 | retrieval_dates=retrieval_dates, 177 | search_queries_gnews=search_queries_list_gnews, 178 | search_queries_nc=search_queries_list_nc, 179 | all_articles=all_articles, 180 | ranked_articles=ranked_articles, 181 | all_summaries=all_summaries, 182 | model_names=reason_config["BASE_REASONING_MODEL_NAMES"], 183 | base_reasoning_prompt_templates=reason_config[ 184 | "BASE_REASONING_PROMPT_TEMPLATES" 185 | ], 186 | base_reasoning_full_prompts=ensemble_dict["base_reasoning_full_prompts"], 187 | base_reasonings=ensemble_dict["base_reasonings"], 188 | base_predictions=ensemble_dict["base_predictions"], 189 | base_brier_scores=base_brier_scores, 190 | ) 191 | meta_html = visualize_utils.visualize_all_ensemble( 192 | question_data=question_raw, 193 | ranked_articles=ranked_articles, 194 | all_articles=all_articles, 195 | search_queries_gnews=search_queries_list_gnews, 196 | search_queries_nc=search_queries_list_nc, 197 | retrieval_dates=retrieval_dates, 198 | meta_reasoning=ensemble_dict["meta_reasoning"], 199 | meta_full_prompt=ensemble_dict["meta_prompt"], 200 | meta_prediction=ensemble_dict["meta_prediction"], 201 | ) 202 | # Generate outputs, one dict per question 203 | output = { 204 | "question": question, 205 | "answer": int(answer), 206 | "data_source": question_raw["data_source"], 207 | "background_info": background_info, 208 | "resolution_criteria": resolution_criteria, 209 | "retrieval_dates": retrieval_dates, 210 | "search_queries_gnews": search_queries_list_gnews, 211 | "search_queries_nc": search_queries_list_nc, 212 | "retrieved_info": all_summaries, 213 | "base_model_names": reason_config["BASE_REASONING_MODEL_NAMES"], 214 | "base_reasoning_full_prompts": ensemble_dict["base_reasoning_full_prompts"], 215 | "base_reasonings": ensemble_dict["base_reasonings"], 216 | "base_predictions": ensemble_dict["base_predictions"], 217 | "alignment_scores": alignment_scores, 218 | "meta_reasoning_full_prompt": ensemble_dict["meta_prompt"], 219 | "meta_reasoning": ensemble_dict["meta_reasoning"], 220 | "meta_prediction": ensemble_dict["meta_prediction"], 221 | "community_prediction": question_dict["community_pred_at_retrieval"], 222 | "base_html": base_html, 223 | "meta_html": meta_html, 224 | "base_brier_score": base_brier_scores, 225 | "meta_brier_score": (ensemble_dict["meta_prediction"] - answer) ** 2, 226 | "community_brier_score": (question_dict["community_pred_at_retrieval"] - answer) 227 | ** 2, 228 | } 229 | if return_articles: 230 | return output, all_articles, ranked_articles 231 | return output 232 | -------------------------------------------------------------------------------- /llm_forecasting/prompts/search_query.py: -------------------------------------------------------------------------------- 1 | SEARCH_QUERY_PROMPT_0 = ( 2 | """I will provide you with a forecasting question and the background information for the question. I will then ask you to generate short search queries (up to {max_words} words each) that I'll use to find articles on Google News to help answer the question. 3 | 4 | Question: 5 | {question} 6 | 7 | Question Background: 8 | {background} 9 | 10 | Today's date: {date_begin} 11 | Question close date: {date_end} 12 | 13 | You must generate this exact amount of queries: {num_keywords} 14 | 15 | Start off by writing down sub-questions. Then use your sub-questions to help steer the search queries you produce. 16 | 17 | Your response should take the following structure: 18 | Thoughts: 19 | {{ Insert your thinking here. }} 20 | Search Queries: 21 | {{ Insert the queries here. Use semicolons to separate the queries. }}""", 22 | ( 23 | "QUESTION", 24 | "BACKGROUND", 25 | "DATES", 26 | "NUM_KEYWORDS", 27 | "MAX_WORDS", 28 | ), 29 | ) 30 | 31 | SEARCH_QUERY_PROMPT_1 = ( 32 | """I will provide you with a forecasting question and the background information for the question. 33 | 34 | Question: 35 | {question} 36 | 37 | Question Background: 38 | {background} 39 | 40 | Today's date: {date_begin} 41 | Question close date: {date_end} 42 | 43 | Task: 44 | - Generate brief search queries (up to {max_words} words each) to gather information on Google that could influence the forecast. 45 | 46 | You must generate this exact amount of queries: {num_keywords} 47 | 48 | Your response should take the following structure: 49 | Thoughts: 50 | {{ Insert your thinking here. }} 51 | Search Queries: 52 | {{ Insert the queries here. Use semicolons to separate the queries. }}""", 53 | ( 54 | "QUESTION", 55 | "BACKGROUND", 56 | "DATES", 57 | "NUM_KEYWORDS", 58 | "MAX_WORDS", 59 | ), 60 | ) 61 | 62 | SEARCH_QUERY_PROMPT_2 = ( 63 | """In this task, I will present a forecasting question along with relevant background information. Your goal is to create {num_keywords} concise search queries (up to {max_words} words each) to gather information that could influence the forecast. Consider different angles and aspects that might impact the outcome. 64 | 65 | Question: 66 | {question} 67 | 68 | Question Background: 69 | {background} 70 | 71 | Today's date: {date_begin} 72 | Question close date: {date_end} 73 | 74 | Now, generate {num_keywords} short search queries to search for information on Google News. 75 | You must generate this exact amount of queries: {num_keywords}. 76 | When formulating your search queries, think about various factors that could affect the forecast, such as recent trends, historical data, or external influences. 77 | 78 | Your response should take the following structure: 79 | Thoughts: 80 | {{ Insert your thinking here. }} 81 | Search Queries: 82 | {{ Insert the queries here. Use semicolons to separate the queries. }}""", 83 | ( 84 | "QUESTION", 85 | "BACKGROUND", 86 | "DATES", 87 | "NUM_KEYWORDS", 88 | "MAX_WORDS", 89 | ), 90 | ) 91 | 92 | SEARCH_QUERY_PROMPT_3 = ( 93 | """I will provide you with a forecasting question and the background information for the question. I will then ask you to generate {num_keywords} short search queries (up to {max_words} words each) that I'll use to find articles on Google News to help answer the question. 94 | 95 | Question: 96 | {question} 97 | 98 | Question Background: 99 | {background} 100 | 101 | Today's date: {date_begin} 102 | Question close date: {date_end} 103 | 104 | You must generate this exact amount of queries: {num_keywords} 105 | 106 | Your response should take the following structure: 107 | Thoughts: 108 | {{ Insert your thinking here. }} 109 | Search Queries: 110 | {{ Insert the queries here. Use semicolons to separate the queries. }}""", 111 | ( 112 | "QUESTION", 113 | "BACKGROUND", 114 | "DATES", 115 | "NUM_KEYWORDS", 116 | "MAX_WORDS", 117 | ), 118 | ) 119 | 120 | SEARCH_QUERY_PROMPT_4 = ( 121 | """Generate short search queries (up to {max_words} words) for the forecasting question below. 122 | 123 | I will use them to query Google News for articles. These search queries should result in articles that help me make an informed prediction. 124 | 125 | Question: 126 | {question} 127 | 128 | Question Background: 129 | {background} 130 | 131 | Today's date: {date_begin} 132 | Question close date: {date_end} 133 | 134 | You must generate this exact amount of queries: {num_keywords}. 135 | 136 | Your response should take the following structure: 137 | Thoughts: 138 | {{ Insert your thinking here. }} 139 | Search Queries: 140 | {{ Insert the queries here. Use semicolons to separate the queries. }}""", 141 | ( 142 | "QUESTION", 143 | "BACKGROUND", 144 | "DATES", 145 | "NUM_KEYWORDS", 146 | "MAX_WORDS", 147 | ), 148 | ) 149 | 150 | SEARCH_QUERY_PROMPT_5 = ( 151 | """ 152 | Please provide {num_keywords} search queries to input into Google to help me research this forecasting question: 153 | 154 | Question: {question} 155 | 156 | Background: {background} 157 | 158 | Today's Date: {date_begin} 159 | Close Date: {date_end} 160 | 161 | Guidelines: 162 | - Include terms related to influential factors that could sway the outcome. 163 | - Use different keyword approaches to get balanced perspectives. 164 | - Each search query should be up to {max_words} words. 165 | 166 | You must generate this exact amount of queries: {num_keywords}. 167 | 168 | Your response should take the following structure: 169 | Thoughts: 170 | {{ Insert your thinking here. }} 171 | Search Queries: 172 | {{ Insert the queries here. Use semicolons to separate the queries. }}""", 173 | ( 174 | "QUESTION", 175 | "BACKGROUND", 176 | "DATES", 177 | "NUM_KEYWORDS", 178 | "MAX_WORDS", 179 | ), 180 | ) 181 | 182 | SEARCH_QUERY_PROMPT_6 = ( 183 | """ 184 | In this task, you will receive a forecasting question along with its background information. Your objective is to create {num_keywords} targeted search queries, each not exceeding {max_words} words, to unearth information that could shape the forecast. 185 | 186 | Question: 187 | {question} 188 | 189 | Background: 190 | {background} 191 | 192 | Current Date: 193 | {date_begin} 194 | Question Close Date: 195 | {date_end} 196 | 197 | Your job is to formulate {num_keywords} distinct and concise search queries. These queries will be used to query Google News to capture diverse perspectives and relevant data from various sources. Think about different elements that could influence the outcome. 198 | 199 | Structure your response as follows: 200 | Thoughts: 201 | {{ Insert your thinking here. }} 202 | Search Queries: 203 | {{ Insert the queries here. Use semicolons to separate the queries. }}""", 204 | ( 205 | "QUESTION", 206 | "BACKGROUND", 207 | "DATES", 208 | "NUM_KEYWORDS", 209 | "MAX_WORDS", 210 | ), 211 | ) 212 | 213 | SEARCH_QUERY_PROMPT_7 = ( 214 | """ 215 | In this task, I will present a forecasting question along with relevant background information. Your goal is to create {num_keywords} concise search queries (up to {max_words} words each) to gather information on Google that could influence the forecast. 216 | 217 | Question: 218 | {question} 219 | 220 | Question Background: 221 | {background} 222 | 223 | Today's date: {date_begin} 224 | Question close date: {date_end} 225 | 226 | Now, generate {num_keywords} short search queries to search for information on Google News. 227 | Begin by formulating sub-questions related to the main question. Use these sub-questions to guide the creation of your search queries. 228 | 229 | Your response should take the following structure: 230 | Thoughts: 231 | {{ Insert your thinking here. }} 232 | Search Queries: 233 | {{ Insert the queries here. Use semicolons to separate the queries. }}""", 234 | ( 235 | "QUESTION", 236 | "BACKGROUND", 237 | "DATES", 238 | "NUM_KEYWORDS", 239 | "MAX_WORDS", 240 | ), 241 | ) 242 | 243 | SEARCH_QUERY_PROMPT_8 = ( 244 | """In this task, I will present a forecasting question along with relevant background information. Your goal is to create {num_keywords} concise search queries (up to {max_words} words each) to gather information that could influence the forecast. Consider different angles and aspects that might impact the outcome. 245 | 246 | Question: 247 | {question} 248 | 249 | Question Background: 250 | {background} 251 | 252 | Today's date: {date_begin} 253 | Question close date: {date_end} 254 | 255 | Now, generate {num_keywords} short search queries to search for information on Google News. 256 | 257 | Your response should take the following structure: 258 | Thoughts: 259 | {{ Insert your thinking here. }} 260 | Search Queries: 261 | {{ Insert the queries here. Use semicolons to separate the queries. }}""", 262 | ( 263 | "QUESTION", 264 | "BACKGROUND", 265 | "DATES", 266 | "NUM_KEYWORDS", 267 | "MAX_WORDS", 268 | ), 269 | ) 270 | 271 | # To be evaluated 272 | SEARCH_QUERY_PROMPT_NO_DATE_0 = ( 273 | """Generate {num_keywords} search queries (up to {max_words} words each) for the forecasting question below. 274 | 275 | I will use them to query Google News for articles. These search queries should result in articles that help me make an informed prediction. 276 | 277 | --- 278 | Question: 279 | {question} 280 | --- 281 | 282 | --- 283 | Question Background: 284 | {background} 285 | --- 286 | 287 | Your response should take the following structure: 288 | 289 | Thoughts: 290 | {{ insert your thinking here }} 291 | 292 | Search Queries: 293 | {{ Insert the queries here. Use semicolons to separate the queries. }}""", 294 | ( 295 | "QUESTION", 296 | "BACKGROUND", 297 | "NUM_KEYWORDS", 298 | "MAX_WORDS", 299 | ), 300 | ) 301 | 302 | 303 | # To be evaluated 304 | SEARCH_QUERY_PROMPT_NO_DATE_1 = ( 305 | """I will give you a forecasting question and its background information. 306 | 307 | Your goal is to generate {num_keywords} search queries (up to {max_words} words each). 308 | The search queries wil be used to query Google News for articles. 309 | They should result in articles that help me make an informed prediction. 310 | 311 | --- 312 | Question: 313 | {question} 314 | --- 315 | 316 | --- 317 | Question Background: 318 | {background} 319 | --- 320 | 321 | Your response should take the following structure: 322 | 323 | Thoughts: 324 | {{ insert your thinking here }} 325 | 326 | Search Queries: 327 | {{ Insert the queries here. Use semicolons to separate the queries. }}""", 328 | ( 329 | "QUESTION", 330 | "BACKGROUND", 331 | "NUM_KEYWORDS", 332 | "MAX_WORDS", 333 | ), 334 | ) 335 | -------------------------------------------------------------------------------- /scripts/data_scraping/polymarket.py: -------------------------------------------------------------------------------- 1 | # run this script with a venv with python 3.9 so it's compatible with 2 | # py_clob_client (polymarket API python client) 3 | 4 | # Standard library imports 5 | from concurrent.futures import ThreadPoolExecutor 6 | from datetime import datetime, timedelta 7 | import argparse 8 | import ast 9 | import logging 10 | import time 11 | 12 | # Related third-party imports 13 | import requests 14 | from tqdm import tqdm 15 | 16 | # Local application/library-specific imports 17 | import data_scraping 18 | import information_retrieval 19 | from config.keys import keys 20 | from py_clob_client.client import ClobClient 21 | from py_clob_client.constants import POLYGON 22 | 23 | # Setup logging and other configurations 24 | logger = logging.getLogger(__name__) 25 | 26 | client = ClobClient( 27 | "https://clob.polymarket.com", key=keys["CRYPTO_PRIVATE_KEY"], chain_id=POLYGON 28 | ) 29 | 30 | 31 | def get_market_query(offset_value): 32 | """ 33 | Construct and return a GraphQL query string for fetching market data. 34 | 35 | Args: 36 | offset_value (int): The offset value to use in the query for pagination. 37 | 38 | Returns: 39 | str: A GraphQL query string. 40 | """ 41 | query = f"""{{ 42 | markets(offset: {offset_value}) {{ 43 | id 44 | conditionId 45 | question 46 | description 47 | category 48 | createdAt 49 | questionID 50 | outcomes 51 | outcomePrices 52 | clobTokenIds 53 | endDate 54 | volume 55 | closed 56 | }} 57 | }}""" 58 | return query 59 | 60 | 61 | def get_comment_query(market_id, offset_value): 62 | """ 63 | Construct and return a GraphQL query string for fetching comments related 64 | to a specific market ID with pagination. 65 | 66 | Args: 67 | market_id (int): The unique identifier of the market. 68 | offset_value (int): The offset value to use for pagination. 69 | 70 | Returns: 71 | str: A GraphQL query string for fetching comments. 72 | """ 73 | query = f"""{{ 74 | comments(marketID: {market_id}, offset: {offset_value}) {{ 75 | id 76 | body 77 | createdAt 78 | }} 79 | }}""" 80 | return query 81 | 82 | 83 | def generate_json_markets(field, market_id=None): 84 | """ 85 | Perform a POST request to retrieve market or comment data from the 86 | Polymarket API. 87 | 88 | Args: 89 | field (str): Specifies the type of data to fetch ('markets' or 'comments'). 90 | market_id (int, optional): The market ID for which comments are to be fetched. 91 | Required if 'field' is 'comments'. 92 | 93 | Returns: 94 | list: A list of dictionaries containing market or comment data. 95 | """ 96 | url = "https://gamma-api.polymarket.com/query" 97 | offset_value = 0 98 | all_data = [] 99 | 100 | while True: 101 | if field == "comments": 102 | query = get_comment_query(market_id, offset_value) 103 | elif field == "markets": 104 | query = get_market_query(offset_value) 105 | else: 106 | print("Wrong field name!") 107 | 108 | try: 109 | response = requests.post(url, json={"query": query}) 110 | # Will raise an HTTPError if the HTTP request returned an 111 | # unsuccessful status code 112 | response.raise_for_status() 113 | data = response.json().get("data", {}).get(field, []) 114 | 115 | if not data: 116 | break # Exit loop if no more markets are returned 117 | 118 | all_data.extend(data) 119 | offset_value += 1000 120 | 121 | except requests.exceptions.RequestException as e: 122 | print(f"Request failed: {e}") 123 | break # Break the loop in case of request failure 124 | 125 | return all_data 126 | 127 | 128 | def question_to_url(question, base_url="https://polymarket.com/event/"): 129 | """ 130 | Convert a Polymarket question into a URL format. 131 | 132 | Args: 133 | question (str): The question text to be formatted. 134 | base_url (str, optional): The base URL to prepend to the formatted question. 135 | 136 | Returns: 137 | str: The formatted URL representing the Polymarket question. 138 | """ 139 | cleaned_question = "".join( 140 | "" if char == "$" else char 141 | for char in question.strip() 142 | if char.isalnum() or char in [" ", "-", "$"] 143 | ) 144 | 145 | # Replace spaces with hyphens and convert to lowercase 146 | url_formatted_question = cleaned_question.replace(" ", "-").lower() 147 | 148 | # Concatenate with the base URL 149 | url = base_url + url_formatted_question 150 | return url 151 | 152 | 153 | def fetch_price_history(market_id): 154 | """ 155 | Retrieve the price history of a market from the Polymarket API. 156 | 157 | Args: 158 | market_id (str): The unique identifier of the market. 159 | 160 | Returns: 161 | list: A list of dictionaries containing the price history data, or an empty list 162 | if the data retrieval fails. 163 | """ 164 | url = ( 165 | f"https://clob.polymarket.com/prices-history?interval=all&market=" 166 | f"{market_id}&fidelity=60" 167 | ) 168 | 169 | response = requests.get(url) 170 | 171 | if response.status_code == 200: 172 | data = response.json() 173 | history_data = data.get("history", []) 174 | return history_data 175 | else: 176 | print("Failed to retrieve data:", response.status_code) 177 | return [] 178 | 179 | 180 | def process_market(m): 181 | """ 182 | Process a single market dictionary by adding additional information 183 | such as comments, URLs, community predictions, and other metadata. 184 | It also renames fields to align with a specific format. 185 | 186 | Args: 187 | market (dict): A dictionary representing a single market with its initial data. 188 | 189 | Returns: 190 | dict: The processed market dictionary with additional fields and formatted data. 191 | """ 192 | m["comments"] = generate_json_markets("comments", market_id=int(m["id"])) 193 | m["url"] = question_to_url(m["question"]) 194 | 195 | # Resolution 196 | try: 197 | m["outcomes"] = ast.literal_eval(m["outcomes"]) 198 | m["outcomePrices"] = ast.literal_eval(m["outcomePrices"]) 199 | 200 | # Make sure that 'outcomes' and 'outcomePrices' have the same length 201 | if len(m["outcomes"]) != len(m["outcomePrices"]): 202 | raise ValueError( 203 | "The lengths of 'outcomes' and 'outcomePrices' do not match." 204 | ) 205 | 206 | # Find the outcome with the highest price 207 | highest_price = max(m["outcomePrices"], key=lambda x: float(x)) 208 | highest_price_index = m["outcomePrices"].index(highest_price) 209 | resolution_outcome = m["outcomes"][highest_price_index] 210 | 211 | m["resolution"] = resolution_outcome 212 | 213 | if m["resolution"] == "Yes": 214 | m["resolution"] = 1 215 | elif m["resolution"] == "No": 216 | m["resolution"] = 0 217 | except Exception as e: 218 | print(f"An error occurred: {e}") 219 | m["resolution"] = "Error" 220 | 221 | # Question type 222 | m["question_type"] = "multiple_choice" 223 | if set(m["outcomes"]) == {"Yes", "No"}: 224 | m["question_type"] = "binary" 225 | 226 | # Community predictions for the first outcome 227 | try: 228 | if m["clobTokenIds"] is not None: 229 | # Attempt to fetch community predictions 230 | m["community_predictions"] = fetch_price_history( 231 | m["clobTokenIds"].split('"')[1] 232 | ) 233 | else: 234 | m["community_predictions"] = [] 235 | except IndexError as e: 236 | # Print the error and the problematic clobTokenIds 237 | print(f"Error: {e}, clobTokenIds: {m.get('clobTokenIds')}") 238 | m["community_predictions"] = [] 239 | 240 | # Rename field names so it aligns with mateculus 241 | m["title"] = m.pop("question") 242 | m["close_time"] = m.pop("endDate") 243 | m["created_time"] = m.pop("createdAt") 244 | m["background"] = m.pop("description") 245 | m["is_resolved"] = m.pop("closed") 246 | 247 | # Data source 248 | m["data_source"] = "polymarket" 249 | 250 | return m 251 | 252 | 253 | def main(n_days): 254 | """ 255 | Main function to fetch and process Polymarket data. 256 | 257 | Args: 258 | n_days (int): Number of days in the past to limit the data fetching. 259 | If None, fetches all available data. 260 | 261 | Returns: 262 | list: A list of processed market dictionaries. 263 | """ 264 | logger.info("Starting the polymarket script...") 265 | 266 | start_time = time.time() 267 | 268 | all_markets = generate_json_markets("markets") 269 | 270 | if n_days is not None: 271 | date_limit = datetime.now() - timedelta(days=n_days) 272 | all_markets = [ 273 | market 274 | for market in all_markets 275 | if datetime.fromisoformat(market["createdAt"][:-1]) >= date_limit 276 | ] 277 | 278 | logger.info(f"Number of polymarket questions fetched: {len(all_markets)}") 279 | 280 | logger.info("Start preprocess the question...") 281 | with ThreadPoolExecutor(max_workers=50) as executor: 282 | results = list( 283 | tqdm( 284 | executor.map(process_market, all_markets), 285 | total=len(all_markets), 286 | ) 287 | ) 288 | 289 | logger.info("Start extracting articles links...") 290 | 291 | for question in results: 292 | question["extracted_articles_urls"] = information_retrieval.get_urls_from_text( 293 | question["background"] 294 | ) 295 | if question["comments"]: 296 | for comment in question["comments"]: 297 | question["extracted_articles_urls"].extend( 298 | information_retrieval.get_urls_from_text(comment["body"]) 299 | ) 300 | 301 | elapsed_time = time.time() - start_time 302 | logger.info(f"Total execution time: {elapsed_time} seconds") 303 | 304 | logger.info("Uploading to s3...") 305 | question_types = ["binary", "multiple_choice"] 306 | data_scraping.upload_scraped_data(results, "polymarket", question_types, n_days) 307 | 308 | 309 | if __name__ == "__main__": 310 | parser = argparse.ArgumentParser(description="Fetch polymarket data.") 311 | parser.add_argument( 312 | "--n_days", 313 | type=int, 314 | help="Fetch questions created in the last N days", 315 | default=None, 316 | ) 317 | args = parser.parse_args() 318 | main(args.n_days) 319 | -------------------------------------------------------------------------------- /llm_forecasting/summarize.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | import asyncio 3 | import logging 4 | import time 5 | 6 | # Local application/library-specific imports 7 | from config.constants import MODEL_TOKEN_LIMITS 8 | import model_eval 9 | from prompts.prompts import PROMPT_DICT 10 | from utils import model_utils 11 | 12 | # Set up logging 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def concat_summaries(articles, return_summaries_list=False): 18 | """ 19 | Combine the summaries of various articles into a single, cohesive string, 20 | ensuring to incorporate each article's title, publication date, and a 21 | sequential index for easy reference. 22 | 23 | Args: 24 | articles (list of article objects): List of article objects with 25 | a 'summary' field. 26 | return_summaries_list (bool, optional): Whether to return the list of 27 | summaries as well. Defaults to False. 28 | 29 | Returns: 30 | str: A string containing the concatenated summaries of the articles. 31 | 32 | Example output: 33 | --- 34 | ARTICLES 35 | [1] Title 1 (published on 2021-01-01) 36 | .... 37 | [2] Title 2 (published on 2021-01-02) 38 | .... 39 | ---- 40 | 41 | If return_summaries_list is True, then the function returns a tuple 42 | containing the concatenated summaries string and the list of summaries. 43 | """ 44 | if not articles: 45 | return "---\nNo articles were retrieved for this question.\n----" 46 | article_summaries = [ 47 | f"[{index}] {article.title} (published on {(article.publish_date.date() if article.publish_date else 'unknown date')})\nSummary: {article.summary}\n" 48 | for index, article in enumerate(articles, start=1) 49 | ] 50 | concatenated_summaries_str = ( 51 | "---\nARTICLES\n" + "\n".join(article_summaries) + "----" 52 | ) 53 | if return_summaries_list: 54 | return concatenated_summaries_str, article_summaries 55 | return concatenated_summaries_str 56 | 57 | 58 | def split_text_into_chunks(text, model_name, token_limit): 59 | """ 60 | Split the text into chunks, ensuring each chunk is below the token limit. 61 | 62 | Args: 63 | text (str): Input text to be split. 64 | model_name (str): Name of the model to be used for token counting. 65 | token_limit (int): Maximum number of tokens allowed per chunk. 66 | 67 | Returns: 68 | list: List of text chunks. 69 | """ 70 | words = text.split() 71 | current_chunk = [] 72 | current_chunk_tokens = 0 73 | chunks = [] 74 | 75 | for word in words: 76 | word_tokens = model_utils.count_tokens(word, model_name) 77 | if current_chunk_tokens + word_tokens > token_limit: 78 | chunks.append(" ".join(current_chunk)) 79 | current_chunk = [word] 80 | current_chunk_tokens = word_tokens 81 | else: 82 | current_chunk.append(word) 83 | current_chunk_tokens += word_tokens 84 | 85 | if current_chunk: 86 | chunks.append(" ".join(current_chunk)) 87 | 88 | return chunks 89 | 90 | 91 | def recursive_summarize( 92 | text, 93 | model_name, 94 | prompt, 95 | output_token_length=None, 96 | temperature=0, 97 | ): 98 | """ 99 | Recursively summarize the text until the summary fits within the context 100 | window. 101 | 102 | Args: 103 | text (str): The input text to be summarized. 104 | model_name (str): Name of the model to be used for summarization. 105 | primary_prompt (str): Main prompt to guide the summarization. 106 | output_token_length (int, optional): Desired word count of the final 107 | summary. Defaults to None. 108 | temperature (float, optional): Sampling temperature for the completion. 109 | Defaults to 0. 110 | 111 | Returns: 112 | str: Summarized text. 113 | """ 114 | start_time = time.time() 115 | 116 | total_tokens = model_utils.count_tokens(text, model_name) 117 | logger.info(f"Total number of tokens of the given text: {total_tokens}") 118 | 119 | if total_tokens <= MODEL_TOKEN_LIMITS[model_name]: 120 | if output_token_length: 121 | prompt += ( 122 | f"\n\nAlso, ensure the summary is under {output_token_length} words.\n" 123 | ) 124 | 125 | output = model_eval.get_response_from_model( 126 | model_name=model_name, 127 | prompt=prompt.format(article=text), 128 | max_tokens=output_token_length, 129 | ) 130 | 131 | end_time = time.time() 132 | elapsed_time = end_time - start_time 133 | logger.info(f"Time taken for summarization: {elapsed_time:.2f} seconds") 134 | logger.info("Finished summarizing the article!") 135 | return output 136 | else: 137 | token_limit = MODEL_TOKEN_LIMITS[model_name] - 1000 138 | chunks = split_text_into_chunks(text, model_name, token_limit) 139 | 140 | summarized_chunks = [] 141 | for chunk in chunks: 142 | summarized_chunk = recursive_summarize( 143 | chunk, 144 | model_name, 145 | prompt, 146 | output_token_length, 147 | temperature, 148 | ) 149 | summarized_chunks.append(summarized_chunk) 150 | 151 | summarized_text = " ".join(summarized_chunks) 152 | return recursive_summarize( 153 | summarized_text, 154 | model_name, 155 | prompt, 156 | output_token_length, 157 | temperature, 158 | ) 159 | 160 | 161 | async def summarize_articles( 162 | articles, 163 | model_name="gpt-3.5-turbo-1106", 164 | prompt=PROMPT_DICT["summarization"]["0"][0], 165 | temperature=0.2, 166 | update_object=True, 167 | inline_questions=[], 168 | ): 169 | """ 170 | Summarizes a list of articles asynchronously. 171 | 172 | Long articles are truncated down to token limit and summarized separately. 173 | 174 | Example usage: 175 | >> summarized_results = await summarize_articles(articles) 176 | 177 | Args: 178 | articles (list of obj): List of article objects. Each article object should have the following fields: 179 | text_cleaned (str): Full text of the article. 180 | model_name (str, optional): Name of the OpenAI model to be used for summarization (defaults to "gpt-3.5-turbo-1106"). 181 | prompt (str, optional): Prompt to use for the API call (defaults to PROMPT_DICT["summarization"]["0"][0]). 182 | This is not the full prompt, but contains a placeholder for the article text. 183 | output_token_length (int, optional): Desired word count of the final summary. Defaults to None. 184 | temperature (float, optional): Sampling temperature for the completion. Defaults to 0.2. 185 | update_object (bool, optional): Whether to update the article object with the summary (defaults to True). 186 | inline_questions (dict, optional): List containing the inline questions. Defaults to []. 187 | 188 | Returns: 189 | dict: Dictionary containing the summarized results for each article. 190 | Example output: 191 | --- 192 | { 193 | "Title 1": "Summary 1", 194 | "Title 2": "Summary 2", 195 | ... 196 | } 197 | --- 198 | Also, the article objects are updated with the summaries if update_object is True. 199 | """ 200 | summarized_results = {} 201 | # Truncate articles that are too long 202 | for article in articles: 203 | if ( # exceeds token limit 204 | model_utils.count_tokens(article.text_cleaned, model_name) 205 | > MODEL_TOKEN_LIMITS[model_name] - 1000 206 | ): 207 | article.text_cleaned = split_text_into_chunks( 208 | article.text_cleaned, model_name, MODEL_TOKEN_LIMITS[model_name] - 1000 209 | )[0] 210 | # Summarize all articles asynchronously 211 | logger.info(f"Async summarizing {len(articles)} short articles") 212 | all_summaries = await async_summarize( 213 | articles, 214 | prompt, 215 | update_object=update_object, 216 | temperature=temperature, 217 | model_name=model_name, 218 | inline_questions=inline_questions, 219 | ) 220 | for i, article in enumerate(articles): 221 | summarized_results[article.title] = all_summaries[i] 222 | return summarized_results 223 | 224 | 225 | async def async_summarize( 226 | articles, 227 | prompt=PROMPT_DICT["summarization"]["0"][0], 228 | update_object=True, 229 | model_name="gpt-3.5-turbo-1106", 230 | temperature=0.2, 231 | inline_questions=[], 232 | ): 233 | """ 234 | Asynchronously summarizes a list of articles. 235 | Example usage: 236 | >> all_summaries = await async_summarize(articles) 237 | The order of the returned summaries is the same as the order of the input articles. 238 | All summaries are stored in the "summary" field of each article object, if update_object is True. 239 | 240 | Args: 241 | articles (list of obj): List of article objects. Each article object should have the following fields: 242 | text_cleaned (str): Full text of the article. 243 | prompt (str): Prompt to use for the API call (defaults to PROMPT_DICT["summarization"]["0"][0]). 244 | This is not the full prompt, but contains a placeholder for the article text. 245 | update_object (bool): Whether to update the article object with the summary (defaults to True). 246 | 247 | Returns: 248 | list of str: List of summaries (str) for each article. 249 | """ 250 | if len(articles) == 0 or not articles: 251 | return [] 252 | if inline_questions: 253 | question = inline_questions["title"] 254 | background = inline_questions["background"] 255 | resolution_criteria = inline_questions["resolution_criteria"] 256 | prompts = [ 257 | prompt.format( 258 | question=question, 259 | background=background, 260 | resolution_criteria=resolution_criteria, 261 | article=article.text_cleaned, 262 | ) 263 | for article in articles 264 | ] 265 | else: 266 | prompts = [prompt.format(article=article.text_cleaned) for article in articles] 267 | 268 | summarization_tasks = [ 269 | model_eval.get_async_response( 270 | prompt, model_name=model_name, temperature=temperature 271 | ) 272 | for prompt in prompts 273 | ] 274 | all_summaries = await asyncio.gather(*summarization_tasks) 275 | if update_object: 276 | for i, article in enumerate(articles): 277 | article.summary = all_summaries[i] 278 | return all_summaries 279 | -------------------------------------------------------------------------------- /llm_forecasting/utils/string_utils.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | import logging 3 | import re 4 | 5 | # Third-party imports 6 | import urllib.parse 7 | 8 | # Local application/library-specific imports 9 | from config.constants import TOKENS_TO_PROBS_DICT 10 | 11 | # Set up logging 12 | logging.basicConfig(level=logging.INFO) 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def is_string_in_list(target_string, string_list): 17 | """ 18 | Check if the target string is in the list of strings; case insensitive. 19 | """ 20 | # Convert the target string to lowercase 21 | target_string_lower = target_string.lower() 22 | 23 | # Check if the lowercase target string is in the list of lowercase strings 24 | return any(s.lower() == target_string_lower for s in string_list) 25 | 26 | 27 | def find_end_word(paragraph, end_words, window_size=50): 28 | """ 29 | Find one of the end_words in the last window_size words of the paragraph. 30 | Return the found word or None if no word is found. 31 | 32 | TODO: Lowercase the paragraph and end_words before searching so that the search is case-insensitive? 33 | 34 | Args: 35 | - paragraph (str): The paragraph to search in. 36 | - end_words (list of str): The words to search for. 37 | - window_size (int): The number of words from the end to search within. 38 | 39 | Returns: 40 | str: found word or None 41 | """ 42 | sorted_words = sorted(end_words, key=lambda s: len(s.split(" ")), reverse=True) 43 | for end_word in sorted_words: 44 | if end_word in paragraph[-window_size:]: 45 | return end_word 46 | logger.debug(f"Could not find any end word in {paragraph[-window_size:]}.") 47 | return None 48 | 49 | 50 | def get_prompt( 51 | prompt_template, 52 | fields, 53 | question=None, 54 | data_source=None, 55 | dates=None, 56 | background=None, 57 | resolution_criteria=None, 58 | num_keywords=None, 59 | retrieved_info=None, 60 | reasoning=None, 61 | article=None, 62 | summary=None, 63 | few_shot_examples=None, 64 | max_words=None, 65 | ): 66 | """ 67 | Fill in a prompt template with specific data based on provided fields. 68 | 69 | Args: 70 | prompt_template (str): The template containing placeholders for data 71 | insertion. 72 | fields (list): Placeholders within the template to be filled ('QUESTION', 73 | 'DATES', 'FEW_SHOT_EXAMPLES', 'RETRIEVED_INFO'). 74 | question (str, optional): The question text for the 'QUESTION' 75 | placeholder. 76 | data_source (str, optional): The platform (e.g. "metaculus"). 77 | background (str, optional): Background information for the 'BACKGROUND' 78 | placeholder. 79 | resolution_criteria (str, optional): Resolution criteria for the 80 | 'RESOLUTION_CRITERIA' placeholder. 81 | dates (tuple or list of strs, optional): Start and end dates for the 82 | 'DATES' placeholder (length == 2) 83 | retrieved_info (str, optional): Information text for the 'RETRIEVED_INFO' 84 | placeholder. 85 | reasoning (str, optional): Reasoning text for the 'REASONING' placeholder. 86 | article (str, optional): Article text for the 'ARTICLE' placeholder. 87 | summary (str, optional): Summary text for the 'SUMMARY' placeholder. 88 | few_shot_examples (list, optional): List of (question, answer) tuples for 89 | the 'FEW_SHOT_EXAMPLES' placeholder. 90 | max_words (int, optional): Maximum number of words for the 'MAX_WORDS' 91 | used for search query generation. 92 | 93 | Returns: 94 | str: A string with the placeholders in the template replaced with the provided 95 | data. 96 | """ 97 | mapping = {} 98 | for f in fields: 99 | if f == "QUESTION": 100 | mapping["question"] = question 101 | elif f == "DATES": 102 | mapping["date_begin"] = dates[0] 103 | mapping["date_end"] = dates[1] 104 | elif f == "FEW_SHOT_EXAMPLES": 105 | examples = few_shot_examples or [] 106 | for j, (q, answer) in enumerate(examples, 1): 107 | mapping[f"question_{j}"] = q 108 | mapping[f"answer_{j}"] = answer 109 | elif f == "RETRIEVED_INFO": 110 | mapping["retrieved_info"] = retrieved_info 111 | elif f == "BACKGROUND": 112 | mapping["background"] = background 113 | elif f == "RESOLUTION_CRITERIA": 114 | mapping["resolution_criteria"] = resolution_criteria 115 | elif f == "REASONING": 116 | mapping["reasoning"] = reasoning 117 | elif f == "BASE_REASONINGS": 118 | mapping["base_reasonings"] = reasoning 119 | elif f == "NUM_KEYWORDS": 120 | mapping["num_keywords"] = num_keywords 121 | elif f == "MAX_WORDS": 122 | mapping["max_words"] = str(max_words) 123 | elif f == "ARTICLE": 124 | mapping["article"] = article 125 | elif f == "SUMMARY": 126 | mapping["summary"] = summary 127 | elif f == "DATA_SOURCE": 128 | mapping["data_source"] = data_source 129 | return prompt_template.format(**mapping) 130 | 131 | 132 | def extract_probability_with_stars(text): 133 | """ 134 | Extract a probability value from a given text string. 135 | 136 | The function searches for numbers enclosed in asterisks (*), interpreting 137 | them as potential probability values. If a percentage sign is found with 138 | the number, it's converted to a decimal. The function returns the last 139 | number found that is less than or equal to 1, as a probability should be. 140 | If no such number is found, a default probability of 0.5 is returned. 141 | 142 | Args: 143 | - text (str): The text string from which the probability value is to be 144 | extracted. 145 | 146 | Returns: 147 | - float: The extracted probability value, if found. Otherwise, returns 0.5. 148 | """ 149 | # Regular expression to find numbers between stars 150 | pattern = r"\*(.*?[\d\.]+.*?)\*" 151 | matches = re.findall(pattern, text) 152 | 153 | # Extracting the numerical values from the matches 154 | extracted_numbers = [] 155 | for match in matches: 156 | # Extract only the numerical part (ignoring potential non-numeric 157 | # characters) 158 | number_match = re.search(r"[\d\.]+", match) 159 | if number_match: 160 | try: 161 | number = float(number_match.group()) 162 | if "%" in match: 163 | number /= 100 164 | extracted_numbers.append(number) 165 | except BaseException: 166 | continue 167 | 168 | if len(extracted_numbers) > 0 and extracted_numbers[-1] <= 1: 169 | return extracted_numbers[-1] 170 | 171 | # Regular expression to find numbers between stars 172 | pattern = r"([\d\.]+.*?)\*" 173 | matches = re.findall(pattern, text) 174 | 175 | # Extracting the numerical values from the matches 176 | extracted_numbers = [] 177 | for match in matches: 178 | # Extract only the numerical part (ignoring potential non-numeric 179 | # characters) 180 | number_matches = re.findall(r"[\d\.]+", match) 181 | for num_match in number_matches: 182 | try: 183 | number = float(num_match) 184 | extracted_numbers.append(number) 185 | except BaseException: 186 | continue 187 | 188 | if len(extracted_numbers) > 0 and extracted_numbers[-1] <= 1: 189 | return extracted_numbers[-1] 190 | 191 | return 0.5 192 | 193 | 194 | def extract_prediction( 195 | response, 196 | answer_type="probability", 197 | end_words=list(TOKENS_TO_PROBS_DICT["ten_options"].keys()), 198 | ): 199 | """ 200 | A generic function to extract a prediction from a response string. 201 | 202 | Args: 203 | response (str): The response string from which the prediction is to be 204 | extracted. 205 | answer_type (str): The type of answer to extract. Can be "probability" 206 | or "tokens". 207 | end_words (list): The list of end words to search for in the response 208 | string. The first end word found in the response string will be 209 | used to extract the prediction. 210 | Only used if answer_type == "tokens". 211 | 212 | Returns: 213 | str or float: The extracted prediction. 214 | """ 215 | if answer_type == "probability": 216 | return extract_probability_with_stars(response) 217 | elif answer_type == "tokens": 218 | return find_end_word(response, end_words) 219 | else: 220 | raise ValueError(f"Invalid answer_type: {answer_type}") 221 | 222 | 223 | def extract_and_decode_title_from_wikiurl(url): 224 | """ 225 | Extract the title from a Wikipedia URL and decode it. 226 | 227 | Args: 228 | url (str): The Wikipedia URL. 229 | 230 | Returns: 231 | str: The decoded title. None if the URL is invalid. 232 | """ 233 | if "wikipedia.org" in url and "upload.wikimedia.org" not in url: 234 | # Extract the part after '/wiki/' and before the first '#?' if present 235 | match = re.search(r"/wiki/([^#?]+)", url) 236 | if match: 237 | # Replace underscores with spaces and decode percent encoding 238 | return urllib.parse.unquote(re.sub(r"_", " ", match.group(1))) 239 | return None 240 | 241 | 242 | def concat_summaries_from_fields(summary_texts, titles, publish_dates): 243 | """ 244 | Concatenate summaries from a list of summary texts. Fill in the titles and 245 | publish dates for each summary. 246 | 247 | The length of the summary_texts, titles, and publish_dates lists should be 248 | the same. 249 | 250 | Args: 251 | summary_texts (list of str): A list of summary texts. 252 | titles (list of str): A list of titles. 253 | publish_dates (list str): A list of publish dates. 254 | 255 | Returns: 256 | str: The concatenated summaries with titles and publish dates. 257 | """ 258 | logger.info( 259 | f"Concatenating summaries from {len(summary_texts)} articles ({len(titles)} titles and {len(publish_dates)} dates)." 260 | ) 261 | if (len(summary_texts) != len(titles)) or ( 262 | len(summary_texts) != len(publish_dates) 263 | ): 264 | logger.error( 265 | f"Lengths of summary_texts, titles, and publish_dates should be the same. Got {len(summary_texts)}, {len(titles)}, and {len(publish_dates)}." 266 | ) 267 | return "Not available." 268 | article_summaries = [ 269 | f"[{i+1}] {titles[i]} (published on {(publish_dates[i] if publish_dates[i] else 'unknown date')})\nSummary: {article_summary}\n" 270 | for i, article_summary in enumerate(summary_texts) 271 | ] 272 | concatenated_summaries_str = ( 273 | "---\nARTICLES\n" + "\n".join(article_summaries) + "----" 274 | ) 275 | return concatenated_summaries_str -------------------------------------------------------------------------------- /llm_forecasting/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | import concurrent.futures 3 | import logging 4 | import re 5 | 6 | # Local application/library-specific imports 7 | from config.constants import S3, S3_BUCKET_NAME 8 | import model_eval 9 | from prompts.prompts import PROMPT_DICT 10 | from utils import db_utils, time_utils, string_utils 11 | 12 | # Set up logging 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def get_formatted_data( 18 | s3_path, 19 | retrieval_index=0, 20 | num_retrievals=5, 21 | questions_after="2015", 22 | return_raw_question_data=False, 23 | data=None, 24 | ): 25 | """ 26 | Retrieve and process training data from an S3 path. 27 | 28 | This function reads data from S3, processes it, and structures it for training purposes. 29 | It calculates retrieval dates and filters out data based on these dates. The function can 30 | optionally return raw question data. 31 | 32 | Also, the function can optionally take in the |data| directly. 33 | 34 | Parameters: 35 | s3_path (str): Path to the data file in S3. 36 | retrieval_index (int, optional): Index for calculating the retrieval date. Defaults to 0. 37 | num_retrievals (int, optional): Total number of retrievals to consider. Defaults to 5. 38 | return_raw_question_data (bool, optional): Flag to return raw question data. Defaults to False. 39 | data (list): List of forecasting questions. 40 | 41 | Returns: 42 | dict or tuple: A dictionary containing structured training data, or a tuple with the dictionary 43 | and raw data if return_raw_question_data is True. 44 | 45 | Note: 46 | This function expects specific keys in the data (e.g., 'date_close', 'date_resolve_at', etc.), 47 | and logs an error if reading from S3 fails. 48 | """ 49 | if not data: 50 | try: 51 | data = db_utils.read_pickle_from_s3(S3, S3_BUCKET_NAME, s3_path) 52 | except Exception as e: 53 | logger.error(f"Error reading data from S3: {e}") 54 | return {} 55 | 56 | question_dict = { 57 | "question_list": [], 58 | "background_list": [], 59 | "resolution_criteria_list": [], 60 | "question_dates_list": [], 61 | "resolve_dates_list": [], 62 | "retrieval_dates_list": [], 63 | "answer_list": [], 64 | "data_source_list": [], 65 | "community_pred_at_retrieval_list": [], 66 | "urls_in_background_list": [], 67 | "category_list": [], 68 | } 69 | raw_data = [] 70 | for q in data: 71 | q["date_close"] = q["date_close"] or q["date_resolve_at"] 72 | retrieval_date = time_utils.get_retrieval_date( 73 | retrieval_index, 74 | num_retrievals, 75 | q["date_begin"], 76 | q["date_close"], 77 | q["date_resolve_at"], 78 | ) 79 | 80 | if retrieval_date is None: 81 | continue 82 | elif retrieval_date == q["date_resolve_at"]: 83 | continue 84 | elif not time_utils.is_more_recent( 85 | f"{questions_after}-01-01", q["date_begin"], or_equal_to=True 86 | ): 87 | continue 88 | 89 | raw_data.append(q) 90 | for key, value in { 91 | "question_list": q["question"], 92 | "background_list": q["background"], 93 | "resolution_criteria_list": q["resolution_criteria"], 94 | "question_dates_list": ( 95 | time_utils.extract_date(q["date_begin"]), 96 | time_utils.extract_date(q["date_close"]), 97 | ), 98 | "resolve_dates_list": q["date_resolve_at"], 99 | "retrieval_dates_list": ( 100 | time_utils.extract_date(q["date_begin"]), 101 | retrieval_date, 102 | ), 103 | "answer_list": int(q["resolution"]), 104 | "data_source_list": q["data_source"], 105 | "community_pred_at_retrieval_list": time_utils.find_pred_with_closest_date( 106 | retrieval_date, q["community_predictions"] 107 | )[1], 108 | "urls_in_background_list": q["urls_in_background"], 109 | "category_list": q["gpt_3p5_category"], 110 | }.items(): 111 | question_dict[key].append(value) 112 | 113 | return (question_dict, raw_data) if return_raw_question_data else question_dict 114 | 115 | 116 | def format_single_question(data_dict, index): 117 | """ 118 | Format a single question, located by |index|, from a dictionary of all 119 | questions. 120 | 121 | Args: 122 | data_dict (dict): Dictionary containing all question data. 123 | index (int): Index of question. 124 | 125 | Returns: 126 | dict: Formatted question 127 | """ 128 | return { 129 | "question": data_dict["question_list"][index], 130 | "background": data_dict["background_list"][index], 131 | "resolution_criteria": data_dict["resolution_criteria_list"][index], 132 | "answer": data_dict["answer_list"][index], 133 | "question_dates": data_dict["question_dates_list"][index], 134 | "retrieval_dates": data_dict["retrieval_dates_list"][index], 135 | "data_source": data_dict["data_source_list"][index], 136 | "resolve_date": data_dict["resolve_dates_list"][index], 137 | "community_pred_at_retrieval": data_dict["community_pred_at_retrieval_list"][ 138 | index 139 | ], 140 | "urls_in_background": data_dict["urls_in_background_list"][index], 141 | "category": data_dict["category_list"][index], 142 | } 143 | 144 | 145 | def is_question_ill_defined(question, model_name): 146 | """ 147 | Determine if a given question is ill-defined using a specified model. 148 | Returns True if ill-defined, False if not, and None if the determination cannot be made. 149 | """ 150 | prompt = string_utils.get_prompt( 151 | PROMPT_DICT["data_wrangling"]["is_bad_title"][0], 152 | PROMPT_DICT["data_wrangling"]["is_bad_title"][1], 153 | question=question, 154 | ) 155 | response = model_eval.get_response_from_model( 156 | model_name, prompt, max_tokens=500, temperature=0.1 157 | ) 158 | 159 | if "Classification:" not in response: 160 | logger.error( 161 | f"'Classification:' is not in the response for question: {question}" 162 | ) 163 | return None 164 | 165 | end_resp = response.split("Classification:")[1] 166 | if "ok" in end_resp: 167 | return False 168 | elif "flag" in end_resp: 169 | logger.info(f"The following question is ill-defined: {question}") 170 | return True 171 | 172 | logger.error(f"Ambiguous response for question: {question}") 173 | return True 174 | 175 | 176 | def assign_ill_defined_questions(data_list, model_name="gpt-3.5-turbo-1106"): 177 | """ 178 | Evaluate each question in data_list to determine if it's ill-defined using the specified model. 179 | Modifies data_list in place by adding a key 'ill-defined' with a Boolean value. 180 | """ 181 | number_of_workers = 50 182 | 183 | with concurrent.futures.ThreadPoolExecutor( 184 | max_workers=number_of_workers 185 | ) as executor: 186 | future_to_question = { 187 | executor.submit(is_question_ill_defined, item["question"], model_name): item 188 | for item in data_list 189 | if "is_ill_defined" not in item 190 | } 191 | 192 | for future in concurrent.futures.as_completed(future_to_question): 193 | question_item = future_to_question[future] 194 | try: 195 | result = future.result() 196 | if result is not None: 197 | question_item["is_ill_defined"] = result 198 | else: 199 | logger.warning( 200 | f"Could not determine if question is ill-defined: {question_item['question']}" 201 | ) 202 | except Exception as exc: 203 | logger.error( 204 | f"Error processing question {question_item['question']}: {exc}" 205 | ) 206 | return None 207 | 208 | 209 | def assign_category(question, background, model_name): 210 | try: 211 | prompt = string_utils.get_prompt( 212 | PROMPT_DICT["data_wrangling"]["assign_category"][0], 213 | PROMPT_DICT["data_wrangling"]["assign_category"][1], 214 | question=question, 215 | background=background, 216 | ) 217 | response = model_eval.get_response_from_model( 218 | model_name, prompt, max_tokens=500, temperature=0.1 219 | ) 220 | return response.strip('"').strip("'").strip(" ").strip(".") 221 | except Exception as e: 222 | logger.error(f"Error in assign_category: {e}") 223 | return None 224 | 225 | 226 | def assign_categories(data_list, model_name="gpt-3.5-turbo-1106"): 227 | number_of_workers = 100 228 | updated_items = [] 229 | 230 | with concurrent.futures.ThreadPoolExecutor( 231 | max_workers=number_of_workers 232 | ) as executor: 233 | future_to_question = { 234 | executor.submit( 235 | assign_category, item["question"], item["background"], model_name 236 | ): item 237 | for item in data_list 238 | if "gpt_3p5_category" not in item 239 | } 240 | 241 | for future in concurrent.futures.as_completed(future_to_question): 242 | question_item = future_to_question[future] 243 | try: 244 | result = future.result() 245 | if result is not None: 246 | question_item["gpt_3p5_category"] = result 247 | else: 248 | logger.warning( 249 | f"Could not assign category: {question_item['question']}" 250 | ) 251 | updated_items.append(question_item) 252 | except Exception as exc: 253 | logger.error( 254 | f"Error processing question {question_item['question']}: {exc}" 255 | ) 256 | 257 | return None 258 | 259 | 260 | def reformat_metaculus_questions( 261 | data, 262 | model_name="gpt-3.5-turbo-1106", 263 | prompt=PROMPT_DICT["data_wrangling"]["reformat"], 264 | ): 265 | """ 266 | Reformat questions from Metaculus to be more readable. 267 | 268 | In particular, some questions have a title that ends with a parenthesis, 269 | containing the actual subject. 270 | This function rephrases it to be a Yes/No question. 271 | 272 | For example, 273 | >>> "Who will win the 2020 US presidential election? (Biden)" 274 | will be reformatted by the langauge model to 275 | >>> "Will Biden win the 2020 US presidential election?" 276 | 277 | Args: 278 | data (list of dict): List of questions in dictionary format. 279 | model_name (str, optional): Language model name, default is 280 | "gpt-3.5-turbo-1106". 281 | prompt (tuple of str, optional): Prompt to use for model evaluation. 282 | Default is PROMPT_DICT["data_cleaning"]["reformat"]. 283 | 284 | Returns: 285 | Modifies the input data in-place, and returns None. 286 | """ 287 | 288 | def find_text_between_stars(text): 289 | match = re.search(r"\*([^*]+)\*", text) 290 | return match.group(1) if match else None 291 | 292 | for d in data: 293 | if "? (" in d["title"]: 294 | prompt = string_utils.get_prompt( 295 | prompt[0], 296 | prompt[1], 297 | question=d["title"], 298 | ) 299 | response = model_eval.get_response_from_model( 300 | model_name=model_name, prompt=prompt 301 | ) 302 | transformed_title = find_text_between_stars(response) 303 | if transformed_title: 304 | d["title"] = transformed_title 305 | 306 | return None 307 | -------------------------------------------------------------------------------- /scripts/data_scraping/manifold.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import certifi 3 | import concurrent.futures 4 | import logging 5 | import requests 6 | from datetime import datetime, timedelta 7 | from tqdm import tqdm 8 | 9 | import data_scraping 10 | import information_retrieval 11 | from config.keys import keys 12 | from utils import time_utils 13 | 14 | # Logger configuration 15 | logger = logging.getLogger(__name__) 16 | 17 | # Manifold API configuration 18 | MANIFOLD_API = keys["MANIFOLD_KEY"] 19 | BASE_URL = "https://api.manifold.markets/v0/markets?" 20 | 21 | # Headers for API requests 22 | headers = {"Authorization": f"Key {MANIFOLD_API}"} 23 | 24 | 25 | def fetch_all_manifold_questions(base_url, headers, limit=1000): 26 | """ 27 | Fetch all questions from the Manifold API. 28 | 29 | Retrieve a list of questions from the Manifold API, paginating through the results 30 | based on the specified limit until all questions are fetched. 31 | 32 | Args: 33 | base_url (str): The base URL for the Manifold API. 34 | headers (dict): The headers for the API request, including authorization. 35 | limit (int): The maximum number of questions to fetch in each request. 36 | 37 | Returns: 38 | list: A list of all questions fetched from the API. 39 | """ 40 | all_questions = [] 41 | last_id = None 42 | 43 | while True: 44 | params = {"limit": limit, "sort": "created-time", "order": "desc"} 45 | if last_id: 46 | params["before"] = last_id 47 | 48 | response = requests.get( 49 | base_url, headers=headers, params=params, verify=certifi.where() 50 | ) 51 | if response.status_code != 200: 52 | print(f"Error: {response.status_code}") 53 | break 54 | 55 | data = response.json() 56 | if not data: 57 | break 58 | 59 | all_questions.extend(data) 60 | last_id = data[-1]["id"] 61 | 62 | return all_questions 63 | 64 | 65 | def fetch_market_details(market_id, headers): 66 | """ 67 | Fetch detailed information for a specific market from the Manifold API. 68 | 69 | Retrieve detailed information about a specific market, identified by its market ID, 70 | from the Manifold API. 71 | 72 | Args: 73 | market_id (str): The unique identifier of the market. 74 | headers (dict): Headers for the API request, including authorization. 75 | 76 | Returns: 77 | dict or None: The detailed market information or None if an error occurs. 78 | """ 79 | try: 80 | url = f"https://api.manifold.markets/v0/market/{market_id}" 81 | response = requests.get(url, headers=headers, verify=certifi.where()) 82 | response.raise_for_status() # Raises an HTTPError for non-200 responses 83 | return response.json() 84 | except requests.exceptions.HTTPError as e: 85 | print(f"HTTP error fetching market details for ID {market_id}: {e}") 86 | except requests.exceptions.RequestException as e: 87 | print(f"Error fetching market details for ID {market_id}: {e}") 88 | return None 89 | 90 | 91 | def fetch_bets_for_market(market_id, headers): 92 | """ 93 | Fetch a list of bets for a specific market from the Manifold API. 94 | 95 | Retrieve all bets placed in a specific market, identified by its market ID, from 96 | the Manifold API. Handle any HTTP errors encountered during the request and return 97 | an empty list if an error occurs. 98 | 99 | Args: 100 | market_id (str): The unique identifier of the market. 101 | headers (dict): Headers for the API request, including authorization. 102 | 103 | Returns: 104 | list: A list of bets for the specified market or an empty list if an error occurs. 105 | """ 106 | try: 107 | url = "https://api.manifold.markets/v0/bets" 108 | params = {"contractId": market_id, "limit": 1000} 109 | response = requests.get( 110 | url, headers=headers, params=params, verify=certifi.where() 111 | ) 112 | response.raise_for_status() 113 | return response.json() 114 | except requests.exceptions.HTTPError as e: 115 | print(f"HTTP error fetching bets for market ID {market_id}: {e}") 116 | except requests.exceptions.RequestException as e: 117 | print(f"Error fetching bets for market ID {market_id}: {e}") 118 | return [] 119 | 120 | 121 | def fetch_comments_for_market(market_id, headers): 122 | """ 123 | Fetch a list of comments for a specific market from the Manifold API. 124 | 125 | Retrieve all comments made in a specific market, identified by its market ID, from 126 | the Manifold API. Handle any HTTP errors encountered during the request and return 127 | an empty list if an error occurs. 128 | 129 | Args: 130 | market_id (str): The unique identifier of the market. 131 | headers (dict): Headers for the API request, including authorization. 132 | 133 | Returns: 134 | list: A list of comments for the specified market or an empty list if an error occurs. 135 | """ 136 | try: 137 | url = "https://api.manifold.markets/v0/comments" 138 | params = {"contractId": market_id, "limit": 1000} 139 | response = requests.get( 140 | url, headers=headers, params=params, verify=certifi.where() 141 | ) 142 | response.raise_for_status() 143 | return response.json() 144 | except requests.exceptions.HTTPError as e: 145 | print(f"HTTP error fetching comments for market ID {market_id}: {e}") 146 | except requests.exceptions.RequestException as e: 147 | print(f"Error fetching comments for market ID {market_id}: {e}") 148 | return [] 149 | 150 | 151 | def process_market(market, headers): 152 | """ 153 | Process a single market by fetching its details, bets, comments, and transforming the data. 154 | 155 | Fetch and add market descriptions, and map bets and comments for a given market. 156 | Convert timestamps and restructure market data for consistency and clarity. 157 | 158 | Args: 159 | market (dict): The market data to process. 160 | headers (dict): Headers for the API requests. 161 | 162 | Returns: 163 | dict: The processed market data. 164 | """ 165 | market_id = market["id"] 166 | 167 | # Fetch and add market descriptions 168 | try: 169 | market_details = fetch_market_details(market_id, headers) 170 | if market_details: 171 | market["background"] = market_details.get("description") 172 | 173 | background_content = market.get("background", {}).get("content", []) 174 | background_text = " ".join( 175 | [ 176 | item["content"][0].get("text", "") 177 | for item in background_content 178 | if item.get("type") == "paragraph" and item.get("content") 179 | ] 180 | ) 181 | market["background"] = background_text 182 | except Exception as exc: 183 | logger.error( 184 | f"Market id {market_id} got processing error when fetching description: {exc}" 185 | ) 186 | try: 187 | market["resolution"] = 1 if market["resolution"] == "YES" else 0 188 | except Exception as exc: 189 | logger.error( 190 | f"Market id {market_id} got processing error when reformatting resolution: {exc}" 191 | ) 192 | 193 | # Fetch and map bets and comments 194 | market["community_predictions"] = fetch_bets_for_market(market_id, headers) 195 | market["comments"] = fetch_comments_for_market(market_id, headers) 196 | 197 | for comment in market.get("comments", []): 198 | if "createdTime" in comment: 199 | comment["createdTime"] = time_utils.convert_timestamp( 200 | comment["createdTime"] 201 | ) 202 | 203 | for bet in market.get("community_predictions", []): 204 | if "createdTime" in bet: 205 | bet["createdTime"] = time_utils.convert_timestamp(bet["createdTime"]) 206 | 207 | try: 208 | market["date_close"] = market.pop("closeTime") 209 | except Exception as exc: 210 | logger.error( 211 | f"Market id {market_id} got processing error when popping closeTime: {exc}" 212 | ) 213 | try: 214 | market["date_begin"] = market.pop("createdTime") 215 | except Exception as exc: 216 | logger.error( 217 | f"Market id {market_id} got processing error when popping createdTime: {exc}" 218 | ) 219 | try: 220 | market["is_resolved"] = market.pop("isResolved") 221 | except Exception as exc: 222 | logger.error( 223 | f"Market id {market_id} got processing error when popping isResolved: {exc}" 224 | ) 225 | try: 226 | market["question_type"] = market.pop("outcomeType") 227 | except Exception as exc: 228 | logger.error( 229 | f"Market id {market_id} got processing error when popping outcomeType: {exc}" 230 | ) 231 | 232 | try: 233 | market["resolved_time"] = market.pop("resolutionTime") 234 | except Exception as exc: 235 | logger.error( 236 | f"Market id {market_id} got processing error when reformatting resolution time: {exc}" 237 | ) 238 | 239 | if not market["background"]: 240 | market["background"] = "Not applicable/available for this question." 241 | 242 | market["resolution_criteria"] = "Not applicable/available for this question." 243 | market["data_source"] = "manifold" 244 | 245 | return market 246 | 247 | 248 | def main(n_days): 249 | """ 250 | Process and upload Manifold market data. 251 | 252 | Fetch market data from Manifold, process it, and upload the processed data 253 | to an AWS S3 bucket. 254 | Limit the fetched data to markets created within the last N days, if specified. 255 | 256 | Args: 257 | n_days (int or None): Number of days to look back for markets. 258 | If None, fetches all markets. 259 | 260 | Returns: 261 | list: A list of processed market data. 262 | """ 263 | logger.info("Starting the manifold script...") 264 | 265 | all_markets = fetch_all_manifold_questions(BASE_URL, headers) 266 | 267 | # Transform time format 268 | ids_with_date_errors = [] 269 | for key in [ 270 | "createdTime", 271 | "closeTime", 272 | "resolutionTime", 273 | "lastUpdatedTime", 274 | "lastCommentTime", 275 | ]: 276 | for market in all_markets: 277 | if key in market: 278 | try: 279 | market[key] = time_utils.convert_timestamp(market[key]) 280 | except BaseException: 281 | ids_with_date_errors.append(market["id"]) 282 | continue 283 | 284 | logger.info( 285 | f"Number of manifold questions with date errors: {len(set(ids_with_date_errors))}" 286 | ) 287 | 288 | if n_days is not None: 289 | date_limit = datetime.now() - timedelta(days=n_days) 290 | all_markets = [ 291 | q 292 | for q in all_markets 293 | if datetime.fromisoformat(q["createdTime"][:]) >= date_limit 294 | ] 295 | 296 | logger.info(f"Number of manifold questions fetched: {len(all_markets)}") 297 | 298 | processed_markets = [] 299 | number_of_workers = 50 300 | 301 | with concurrent.futures.ThreadPoolExecutor( 302 | max_workers=number_of_workers 303 | ) as executor: 304 | # Create a list to hold the futures 305 | futures = [ 306 | executor.submit(process_market, market, headers) for market in all_markets 307 | ] 308 | 309 | # Use tqdm to create a progress bar. Wrap futures with tqdm. 310 | for future in tqdm( 311 | concurrent.futures.as_completed(futures), 312 | total=len(all_markets), 313 | desc="Processing", 314 | ): 315 | result = future.result() 316 | processed_markets.append(result) 317 | 318 | processed_markets = [ 319 | q 320 | for q in processed_markets 321 | if "question_type" in q.keys() and q["community_predictions"] 322 | ] 323 | 324 | # Save itermediate data files 325 | logger.info("Uploading the intermediate data files to s3...") 326 | question_types = list(set([q["question_type"] for q in processed_markets])) 327 | data_scraping.upload_scraped_data(processed_markets, "manifold", question_types) 328 | 329 | for q in processed_markets: 330 | q["date_resolve_at"] = q["community_predictions"][0]["createdTime"] 331 | 332 | logger.info("Start extracting articles from links...") 333 | 334 | for question in processed_markets: 335 | question["extracted_urls"] = information_retrieval.get_urls_from_text( 336 | question["background"] 337 | ) 338 | if question["comments"]: 339 | for comment in question["comments"]: 340 | try: 341 | comment = comment["content"]["content"][0]["content"][-1]["text"] 342 | question["extracted_articles_urls"].extend( 343 | information_retrieval.retrieve_webpage_from_background( 344 | comment, question["close_time"] 345 | ) 346 | ) 347 | except: 348 | continue 349 | 350 | logger.info("Uploading to s3...") 351 | question_types = list(set([q["question_type"] for q in processed_markets])) 352 | data_scraping.upload_scraped_data(processed_markets, "manifold", question_types) 353 | 354 | 355 | if __name__ == "__main__": 356 | parser = argparse.ArgumentParser(description="Fetch Manifold data.") 357 | parser.add_argument( 358 | "--n_days", 359 | type=int, 360 | help="Fetch questions created in the last N days", 361 | default=None, 362 | ) 363 | args = parser.parse_args() 364 | main(args.n_days) 365 | -------------------------------------------------------------------------------- /llm_forecasting/model_eval.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | import asyncio 3 | import logging 4 | import time 5 | 6 | # Related third-party imports 7 | import openai 8 | import together 9 | import anthropic 10 | import google.generativeai as google_ai 11 | 12 | # Local application/library-specific imports 13 | from config.constants import ( 14 | OAI_SOURCE, 15 | ANTHROPIC_SOURCE, 16 | TOGETHER_AI_SOURCE, 17 | GOOGLE_SOURCE, 18 | ) 19 | from config.keys import ( 20 | ANTHROPIC_KEY, 21 | OPENAI_KEY, 22 | TOGETHER_KEY, 23 | GOOGLE_AI_KEY, 24 | ) 25 | from utils import model_utils, string_utils 26 | 27 | # Setup code 28 | if ANTHROPIC_KEY: 29 | anthropic_console = anthropic.Anthropic(api_key=ANTHROPIC_KEY) 30 | anthropic_async_client = anthropic.AsyncAnthropic(api_key=ANTHROPIC_KEY) 31 | 32 | if OPENAI_KEY: 33 | oai_async_client = openai.AsyncOpenAI(api_key=OPENAI_KEY) 34 | oai = openai.OpenAI(api_key=OPENAI_KEY) 35 | 36 | if TOGETHER_KEY: 37 | together.api_key = TOGETHER_KEY 38 | client = openai.OpenAI( 39 | api_key=TOGETHER_KEY, 40 | base_url="https://api.together.xyz/v1", 41 | ) 42 | 43 | if GOOGLE_AI_KEY: 44 | google_ai.configure(api_key=GOOGLE_AI_KEY) 45 | 46 | # Set up logging 47 | logging.basicConfig(level=logging.INFO) 48 | logger = logging.getLogger(__name__) 49 | 50 | 51 | def get_response_with_retry(api_call, wait_time, error_msg): 52 | """ 53 | Make an API call and retry on failure after a specified wait time. 54 | 55 | Args: 56 | api_call (function): API call to make. 57 | wait_time (int): Time to wait before retrying, in seconds. 58 | error_msg (str): Error message to print on failure. 59 | """ 60 | while True: 61 | try: 62 | return api_call() 63 | except Exception as e: 64 | logger.info(f"{error_msg}: {e}") 65 | logger.info(f"Waiting for {wait_time} seconds before retrying...") 66 | time.sleep(wait_time) 67 | 68 | 69 | def get_response_from_oai_model( 70 | model_name, prompt, system_prompt, max_tokens, temperature, wait_time 71 | ): 72 | """ 73 | Make an API call to the OpenAI API and retry on failure after a specified 74 | wait time. 75 | 76 | Args: 77 | model_name (str): Name of the model to use (such as "gpt-4"). 78 | prompt (str): Fully specififed prompt to use for the API call. 79 | system_prompt (str): Prompt to use for system prompt. 80 | max_tokens (int): Maximum number of tokens to sample. 81 | temperature (float): Sampling temperature. 82 | wait_time (int): Time to wait before retrying, in seconds. 83 | 84 | Returns: 85 | str: Response string from the API call. 86 | """ 87 | 88 | def api_call(): 89 | """ 90 | Make an API call to the OpenAI API, without retrying on failure. 91 | 92 | Returns: 93 | str: Response string from the API call. 94 | """ 95 | model_input = ( 96 | [{"role": "system", "content": system_prompt}] if system_prompt else [] 97 | ) 98 | model_input.append({"role": "user", "content": prompt}) 99 | response = oai.chat.completions.create( 100 | model=model_name, 101 | messages=model_input, 102 | max_tokens=max_tokens, 103 | temperature=temperature, 104 | ) 105 | # logger.info(f"full prompt: {prompt}") 106 | return response.choices[0].message.content 107 | 108 | return get_response_with_retry( 109 | api_call, wait_time, "OpenAI API request exceeded rate limit." 110 | ) 111 | 112 | 113 | def get_response_from_anthropic_model( 114 | model_name, prompt, max_tokens, temperature, wait_time 115 | ): 116 | """ 117 | Make an API call to the Anthropic API and retry on failure after a 118 | specified wait time. 119 | 120 | Args: 121 | model_name (str): Name of the model to use (such as "claude-2"). 122 | prompt (str): Fully specififed prompt to use for the API call. 123 | max_tokens (int): Maximum number of tokens to sample. 124 | temperature (float): Sampling temperature. 125 | wait_time (int): Time to wait before retrying, in seconds. 126 | 127 | Returns: 128 | str: Response string from the API call. 129 | """ 130 | if max_tokens > 4096: 131 | max_tokens = 4096 132 | 133 | def api_call(): 134 | completion = anthropic_console.messages.create( 135 | model=model_name, 136 | messages=[{"role": "user", "content": prompt}], 137 | temperature=temperature, 138 | max_tokens=max_tokens, 139 | ) 140 | return completion.content[0].text 141 | 142 | return get_response_with_retry( 143 | api_call, wait_time, "Anthropic API request exceeded rate limit." 144 | ) 145 | 146 | 147 | def get_response_from_together_ai_model( 148 | model_name, prompt, max_tokens, temperature, wait_time 149 | ): 150 | """ 151 | Make an API call to the Together AI API and retry on failure after a 152 | specified wait time. 153 | 154 | Args: 155 | model_name (str): Name of the model to use (such as "togethercomputer/ 156 | llama-2-13b-chat"). 157 | prompt (str): Fully specififed prompt to use for the API call. 158 | max_tokens (int): Maximum number of tokens to sample. 159 | temperature (float): Sampling temperature. 160 | wait_time (int): Time to wait before retrying, in seconds. 161 | 162 | Returns: 163 | str: Response string from the API call. 164 | """ 165 | 166 | def api_call(): 167 | chat_completion = client.chat.completions.create( 168 | model=model_name, 169 | messages=[ 170 | {"role": "user", "content": prompt}, 171 | ], 172 | temperature=temperature, 173 | max_tokens=max_tokens, 174 | ) 175 | response = chat_completion.choices[0].message.content 176 | 177 | return response 178 | 179 | return get_response_with_retry( 180 | api_call, wait_time, "Together AI API request exceeded rate limit." 181 | ) 182 | 183 | 184 | def get_response_from_google_model( 185 | model_name, prompt, max_tokens, temperature, wait_time 186 | ): 187 | """ 188 | Make an API call to the Together AI API and retry on failure after a specified wait time. 189 | 190 | Args: 191 | model (str): Name of the model to use (such as "gemini-pro"). 192 | prompt (str): Initial prompt for the API call. 193 | max_tokens (int): Maximum number of tokens to sample. 194 | temperature (float): Sampling temperature. 195 | wait_time (int): Time to wait before retrying, in seconds. 196 | 197 | Returns: 198 | str: Response string from the API call. 199 | """ 200 | model = google_ai.GenerativeModel(model_name) 201 | 202 | response = model.generate_content( 203 | prompt, 204 | generation_config=google_ai.types.GenerationConfig( 205 | candidate_count=1, 206 | max_output_tokens=max_tokens, 207 | temperature=temperature, 208 | ), 209 | ) 210 | return response.text 211 | 212 | 213 | def get_response_from_model( 214 | model_name, 215 | prompt, 216 | system_prompt="", 217 | max_tokens=2000, 218 | temperature=0.8, 219 | wait_time=30, 220 | ): 221 | """ 222 | Make an API call to the specified model and retry on failure after a 223 | specified wait time. 224 | 225 | Args: 226 | model_name (str): Name of the model to use (such as "gpt-4"). 227 | prompt (str): Fully specififed prompt to use for the API call. 228 | system_prompt (str, optional): Prompt to use for system prompt. 229 | max_tokens (int, optional): Maximum number of tokens to generate. 230 | temperature (float, optional): Sampling temperature. 231 | wait_time (int, optional): Time to wait before retrying, in seconds. 232 | """ 233 | model_source = model_utils.infer_model_source(model_name) 234 | if model_source == OAI_SOURCE: 235 | return get_response_from_oai_model( 236 | model_name, prompt, system_prompt, max_tokens, temperature, wait_time 237 | ) 238 | elif model_source == ANTHROPIC_SOURCE: 239 | return get_response_from_anthropic_model( 240 | model_name, prompt, max_tokens, temperature, wait_time 241 | ) 242 | elif model_source == TOGETHER_AI_SOURCE: 243 | return get_response_from_together_ai_model( 244 | model_name, prompt, max_tokens, temperature, wait_time 245 | ) 246 | elif model_source == GOOGLE_SOURCE: 247 | return get_response_from_google_model( 248 | model_name, prompt, max_tokens, temperature, wait_time 249 | ) 250 | else: 251 | return "Not a valid model source." 252 | 253 | 254 | async def get_async_response( 255 | prompt, 256 | model_name="gpt-3.5-turbo-1106", 257 | temperature=0.0, 258 | max_tokens=8000, 259 | ): 260 | """ 261 | Asynchronously get a response from the OpenAI API. 262 | 263 | Args: 264 | prompt (str): Fully specififed prompt to use for the API call. 265 | model_name (str, optional): Name of the model to use (such as "gpt-3.5-turbo"). 266 | temperature (float, optional): Sampling temperature. 267 | max_tokens (int, optional): Maximum number of tokens to sample. 268 | 269 | Returns: 270 | str: Response string from the API call (not the dictionary). 271 | """ 272 | model_source = model_utils.infer_model_source(model_name) 273 | while True: 274 | try: 275 | if model_source == OAI_SOURCE: 276 | response = await oai_async_client.chat.completions.create( 277 | model=model_name, 278 | messages=[{"role": "user", "content": prompt}], 279 | temperature=temperature, 280 | ) 281 | return response.choices[0].message.content 282 | elif model_source == ANTHROPIC_SOURCE: 283 | response = await anthropic_async_client.messages.create( 284 | model=model_name, 285 | messages=[{"role": "user", "content": prompt}], 286 | temperature=temperature, 287 | max_tokens=4096, 288 | ) 289 | return response.content[0].text 290 | elif model_source == GOOGLE_SOURCE: 291 | model = google_ai.GenerativeModel(model_name) 292 | response = await model.generate_content_async( 293 | prompt, 294 | generation_config=google_ai.types.GenerationConfig( 295 | candidate_count=1, 296 | max_output_tokens=max_tokens, 297 | temperature=temperature, 298 | ), 299 | ) 300 | return response.text 301 | elif model_source == TOGETHER_AI_SOURCE: 302 | chat_completion = await asyncio.to_thread( 303 | client.chat.completions.create, 304 | model=model_name, 305 | messages=[ 306 | {"role": "user", "content": prompt}, 307 | ], 308 | temperature=temperature, 309 | max_tokens=max_tokens, 310 | ) 311 | return chat_completion.choices[0].message.content 312 | else: 313 | logger.debug("Not a valid model source: {model_source}") 314 | return "" 315 | except (Exception, BaseException) as e: 316 | logger.info(f"Exception, erorr message: {e}") 317 | logger.info("Waiting for 30 seconds before retrying...") 318 | time.sleep(30) 319 | continue 320 | 321 | 322 | def get_openai_embedding(texts, model="text-embedding-3-large"): 323 | """ 324 | Query OpenAI's text embedding model to get the embedding of the given text. 325 | 326 | Args: 327 | texts (list of str): List of texts to embed. 328 | 329 | Returns: 330 | list of Embedding objects: List of embeddings, where embedding[i].embedding is a list of floats. 331 | """ 332 | texts = [text.replace("\n", " ") for text in texts] 333 | while True: 334 | try: 335 | embedding = oai.embeddings.create(input=texts, model=model) 336 | return embedding.data 337 | except Exception as e: 338 | logger.info(f"erorr message: {e}") 339 | logger.info("Waiting for 30 seconds before retrying...") 340 | time.sleep(30) 341 | continue 342 | 343 | 344 | async def async_make_forecast( 345 | question, 346 | background_info, 347 | resolution_criteria, 348 | dates, 349 | retrieved_info, 350 | reasoning_prompt_templates, 351 | model_name="gpt-4-1106-preview", 352 | temperature=1.0, 353 | return_prompt=False, 354 | ): 355 | """ 356 | Asynchronously make forecasts using the given information. 357 | 358 | Args: 359 | question (str): Question to ask the model. 360 | background_info (str): Background information to provide to the model. 361 | resolution_criteria (str): Resolution criteria to provide to the model. 362 | dates (str): Dates to provide to the model. 363 | retrieved_info (str): Retrieved information to provide to the model. 364 | reasoning_prompt_templates (list of str): List of reasoning prompt templates to use. 365 | model_name (str, optional): Name of the model to use (such as "gpt-4-1106-preview"). 366 | temperature (float, optional): Sampling temperature. 367 | return_prompt (bool, optional): Whether to return the full prompt or not. 368 | 369 | Returns: 370 | list of str: List of forecasts and reasonings from the model. 371 | """ 372 | assert ( 373 | len(reasoning_prompt_templates) > 0 374 | ), "No reasoning prompt templates provided." 375 | reasoning_full_prompts = [] 376 | for reasoning_prompt_template in reasoning_prompt_templates: 377 | template, fields = reasoning_prompt_template 378 | reasoning_full_prompts.append( 379 | string_utils.get_prompt( 380 | template, 381 | fields, 382 | question=question, 383 | retrieved_info=retrieved_info, 384 | background=background_info, 385 | resolution_criteria=resolution_criteria, 386 | dates=dates, 387 | ) 388 | ) 389 | # Get all reasonings from the model 390 | reasoning_tasks = [ 391 | get_async_response( 392 | prompt, 393 | model_name=model_name, 394 | temperature=temperature, 395 | ) 396 | for prompt in reasoning_full_prompts 397 | ] 398 | # a list of strings 399 | all_reasonings = await asyncio.gather(*reasoning_tasks) 400 | logger.info( 401 | "Finished {} base reasonings generated by {}".format( 402 | len(reasoning_full_prompts), model_name 403 | ) 404 | ) 405 | if return_prompt: 406 | return all_reasonings, reasoning_full_prompts 407 | return all_reasonings 408 | -------------------------------------------------------------------------------- /llm_forecasting/ensemble.py: -------------------------------------------------------------------------------- 1 | # Standard library imports 2 | import logging 3 | 4 | # Related third-party imports 5 | import numpy as np 6 | 7 | # Local application/library-specific imports 8 | from config.constants import TOKENS_TO_PROBS_DICT 9 | import model_eval 10 | from prompts.prompts import PROMPT_DICT 11 | from utils import string_utils, utils 12 | 13 | # Set up logging 14 | logging.basicConfig(level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def concatenate_reasonings(reasonings): 19 | """ 20 | Concatenate a list of reasonings into a single string. 21 | 22 | Each reasoning is separated by a newline, a separator (---) and a number 23 | (Response 1, 2, 3, ...). 24 | 25 | Args: 26 | reasonings (list[str]): A list of reasonings. 27 | 28 | Returns: 29 | str: A single string containing all reasonings. 30 | """ 31 | concat_reasonings = [] 32 | for i, reasoning in enumerate(reasonings): 33 | reason_str = f"Response from forecaster {i + 1}:\n{reasoning}" 34 | concat_reasonings.append(reason_str) 35 | return "---\n" + "\n\n-\n".join(concat_reasonings) + "\n---" 36 | 37 | 38 | async def meta_reason( 39 | question, 40 | background_info, 41 | resolution_criteria, 42 | today_to_close_date_range, 43 | retrieved_info, 44 | reasoning_prompt_templates, 45 | base_model_names=["gpt-4-1106-preview", "claude-2.1"], 46 | base_temperature=1.0, # temperature for the base reasonings 47 | aggregation_method="meta", 48 | answer_type="probability", 49 | weights=None, 50 | end_words=list(TOKENS_TO_PROBS_DICT["ten_options"].keys()), 51 | meta_model_name="gpt-4-1106-preview", 52 | meta_prompt_template=PROMPT_DICT["meta_reasoning"]["0"], 53 | meta_temperature=0.2, 54 | ): 55 | """ 56 | Given a question and its retrieved articles, elicit model reasonings via 57 | reasoning_prompts, aggregate the reasonings and return the answer. 58 | 59 | Args: 60 | question (str): Forecast question to be answered. 61 | background_info (str): Background information of the question. 62 | resolution_criteria (str): Resolution criteria for the question. 63 | retrieved_info (str): Retrieved articles from our news retrieval system 64 | (a concatenation of the article titles and summaries). 65 | today_to_close_date_range (str): A string containing the today's date 66 | and the close date. 67 | retrieved_info (str): Retrieved articles from our news retrieval system. 68 | reasoning_prompt_templates (list[list[[str]]): A list of reasoning prompts; string templates 69 | that must have tow fields {question} and {retrieved_info}. 70 | There should be a list of reasoning prompts for each base model. 71 | base_model_names (list[str], optional): A list of base model names. 72 | base_temperature (float, optional): Sampling temperature for the base reasonings. 73 | aggregation_method (str, optional): The method to aggregate the reasonings. 74 | Must be either 'vote-or-median','mean', 'weighted-mean', 'meta'. 75 | answer_type (str, optional): The type of the answer to return. Must be either 'probability' or 'tokens'. 76 | weights (np.array, optional): A numpy array of weights for the reasonings. 77 | It will only be used if aggregation_method is 'weighted-mean'. 78 | It must have the same length as reasoning_prompt_templates: shape[0] == 79 | len(reasoning_prompt_templates). 80 | end_words (list, optional): A list of words like "Very Unlikely" and "Very Likely" that represent the answer. 81 | It will only be used if answer_type is 'tokens' and aggregation_method 82 | is 'vote-or-median'. 83 | meta_model_name (str, optional): The name of the meta model. 84 | meta_prompt_template (tuple of str, optional): A meta reasoning prompt template; a string template 85 | meta_temparature (float, optional): Sampling temperature for the meta-reasoning. 86 | 87 | Returns: 88 | tuple: The method returns the final answer, all base reasonings, and the meta-reasoning 89 | (if aggregation_method is 'meta'). 90 | 91 | For the final answer: 92 | If answer_type is 'probability' and aggregation_method is 'vote-or-median', 93 | the function returns the median of all answers. 94 | If answer_type is 'tokens' and aggregation_method is 'vote-or-median', 95 | the function returns the most frequent answer. 96 | If the aggregation_method is 'meta', the function returns an answer 97 | by eliciting another meta-reasoning using the meta_prompt_template. 98 | """ 99 | assert answer_type in [ 100 | "probability", 101 | "tokens", 102 | ], "answer_type must be either 'probability' or 'tokens'" 103 | assert aggregation_method in [ 104 | "vote-or-median", 105 | "meta", 106 | "mean", 107 | "weighted-mean", 108 | ], "aggregation_method must be either 'vote-or-median', 'meta', 'mean', or 'weighted-mean'" 109 | if aggregation_method == "weighted-mean": 110 | assert ( 111 | weights is not None 112 | ), "weights must be provided if aggregation_method is 'weighted-mean'" 113 | assert weights.shape[0] == len( 114 | reasoning_prompt_templates 115 | ), "weights must have the same length as reasoning_prompt_templates" 116 | all_base_reasonings = [] 117 | all_base_reasoning_full_prompts = [] 118 | for i, base_model_name in enumerate(base_model_names): 119 | ( 120 | base_reasonings, 121 | base_reasoning_full_prompts, 122 | ) = await model_eval.async_make_forecast( 123 | question=question, 124 | background_info=background_info, 125 | resolution_criteria=resolution_criteria, 126 | dates=today_to_close_date_range, 127 | retrieved_info=retrieved_info, 128 | reasoning_prompt_templates=reasoning_prompt_templates[i], 129 | model_name=base_model_name, 130 | temperature=base_temperature, 131 | return_prompt=True, 132 | ) 133 | # list of lists (not flattened) 134 | all_base_reasonings.append(base_reasonings) 135 | all_base_reasoning_full_prompts.append(base_reasoning_full_prompts) 136 | aggregation_dict = aggregate_base_reasonings( 137 | base_reasonings=all_base_reasonings, 138 | question=question, 139 | background_info=background_info, 140 | today_to_close_date_range=today_to_close_date_range, 141 | resolution_criteria=resolution_criteria, 142 | retrieved_info=retrieved_info, 143 | aggregation_method=aggregation_method, 144 | answer_type=answer_type, 145 | weights=weights, 146 | end_words=end_words, 147 | model_name=meta_model_name, # meta model name 148 | meta_prompt_template=meta_prompt_template, 149 | meta_temperature=meta_temperature, 150 | ) 151 | aggregation_dict["base_reasoning_full_prompts"] = all_base_reasoning_full_prompts 152 | return aggregation_dict 153 | 154 | 155 | def aggregate_base_reasonings( 156 | base_reasonings, 157 | question, 158 | background_info, 159 | today_to_close_date_range, 160 | resolution_criteria, 161 | retrieved_info, 162 | aggregation_method="meta", 163 | answer_type="probability", 164 | weights=None, 165 | end_words=list(TOKENS_TO_PROBS_DICT["ten_options"].keys()), 166 | model_name="gpt-4-1106-preview", # meta model name 167 | meta_prompt_template=PROMPT_DICT["meta_reasoning"]["0"], 168 | meta_temperature=0.2, 169 | ): 170 | """ 171 | Aggregate a list of lists of base reasonings via ensembling method. 172 | 173 | Args: 174 | base_reasonings (list[list[str]]): A list of lists of base reasonings. 175 | question (str): Forecast question to be answered. 176 | background_info (str): Background information of the question. 177 | today_to_close_date_range (str): A string containing the today's date and the close date. 178 | resolution_criteria (str): Resolution criteria for the question. 179 | retrieved_info (str): Retrieved articles from our news retrieval system 180 | (a concatenation of the article titles and summaries). 181 | aggregation_method (str, optional): The method to aggregate the reasonings. 182 | Must be either 'vote-or-median','mean', 'weighted-mean', 'meta'. 183 | answer_type (str, optional): The type of the answer to return. Must be either 'probability' or 'tokens'. 184 | weights (np.array, optional): A numpy array of weights for the reasonings. 185 | It will only be used if aggregation_method is 'weighted-mean'. 186 | It must have the same length as reasoning_prompt_templates: shape[0] == 187 | len(reasoning_prompt_templates). 188 | end_words (list, optional): A list of words like "Very Unlikely" and "Very Likely" that represent the answer. 189 | It will only be used if answer_type is 'tokens' and aggregation_method 190 | is 'vote-or-median'. 191 | model_name (str, optional): The name of the meta model. 192 | meta_prompt_template (str, optional): A string that represents the meta reasoning prompt. 193 | meta_temperature (float, optional): Sampling temperature for the meta-reasoning. 194 | 195 | Returns: 196 | dict: A dictionary containing the final answer, all base reasonings, and the meta-reasoning 197 | (if aggregation_method is 'meta'). 198 | """ 199 | assert len(base_reasonings) > 0, "base_reasonings must be a non-empty list" 200 | # Extract final prediction from each reasoning 201 | all_base_predictions = [] # list of lists of floats 202 | for base_reasonings_list in base_reasonings: 203 | base_predictions = [ # for one model; list of floats 204 | string_utils.extract_prediction( 205 | reasoning, answer_type=answer_type, end_words=end_words 206 | ) 207 | for reasoning in base_reasonings_list 208 | ] 209 | all_base_predictions.append(base_predictions) 210 | flattened_all_base_predictions = [ 211 | item for sublist in all_base_predictions for item in sublist 212 | ] 213 | if len(flattened_all_base_predictions) == 1: # no aggregation needed 214 | return { 215 | "base_reasonings": base_reasonings, 216 | "base_predictions": all_base_predictions, 217 | "meta_prediction": flattened_all_base_predictions[0], 218 | "meta_prompt": None, 219 | "meta_reasoning": None, 220 | } 221 | if answer_type == "probability" and aggregation_method != "meta": 222 | if aggregation_method == "mean" or aggregation_method is None: 223 | meta_prediction = np.mean(flattened_all_base_predictions) # default to mean 224 | if aggregation_method == "vote-or-median": 225 | meta_prediction = np.median(flattened_all_base_predictions) 226 | if aggregation_method == "weighted-mean": 227 | meta_prediction = np.average( 228 | flattened_all_base_predictions, weights=weights 229 | ) 230 | if meta_prediction is None or meta_prediction < 0.0 or meta_prediction > 1.0: 231 | logger.debug( 232 | "final_answer {} is not between 0 and 1".format(meta_prediction) 233 | ) 234 | meta_prediction = 0.5 # default to 0.5 235 | return { 236 | "base_reasonings": base_reasonings, 237 | "base_predictions": all_base_predictions, 238 | "meta_prediction": meta_prediction, 239 | "meta_prompt": None, 240 | "meta_reasoning": None, 241 | } 242 | elif answer_type == "tokens" and aggregation_method == "vote-or-median": 243 | meta_prediction = utils.most_frequent_item( 244 | flattened_all_base_predictions 245 | ) # majority vote 246 | if meta_prediction is None or not string_utils.is_string_in_list( 247 | meta_prediction, end_words 248 | ): 249 | logger.debug("final_answer {} is not valid".format(meta_prediction)) 250 | meta_prediction = "Slightly Unlikely" # default to "Slightly Unlikely" 251 | return { 252 | "base_reasonings": base_reasonings, 253 | "base_predictions": all_base_predictions, 254 | "meta_prediction": meta_prediction, 255 | "meta_prompt": None, 256 | "meta_reasoning": None, 257 | } 258 | 259 | # If aggregation_method is 'meta', elicit a meta-reasoning using the 260 | # meta_prompt_template 261 | prompt, fields = meta_prompt_template 262 | flattened_base_reasonings = [ 263 | item for sublist in base_reasonings for item in sublist 264 | ] 265 | meta_full_prompt = string_utils.get_prompt( 266 | prompt, 267 | fields, 268 | question=question, 269 | background=background_info, 270 | dates=today_to_close_date_range, 271 | retrieved_info=retrieved_info, 272 | reasoning=concatenate_reasonings(flattened_base_reasonings), 273 | resolution_criteria=resolution_criteria, 274 | ) 275 | meta_reasoning = model_eval.get_response_from_model( 276 | model_name=model_name, 277 | prompt=meta_full_prompt, 278 | temperature=meta_temperature, 279 | ) # raw response 280 | # Extract final prediction from raw response 281 | if answer_type == "probability": 282 | # Get the probability from the meta-reasoning 283 | meta_prediction = string_utils.extract_probability_with_stars(meta_reasoning) 284 | if meta_prediction is None or meta_prediction < 0.0 or meta_prediction > 1.0: 285 | logger.debug( 286 | "final_answer {} is not between 0 and 1".format(meta_prediction) 287 | ) 288 | meta_prediction = 0.5 289 | elif answer_type == "tokens": 290 | # Get the final token answer from the meta-reasoning 291 | meta_prediction = string_utils.find_end_word(meta_reasoning, end_words) 292 | if meta_prediction is None or not string_utils.is_string_in_list( 293 | meta_prediction, end_words 294 | ): 295 | logger.debug("final_answer {} is not valid".format(meta_prediction)) 296 | meta_prediction = "Slightly Unlikely" 297 | 298 | return { 299 | "base_reasonings": base_reasonings, 300 | "base_predictions": all_base_predictions, 301 | "meta_prompt": meta_full_prompt, 302 | "meta_reasoning": meta_reasoning, 303 | "meta_prediction": meta_prediction, 304 | } 305 | 306 | 307 | def calculate_normalized_weighted_trimmed_mean(predictions): 308 | """ 309 | Calculate the normalized weighted trimmed mean of a set of predictions. 310 | 311 | This function performs the following steps: 312 | 1. Compute the median of the predictions to serve as a reference point. 313 | 2. Identify the prediction that is furthest from the median value. 314 | 3. Reduce the weight of the furthest prediction by half, acknowledging its potential outlier status. 315 | 4. Equally distribute the reduced weight from the furthest prediction among the remaining predictions, ensuring the total weight remains constant. 316 | 5. Compute and return the weighted mean of the predictions, using the adjusted weights to account for outliers and variance. 317 | 318 | Parameters: 319 | - predictions (np.ndarray): An array of numerical prediction values. 320 | 321 | Returns: 322 | - float: The normalized weighted trimmed mean of the predictions. 323 | """ 324 | # Step 1: Find the median 325 | median_prediction = np.median(predictions) 326 | 327 | # Step 2: Determine the prediction farthest from the median 328 | distances = np.abs(predictions - median_prediction) 329 | max_distance = np.max(distances) 330 | 331 | # Step 3: Down-weight the furthest prediction by half 332 | weights = np.ones(len(predictions)) 333 | weights[distances == max_distance] *= 0.5 334 | 335 | # Step 4: Distribute the saved weight among other predictions 336 | saved_weight = (1.0 - 0.5) / (len(predictions) - 1) 337 | weights[distances != max_distance] += saved_weight 338 | 339 | # Step 5: Calculate the weighted mean 340 | weighted_mean = np.average(predictions, weights=weights) 341 | 342 | return weighted_mean 343 | --------------------------------------------------------------------------------