├── assets
└── forecasting_system.png
├── notebooks
├── results
│ ├── data
│ │ ├── answers.pickle
│ │ ├── questions_data.pickle
│ │ ├── base_predictions.pickle
│ │ ├── crowd_predictions.pickle
│ │ ├── finetuned_predictions.pickle
│ │ └── finetuned_other_predictions.pickle
│ └── results.ipynb
└── demo
│ └── sample_questions.pickle
├── llm_forecasting
├── prompts
│ ├── system.py
│ ├── alignment.py
│ ├── relevance.py
│ ├── ensemble_reasoning.py
│ ├── base_eval.py
│ ├── prompts.py
│ ├── data_wrangling.py
│ ├── summarization.py
│ └── search_query.py
├── utils
│ ├── logging_utils.py
│ ├── model_utils.py
│ ├── utils.py
│ ├── article_utils.py
│ ├── validation_utils.py
│ ├── api_utils.py
│ ├── metrics_utils.py
│ ├── db_utils.py
│ ├── time_utils.py
│ ├── string_utils.py
│ └── data_utils.py
├── config
│ ├── keys.py
│ └── constants.py
├── alignment.py
├── data_scraping.py
├── evaluation.py
├── summarize.py
├── model_eval.py
└── ensemble.py
├── pyproject.toml
├── scripts
├── fine_tune
│ └── fine_tune.py
├── data_scraping
│ ├── cset.py
│ ├── gjopen.py
│ ├── polymarket.py
│ └── manifold.py
└── training_data
│ └── training_point_generation.py
└── README.md
/assets/forecasting_system.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/assets/forecasting_system.png
--------------------------------------------------------------------------------
/notebooks/results/data/answers.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/notebooks/results/data/answers.pickle
--------------------------------------------------------------------------------
/notebooks/demo/sample_questions.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/notebooks/demo/sample_questions.pickle
--------------------------------------------------------------------------------
/notebooks/results/data/questions_data.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/notebooks/results/data/questions_data.pickle
--------------------------------------------------------------------------------
/notebooks/results/data/base_predictions.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/notebooks/results/data/base_predictions.pickle
--------------------------------------------------------------------------------
/notebooks/results/data/crowd_predictions.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/notebooks/results/data/crowd_predictions.pickle
--------------------------------------------------------------------------------
/notebooks/results/data/finetuned_predictions.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/notebooks/results/data/finetuned_predictions.pickle
--------------------------------------------------------------------------------
/notebooks/results/data/finetuned_other_predictions.pickle:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dannyallover/llm_forecasting/HEAD/notebooks/results/data/finetuned_other_predictions.pickle
--------------------------------------------------------------------------------
/llm_forecasting/prompts/system.py:
--------------------------------------------------------------------------------
1 | SYSTEM_SUPERFORECASTER_0 = """You are an expert superforecaster, familiar with the work of Tetlock and others.
2 | Your mission is to generate accurate predictions for forecasting questions.
3 | Aggregate the information provided by the user. Make sure to give detailed reasonings."""
4 |
--------------------------------------------------------------------------------
/llm_forecasting/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def setup_file_logger(logger, file_name):
5 | """
6 | Set up a custom logger that writes to a file.
7 |
8 | Args:
9 | file_name (str): Name of the file to write to.
10 | logger_name (str): Name of the logger.
11 | """
12 | logger.setLevel(logging.INFO)
13 | file_handler = logging.FileHandler(file_name)
14 | formatter = logging.Formatter(
15 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
16 | )
17 | file_handler.setFormatter(formatter)
18 | logger.addHandler(file_handler)
19 |
--------------------------------------------------------------------------------
/llm_forecasting/config/keys.py:
--------------------------------------------------------------------------------
1 | AWS_ACCESS_KEY = ""
2 | AWS_SECRET_KEY = ""
3 |
4 | METACULUS_KEY = ""
5 | MANIFOLD_KEY = ""
6 | POLYMARKET_KEY = ""
7 | CRYPTO_PRIVATE_KEY = ""
8 | EMAIL = ""
9 | GJOPEN_CSET_PASSWORD = ""
10 |
11 |
12 | NEWSCASTCHER_KEY = ""
13 |
14 | OPENAI_KEY = ""
15 | ANTHROPIC_KEY = ""
16 | TOGETHER_KEY = ""
17 | GOOGLE_AI_KEY = ""
18 | HF_ACCESS_TOKEN = ""
19 |
20 | keys = {
21 | "AWS_ACCESS_KEY": AWS_ACCESS_KEY,
22 | "AWS_SECRET_KEY": AWS_SECRET_KEY,
23 |
24 | "METACULUS_KEY": METACULUS_KEY,
25 | "MANIFOLD_KEY": MANIFOLD_KEY,
26 | "POLYMARKET_KEY": POLYMARKET_KEY,
27 | "CRYPTO_PRIVATE_KEY": CRYPTO_PRIVATE_KEY,
28 | "EMAIL": EMAIL,
29 | "GJOPEN_CSET_PASSWORD": GJOPEN_CSET_PASSWORD,
30 |
31 | "NEWSCASTCHER_KEY": NEWSCASTCHER_KEY,
32 |
33 | "OPENAI_KEY": OPENAI_KEY,
34 | "ANTHROPIC_KEY": ANTHROPIC_KEY,
35 | "TOGETHER_KEY": TOGETHER_KEY,
36 | "GOOGLE_AI_KEY": GOOGLE_AI_KEY,
37 | "HF_ACCESS_TOKEN": HF_ACCESS_TOKEN,
38 | }
39 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "llm_forecasting"
7 | version = "0.0.1"
8 | dependencies = [
9 | "requests==2.26.0",
10 | "pandas==1.5.3",
11 | "numpy==1.24.3",
12 | "scipy",
13 | "matplotlib",
14 | "openai==1.6.1",
15 | "tqdm",
16 | "together",
17 | "torch",
18 | "bardapi",
19 | "tiktoken",
20 | "langchain",
21 | "transformers",
22 | "boto3",
23 | "anthropic",
24 | "gnews==0.3.6",
25 | "newspaper4k",
26 | "newscatcherapi",
27 | "sentence_transformers",
28 | "markdown2",
29 | "google-generativeai",
30 | "jsonlines",
31 | "selenium",
32 | "aws-wsgi==0.2.7",
33 | "python-dotenv",
34 | "google-cloud-secret-manager",
35 | "black",
36 | "black[jupyter]",
37 | "autopep8",
38 | "flake8",
39 | "pytest",
40 | "pytest-asyncio",
41 | "rich",
42 | ]
43 |
44 | [tool.setuptools.packages.find]
45 | where = ["llm_forecasting"]
46 |
--------------------------------------------------------------------------------
/llm_forecasting/prompts/alignment.py:
--------------------------------------------------------------------------------
1 | ALIGNMENT_PROMPT = (
2 | """Question:
3 | {question}
4 |
5 | Background:
6 | {background}
7 |
8 | Resolution Criteria:
9 | {resolution_criteria}
10 |
11 | Model’s Thinking:
12 | {reasoning}
13 |
14 | Task:
15 | Evaluate the alignment between the model's thinking and its prediction. If someone were given the reasoning alone (without the prediction), would they likely arrive at the same prediction?
16 |
17 | Alignment Ratings:
18 | 1 — Very Not Aligned
19 | 2 — Not Aligned
20 | 3 — Slightly Not Aligned
21 | 4 — Slightly Aligned
22 | 5 — Aligned
23 | 6 — Very Aligned
24 |
25 | Please use these ratings to indicate the degree of alignment between the model's reasoning and its prediction.
26 |
27 | Note: If the response indicates that this question is old or it's already been resolved, give it an alignment rating of 1.
28 |
29 | I want your answer to follow this format:
30 |
31 | Thinking: {{ insert your thinking here }}
32 | Rating: {{ insert your alignment rating here (a number between 1 and 6) }}""",
33 | ("QUESTION", "BACKGROUND", "RESOLUTION_CRITERIA", "REASONING"),
34 | )
35 |
--------------------------------------------------------------------------------
/llm_forecasting/prompts/relevance.py:
--------------------------------------------------------------------------------
1 | RELEVANCE_PROMPT_0 = (
2 | """Please consider the following forecasting question and its background information.
3 | After that, I will give you a news article and ask you to rate its relevance with respect to the forecasting question.
4 |
5 | Question:
6 | {question}
7 |
8 | Question Background:
9 | {background}
10 |
11 | Question Resolution Criteria:
12 | {resolution_criteria}
13 |
14 | Article:
15 | {article}
16 |
17 | Please rate the relevance of the article to the question, at the scale of 1-6
18 | 1 -- irrelevant
19 | 2 -- slightly relevant
20 | 3 -- somewhat relevant
21 | 4 -- relevant
22 | 5 -- highly relevant
23 | 6 -- most relevant
24 |
25 | Guidelines:
26 | - You don't need to access any external sources. Just consider the information provided.
27 | - Focus on the content of the article, not the title.
28 | - If the text content is an error message about JavaScript, paywall, cookies or other technical issues, output a score of 1.
29 |
30 | Your response should look like the following:
31 | Thoughts: {{ insert your thinking }}
32 | Rating: {{ insert your rating }}""",
33 | ("QUESTION", "BACKGROUND", "RESOLUTION_CRITERIA", "ARTICLE"),
34 | )
35 |
--------------------------------------------------------------------------------
/llm_forecasting/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | # Related third-party imports
2 | import tiktoken
3 |
4 | # Local application/library specific imports
5 | from config.constants import OAI_SOURCE, MODEL_NAME_TO_SOURCE
6 |
7 |
8 | def count_tokens(text, model_name):
9 | """
10 | Count the number of tokens for a given text.
11 |
12 | Args:
13 | - text (str): The input text whose tokens need to be counted.
14 | - model_name (str): Name of the OpenAI model to be used for token counting.
15 |
16 | Returns:
17 | - int: Number of tokens in the text for the specified model.
18 | """
19 | model_source = infer_model_source(model_name)
20 | if model_source == OAI_SOURCE:
21 | enc = tiktoken.encoding_for_model(model_name)
22 | token_length = len(enc.encode(text))
23 | else:
24 | token_length = len(text) / 3
25 |
26 | return token_length
27 |
28 |
29 | def infer_model_source(model_name):
30 | """
31 | Infer the model source from the model name.
32 |
33 | Args:
34 | - model_name (str): The name of the model.
35 | """
36 | if "ft:gpt" in model_name: # fine-tuned GPT-3 or 4
37 | return OAI_SOURCE
38 | if model_name not in MODEL_NAME_TO_SOURCE:
39 | raise ValueError(f"Invalid model name: {model_name}")
40 | return MODEL_NAME_TO_SOURCE[model_name]
41 |
--------------------------------------------------------------------------------
/llm_forecasting/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | from collections import Counter
3 |
4 |
5 | def flatten_list(nested_list):
6 | flat_list = [item for sublist in nested_list for item in sublist]
7 | return flat_list
8 |
9 |
10 | def most_frequent_item(lst):
11 | """
12 | Return the most frequent item in the given list.
13 |
14 | If there are multiple items with the same highest frequency, one of them is
15 | returned.
16 |
17 | Args:
18 | lst (list): The list from which to find the most frequent item.
19 |
20 | Returns:
21 | The most frequent item in the list.
22 | """
23 | if not lst:
24 | return None # Return None if the list is empty
25 | # Count the frequency of each item in the list
26 | count = Counter(lst)
27 | # Find the item with the highest frequency
28 | most_common = count.most_common(1)
29 | return most_common[0][0] # Return the item (not its count)
30 |
31 |
32 | def indices_of_N_largest_numbers(list_of_numbers, N=3):
33 | """
34 | Return the indices of the N largest numbers in the given list of numbers.
35 |
36 | Args:
37 | list_of_numbers (list): The list of numbers from which to find the N
38 | largest numbers.
39 | N (int, optional): The number of largest numbers to find. Defaults to 3.
40 |
41 | Returns:
42 | list: The indices of the N largest numbers in the given list of numbers.
43 | """
44 | # Get the indices of the N largest numbers
45 | indices = sorted(range(len(list_of_numbers)), key=lambda i: list_of_numbers[i])[-N:]
46 | return indices
47 |
--------------------------------------------------------------------------------
/llm_forecasting/utils/article_utils.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | from datetime import datetime
3 |
4 | # Local application/library-specific imports
5 | from utils import db_utils
6 | from config.keys import keys
7 | from config.constants import S3_BUCKET_NAME, S3
8 |
9 | # Set up constants
10 | AWS_ACCESS_KEY = keys["AWS_ACCESS_KEY"]
11 | AWS_SECRET_KEY = keys["AWS_SECRET_KEY"]
12 |
13 |
14 | def article_object_to_dict(article):
15 | """
16 | Convert an article object to a dictionary
17 |
18 | Args:
19 | article (Article): An article object (such as NewscatcherArticle)
20 |
21 | Returns:
22 | article_dict (dict): A dictionary containing the article's attributes
23 | such as title, text, authors, etc.
24 | """
25 | article_dict = {}
26 | for attribute in article.__dict__:
27 | field = getattr(article, attribute)
28 | if (
29 | isinstance(field, str) # title, text, etc
30 | or isinstance(field, int)
31 | or isinstance(field, float) # relevance ratings
32 | or isinstance(field, list) # authors list
33 | ):
34 | article_dict[attribute] = field
35 | if isinstance(field, datetime): # datetime, etc
36 | article_dict[attribute] = field.strftime("%Y-%m-%d")
37 | return article_dict
38 |
39 |
40 | def article_object_list_to_dict(article_list):
41 | """
42 | Convert a list of article objects to a list of dictionaries
43 | """
44 | return [article_object_to_dict(article) for article in article_list]
45 |
46 |
47 | def upload_articles_to_s3(article_list, s3_path="system/info-hp"):
48 | """
49 | Upload a list of articles to S3
50 |
51 | Args:
52 | article_list (list): A list of article objects (such as NewscatcherArticle)
53 | s3_path (str): The path to save the articles to in S3
54 |
55 | Returns:
56 | None
57 | """
58 | articles_dict = article_object_list_to_dict(article_list)
59 | s3_filename = f"{s3_path}/articles.pickle"
60 | db_utils.upload_data_structure_to_s3(
61 | s3=S3, data_structure=articles_dict, bucket=S3_BUCKET_NAME, s3_path=s3_filename
62 | )
63 |
--------------------------------------------------------------------------------
/llm_forecasting/utils/validation_utils.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | import logging
3 | import sys
4 |
5 | # Related third-party imports
6 | import openai
7 |
8 | # Set up logging
9 | logging.basicConfig(level=logging.INFO)
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def is_valid_openai_key(api_key):
14 | """
15 | Check if the given OpenAI API key is valid.
16 |
17 | Args:
18 | api_key (str): OpenAI API key to be validated.
19 |
20 | Returns:
21 | bool: Whether the given API key is valid.
22 | """
23 | try:
24 | # Set the API key
25 | openai.api_key = api_key
26 | # Make a test request (e.g., listing available models)
27 | openai.Model.list()
28 | # If the above line didn't raise an exception, the key is valid
29 | return True
30 | except openai.error.AuthenticationError:
31 | logger.error("Invalid API key.")
32 | return False
33 | except openai.error.OpenAIError as e:
34 | logger.error(f"An error occurred in validating OpenAI API key: {str(e)}")
35 | return False
36 |
37 |
38 | def is_valid_openai_model(model_name, api_key):
39 | """
40 | Check if the model name is valid, given a valid OpenAI API key.
41 |
42 | Args:
43 | - model_name (str): Name of the model to be validated, such as "gpt-4"
44 | - api_key (str): OpenAI API key, assumed to be valid.
45 |
46 | Returns:
47 | - bool: Whether the given model name is valid.
48 | """
49 | try:
50 | openai.api_key = api_key
51 | # Attempt to retrieve information about the model
52 | _ = openai.Model.retrieve(model_name)
53 | return True
54 | except openai.error.OpenAIError as e:
55 | logger.error(f"An error occurred in validing the model name: {str(e)}")
56 | return False
57 |
58 |
59 | def validate_key_and_model(key, model):
60 | """
61 | Check if the given OpenAI API key and model name are valid.
62 |
63 | Args:
64 | - key (str): OpenAI API key to be validated.
65 | - model (str): Name of the model to be validated, such as "gpt-4"
66 |
67 | If either the key or model is invalid, exit the program.
68 | """
69 | if not is_valid_openai_key(key) or not is_valid_openai_model(model, key):
70 | sys.exit(1)
71 |
--------------------------------------------------------------------------------
/llm_forecasting/prompts/ensemble_reasoning.py:
--------------------------------------------------------------------------------
1 | ENSEMBLE_PROMPT_0 = (
2 | """I need your assistance with making a forecast. Here is the question and its metadata.
3 | Question: {question}
4 |
5 | Background: {background}
6 |
7 | Resolution criteria: {resolution_criteria}
8 |
9 | Today's date: {date_begin}
10 | Question close date: {date_end}
11 |
12 | I have retrieved the following information about this question.
13 | Retrieved Info:
14 | {retrieved_info}
15 |
16 | In addition, I have generated a collection of other responses and reasonings from other forecasters:
17 | {base_reasonings}
18 |
19 | Your goal is to aggregate the information and make a final prediction.
20 |
21 | Instructions:
22 | 1. Provide reasons why the answer might be no.
23 | {{ Insert your thoughts here }}
24 |
25 | 2. Provide reasons why the answer might be yes.
26 | {{ Insert your thoughts here }}
27 |
28 | 3. Aggregate your considerations.
29 | {{ Insert your aggregated considerations here }}
30 |
31 | 4. Output your prediction (a number between 0 and 1) with an asterisk at the beginning and end of the decimal.
32 | {{ Insert the probability here }}""",
33 | (
34 | "QUESTION",
35 | "BACKGROUND",
36 | "RESOLUTION_CRITERIA",
37 | "RETRIEVED_INFO",
38 | "DATES",
39 | "BASE_REASONINGS",
40 | ),
41 | )
42 |
43 |
44 | ENSEMBLE_PROMPT_1 = (
45 | """I need your assistance with making a forecast. Here is the question and its metadata.
46 | Question: {question}
47 |
48 | Background: {background}
49 |
50 | Resolution criteria: {resolution_criteria}
51 |
52 | Today's date: {date_begin}
53 | Question close date: {date_end}
54 |
55 | I have retrieved the following information about this question.
56 | Retrieved Info:
57 | {retrieved_info}
58 |
59 | In addition, I have generated a collection of other responses and reasonings from other forecasters:
60 | {base_reasonings}
61 |
62 | Your goal is to aggregate the information and make a final prediction.
63 |
64 | Instructions:
65 | 1. Think step by step: {{ Insert your step by step consideration }}
66 |
67 | 2. Aggregate your considerations: {{ Aggregate your considerations }}
68 |
69 | 3. Output your prediction (a number between 0 and 1) with an asterisk at the beginning and end of the decimal.
70 | {{ Insert the probability here }}""",
71 | (
72 | "QUESTION",
73 | "BACKGROUND",
74 | "RESOLUTION_CRITERIA",
75 | "RETRIEVED_INFO",
76 | "DATES",
77 | "BASE_REASONINGS",
78 | ),
79 | )
80 |
--------------------------------------------------------------------------------
/llm_forecasting/alignment.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | import logging
3 |
4 | # Local application/library-specific imports
5 | import model_eval
6 | import ranking
7 | from utils import string_utils
8 | from prompts.prompts import PROMPT_DICT
9 |
10 | # Set up logging
11 | logging.basicConfig(level=logging.INFO)
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | def get_alignment_scores(
16 | reasonings,
17 | alignment_prompt=PROMPT_DICT["alignment"]["0"],
18 | model_name="gpt-3.5-turbo-1106",
19 | temperature=0,
20 | question=None,
21 | background=None,
22 | resolution_criteria=None,
23 | ):
24 | """
25 | Compute the alignment score of each reasoning for each model.
26 |
27 | The alignment score assesses if the reasoning is consistent with the
28 | model's prediction, i.e. if one were given the reasoning alone, would she
29 | also predict a similar probability.
30 |
31 | Args:
32 | reasonings (list[list[str]]): A list containing a list of reasonings.
33 | alignment_prompt(dict, optional): Alignment prompt to use.
34 | model_name (str, optional): Model used to compute score.
35 | question (str, optional): Forecasting question.
36 | background (str, optional): Background of question.
37 | resolution_criteria(str, optional): Resolution criteria of question.
38 |
39 | Returns:
40 | list[list[int]]: A list containing a list of scores.
41 | """
42 | alignment_scores = []
43 | for model_reasonings in reasonings:
44 | alignment_scores_ = []
45 | for reasoning in model_reasonings:
46 | prompt = string_utils.get_prompt(
47 | alignment_prompt[0],
48 | alignment_prompt[1],
49 | question=question,
50 | background=background,
51 | resolution_criteria=resolution_criteria,
52 | reasoning=reasoning,
53 | )
54 | try:
55 | alignment_response = model_eval.get_response_from_model(
56 | model_name=model_name,
57 | prompt=prompt,
58 | max_tokens=2000,
59 | temperature=temperature,
60 | )
61 | alignment_score = ranking.extract_rating_from_response(
62 | alignment_response
63 | )
64 | alignment_scores_.append(alignment_score)
65 | except Exception as e:
66 | logger.error(f"Error message: {e}")
67 | logger.info("Failed to calculate alignment score")
68 | continue
69 | alignment_scores.append(alignment_scores_)
70 | return alignment_scores
71 |
--------------------------------------------------------------------------------
/llm_forecasting/prompts/base_eval.py:
--------------------------------------------------------------------------------
1 | KNOWLEDGE_PROMPT = (
2 | """Question: {question}
3 |
4 | The question was posed on {date_begin} and closed on {date_end}.
5 | Instructions:
6 | - Please output "1" if the answer is Yes", "0" if the answer is No or "IDK" if you don't know the answer. Do not return anything else.
7 | - Do not guess.
8 | Answer: {{ Insert answer here }}""",
9 | ("QUESTION", "DATES"),
10 | )
11 |
12 | GEMINI_KNOLWEDGE_PROMPT = (
13 | """Here was a forecasting question posed in {date_begin} and resolved in {date_end}.
14 | Question: {question} (Only Output: 1 (Yes), 0 (No), or IDK if you don't know the answer.)
15 | Answer: {{ Insert answer here }} """,
16 | ("QUESTION", "DATES"),
17 | )
18 |
19 | ZERO_SHOT_PROMPT = (
20 | """You are an expert superforecaster, familiar with the work of Tetlock and others. Make a prediction of the probability that the question will be resolved as true. You MUST give a probability estimate between 0 and 1 UNDER ALL CIRCUMSTANCES. If for some reason you can't answer, pick the base rate, but return a number between 0 and 1.
21 |
22 | Question:
23 | {question}
24 |
25 | Question Background:
26 | {background}
27 |
28 | Resolution Criteria:
29 | {resolution_criteria}
30 |
31 | Today's date: {date_begin}
32 | Question close date: {date_end}
33 |
34 | Output your answer (a number between 0 and 1) with an asterisk at the beginning and end of the decimal. Do not output anything else.
35 | Answer: {{ Insert answer here }}""",
36 | ("QUESTION", "BACKGROUND", "RESOLUTION_CRITERIA", "DATES"),
37 | )
38 |
39 | BINARY_SCRATCH_PAD_PROMPT = (
40 | """Question:
41 | {question}
42 |
43 | Question Background:
44 | {background}
45 |
46 | Resolution Criteria:
47 | {resolution_criteria}
48 |
49 | Today's date: {date_begin}
50 | Question close date: {date_end}
51 |
52 |
53 | Instructions:
54 | 1. Provide reasons why the answer might be no.
55 | {{ Insert your thoughts }}
56 |
57 | 2. Provide reasons why the answer might be yes.
58 | {{ Insert your thoughts }}
59 |
60 | 3. Aggregate your considerations.
61 | {{ Insert your aggregated considerations }}
62 |
63 | 4. Output your answer (a number between 0 and 1) with an asterisk at the beginning and end of the decimal.
64 | {{ Insert your answer }}""",
65 | ("QUESTION", "BACKGROUND", "RESOLUTION_CRITERIA", "DATES"),
66 | )
67 |
68 | GEMINI_BINARY_SCRATCH_PAD_PROMPT = (
69 | """Question:
70 | {question}
71 |
72 | Question Background:
73 | {background}
74 |
75 | Resolution Criteria:
76 | {resolution_criteria}
77 |
78 | Today's date: {date_begin}
79 | Question close date: {date_end}
80 |
81 | Output why the answer might be no, why the answer might be yes, aggredate your considerations, and then your answer (a number between 0 and 1) with an asterisk at the beginning and end of the decimal.""",
82 | ("QUESTION", "BACKGROUND", "RESOLUTION_CRITERIA", "DATES"),
83 | )
84 |
--------------------------------------------------------------------------------
/scripts/fine_tune/fine_tune.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | import logging
3 | import json
4 | import openai
5 | from openai import OpenAI
6 |
7 | # Set up logging
8 | logging.basicConfig(level=logging.INFO)
9 | logger = logging.getLogger(__name__)
10 |
11 | client = OpenAI(
12 | api_key="",
13 | organization="",
14 | )
15 |
16 |
17 | def create_jsonl_for_finetuning(training_data, file_path):
18 | """
19 | Writes training data to a JSONL file for fine-tuning purposes.
20 |
21 | Args:
22 | training_data (list): A list of tuples containing user and assistant messages.
23 | file_path (str): Path to save the JSONL file.
24 |
25 | Returns:
26 | None: Logs the completion of file writing.
27 | """
28 | with open(file_path, "w") as jsonl_file:
29 | for user, assistant in training_data:
30 | message = {
31 | "messages": [
32 | {"role": "user", "content": user},
33 | {"role": "assistant", "content": assistant},
34 | ]
35 | }
36 | example = json.dumps(message)
37 | jsonl_file.write(example + "\n")
38 | logger.info(f"|training_data| saved to {file_path} as jsonl")
39 | return None
40 |
41 |
42 | def upload_oai_file(file_path):
43 | """
44 | Uploads a file to OpenAI API for fine-tuning.
45 |
46 | Args:
47 | file_path (str): Path of the file to be uploaded.
48 |
49 | Returns:
50 | object: Returns a file object as a response from OpenAI API.
51 | """
52 | with open(file_path, "rb") as file_data:
53 | file_obj = client.files.create(file=file_data, purpose="fine-tune")
54 |
55 | logging.info(f"Uploaded dataset {file_path} to OpenAI API.")
56 | return file_obj
57 |
58 |
59 | def create_oai_finetuning_job(model_name, train_file_id, model_suffix):
60 | """
61 | Creates a fine-tuning job with OpenAI.
62 |
63 | Args:
64 | model_name (str): Name of the base model to fine-tune.
65 | train_file_id (str): ID of the training file uploaded to OpenAI.
66 | model_suffix (str): Suffix to be added to the fine-tuned model.
67 |
68 | Returns:
69 | object: Returns a fine-tuning job object.
70 | """
71 | model = client.fine_tuning.jobs.create(
72 | model=model_name,
73 | training_file=train_file_id,
74 | suffix=model_suffix,
75 | )
76 | return model
77 |
78 |
79 | def check_on_finetuning_job(ft_id):
80 | """
81 | Checks the status of a fine-tuning job.
82 |
83 | Args:
84 | ft_id (str): ID of the fine-tuning job.
85 |
86 | Returns:
87 | None: Fetches and logs the latest events of the fine-tuning job.
88 | """
89 | # Use this to check on the output of your job
90 | openai.FineTuningJob.list_events(id=ft_id, limit=5)
91 | return None
92 |
--------------------------------------------------------------------------------
/llm_forecasting/prompts/prompts.py:
--------------------------------------------------------------------------------
1 | from prompts.relevance import *
2 | from prompts.base_reasoning import *
3 | from prompts.search_query import *
4 | from prompts.summarization import *
5 | from prompts.ensemble_reasoning import *
6 | from prompts.alignment import *
7 | from prompts.system import *
8 | from prompts.data_wrangling import *
9 | from prompts.base_eval import *
10 |
11 | PROMPT_DICT = {
12 | "binary": {
13 | "scratch_pad": {
14 | "0": BINARY_SCRATCH_PAD_PROMPT_0,
15 | "1": BINARY_SCRATCH_PAD_PROMPT_1,
16 | "2": BINARY_SCRATCH_PAD_PROMPT_2,
17 | "2_tok": BINARY_SCRATCH_PAD_PROMPT_2_TOKENS,
18 | "3": BINARY_SCRATCH_PAD_PROMPT_3,
19 | "new_0": BINARY_SCRATCH_PAD_PROMPT_NEW_0,
20 | "new_1": BINARY_SCRATCH_PAD_PROMPT_NEW_1,
21 | "new_2": BINARY_SCRATCH_PAD_PROMPT_NEW_2,
22 | "new_3": BINARY_SCRATCH_PAD_PROMPT_NEW_3,
23 | "new_4": BINARY_SCRATCH_PAD_PROMPT_NEW_4,
24 | "new_5": BINARY_SCRATCH_PAD_PROMPT_NEW_5,
25 | "new_6": BINARY_SCRATCH_PAD_PROMPT_NEW_6,
26 | "new_7": BINARY_SCRATCH_PAD_PROMPT_NEW_7,
27 | },
28 | "rar": {
29 | "0": RAR_PROMPT_0,
30 | "1": BINARY_SCRATCH_PAD_PROMPT_1_RAR,
31 | "2": BINARY_SCRATCH_PAD_PROMPT_2_RAR,
32 | },
33 | "emotion": {
34 | "0": EMOTION_PROMPT_0,
35 | },
36 | },
37 | "ranking": {
38 | "0": RELEVANCE_PROMPT_0,
39 | },
40 | "alignment": {
41 | "0": ALIGNMENT_PROMPT,
42 | },
43 | "search_query": {
44 | "0": SEARCH_QUERY_PROMPT_0,
45 | "1": SEARCH_QUERY_PROMPT_1,
46 | "2": SEARCH_QUERY_PROMPT_2,
47 | "3": SEARCH_QUERY_PROMPT_3,
48 | "4": SEARCH_QUERY_PROMPT_4,
49 | "5": SEARCH_QUERY_PROMPT_5,
50 | "6": SEARCH_QUERY_PROMPT_6,
51 | "7": SEARCH_QUERY_PROMPT_7,
52 | "8": SEARCH_QUERY_PROMPT_8,
53 | },
54 | "summarization": {
55 | "0": SUMMARIZATION_PROMPT_0,
56 | "1": SUMMARIZATION_PROMPT_1,
57 | "2": SUMMARIZATION_PROMPT_2,
58 | "3": SUMMARIZATION_PROMPT_3,
59 | "4": SUMMARIZATION_PROMPT_4,
60 | "5": SUMMARIZATION_PROMPT_5,
61 | "6": SUMMARIZATION_PROMPT_6,
62 | "7": SUMMARIZATION_PROMPT_7,
63 | "8": SUMMARIZATION_PROMPT_8,
64 | "9": SUMMARIZATION_PROMPT_9,
65 | "10": SUMMARIZATION_PROMPT_10,
66 | "11": SUMMARIZATION_PROMPT_11,
67 | },
68 | "meta_reasoning": {
69 | "0": ENSEMBLE_PROMPT_0,
70 | "1": ENSEMBLE_PROMPT_1,
71 | },
72 | "system": {
73 | "0": SYSTEM_SUPERFORECASTER_0,
74 | },
75 | "data_wrangling": {
76 | "is_bad_title": IS_BAD_TITLE_PROMPT,
77 | "reformat": REFORMAT_PROMPT,
78 | "assign_category": ASSIGN_CATEGORY_PROMPT,
79 | },
80 | "base_eval": {
81 | "knowledge": KNOWLEDGE_PROMPT,
82 | "gemini_knowledge": GEMINI_KNOLWEDGE_PROMPT,
83 | "zero_shot": ZERO_SHOT_PROMPT,
84 | "scratch_pad": BINARY_SCRATCH_PAD_PROMPT,
85 | "gemini_scratch_pad": GEMINI_BINARY_SCRATCH_PAD_PROMPT,
86 | },
87 | }
88 |
--------------------------------------------------------------------------------
/llm_forecasting/prompts/data_wrangling.py:
--------------------------------------------------------------------------------
1 | ASSIGN_CATEGORY_PROMPT = (
2 | """Question: {question}
3 |
4 | Background: {background}
5 |
6 | Options:
7 | ['Science & Tech',
8 | 'Healthcare & Biology',
9 | 'Economics & Business',
10 | 'Environment & Energy',
11 | 'Politics & Governance',
12 | 'Education & Research',
13 | 'Arts & Recreation',
14 | 'Security & Defense',
15 | 'Social Sciences',
16 | 'Sports',
17 | 'Other']
18 |
19 | Instruction: Assign a category for the given question.
20 |
21 | Rules:
22 | 1. Make sure you only return one of the options from the option list.
23 | 2. Only output the category, and do not output any other words in your response.
24 | 3. You have to pick a string from the above categories.
25 |
26 | Answer:""",
27 | ("QUESTION", "BACKGROUND"),
28 | )
29 |
30 | IS_BAD_TITLE_PROMPT = (
31 | """I'm trying to assess the quality of an old forecasting dataset.
32 |
33 | Here is a forecasting question from the dataset: {question}.
34 |
35 | Please flag questions that don't sound like binary forecasting questions by outputting "flag". If it sounds like a reasonable question, output "ok".
36 |
37 | Examples of strings that should be flagged:
38 | "Will I finish my homework tonight?"
39 | "Metaculus party 2023"
40 | "Will Hell freeze over?"
41 | "Heads or tails"
42 | "Will this video reach 100k views by the EOD?"
43 | Examples of strings that should not be flagged:
44 | "Will Megan Markle and Prince Harry have a baby by the end of the year?"
45 | "Will the Brain Preservation Foundation's Large Mammal preservation prize be won by Feb 9th, 2017?"
46 | "Will there be more novel new drugs approved by the FDA in 2016 than in 2015?"
47 |
48 | If a question is already resolved, that doesn't mean it should be flagged. When in doubt, mark it as "ok".
49 |
50 | Your response should take the following structure:
51 | Insert thinking:
52 | {{ insert your concise thoughts here }}
53 | Classification:
54 | {{ insert "flag" or "ok"}}""",
55 | ("QUESTION", "BACKGROUND"),
56 | )
57 |
58 | REFORMAT_PROMPT = (
59 | """I have questions that need to be transformed for clarity.
60 |
61 | Here are some examples:
62 | Example 1:
63 | Before: Who will win the 2022-2023 Premier League? (Leicester City)
64 | After: *Will Leicester City win the 2022-2023 Premier League?*
65 |
66 | Example 2:
67 | Before: What coalition will govern Berlin after the 2023 repeat state election? (SPD+Greens)
68 | After: *Will SPD+Greens govern Berlin after the 2023 repeat state election?*
69 |
70 | Example 3:
71 | Before: If Republicans win control of the House of Representatives in the 2022 election, who will be the next Majority Whip of the U.S. House of Representatives? (Rep. Jim Banks)
72 | After: *If Republicans win control of the House of Representatives in the 2022 election, will Jim Banks be the next Majority Whip of the U.S. House of Representatives?*
73 |
74 | Example 4:
75 | Before: Economic Trouble: Will a country’s currency depreciate 15% or more in the second half of 2022? (Thai Baht ฿)
76 | After: *Economic Trouble: Will the Thai Baht ฿ currency depreciate 15% or more in the second half of 2022?*
77 |
78 | Example 5:
79 | Before: How many of the claims from study 3 of "Behavioral nudges reduce failure to appear for court" (Science, 2020) replicate? After: *Will exactly 2 claims from study 3 of "Behavioral nudges reduce failure to appear for court" (Science, 2020) replicate?*
80 |
81 | Can you now transform this question for clarity: {question}
82 |
83 | Please place stars around the transformed question.
84 |
85 | Your output should take the following structure:
86 | Before: {insert the original question}
87 | After: *{insert the transformed question}*""",
88 | ("QUESTION",),
89 | )
90 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LLM Forecasting
2 |
3 | This repository hosts the code for the paper: [Approaching Human-Level Forecasting with Language Models](https://arxiv.org/abs/2402.18563).
4 |
5 |
6 |
7 |
8 |
9 | Our system is designed to make automated, _simulated_ forecasts by following these steps:
10 | 1. **Search Query Generation**: A language model (LM) is prompted to create search queries to retrieve articles published before a certain date from a news API.
11 | 2. **Assess Article Relevancy**: A LM rates the relevancy of the retrieved articles and filters out non-relevant ones.
12 | 3. **Summarize Articles**: A LM is prompted to retain the salient information relevant to the question from the filtered articles.
13 | 4. **Reason and Predict**: A LM (base or fine-tuned) is prompted multiple times to produce reasoning and predictions based on the article summaries.
14 | 5. **Forecast Aggregation**: An aggregation method is applied to all the predictions to obtain a final forecast.
15 |
16 | We've designed our system to be easily scalable to other news APIs and language models.
17 |
18 | ## Development Guide
19 | This guide will help you set up your development environment for the project. Make sure you follow these steps under the root directory of the project (where **pyproject.toml** is located).
20 |
21 | ### Prerequisites
22 | **Miniconda or Anaconda**:
23 | Ensure you have Miniconda or Anaconda installed as they include the Conda package manager.
24 |
25 | ### Setting Up Your Environment
26 | 1. **Create and Activate a Conda Environment**
27 | Start by creating a Conda environment specifically for this project.
28 | ```
29 | conda create -n myenv python=3.11
30 | conda activate myenv
31 | ```
32 | Replace **myenv** with your preferred environment name and adjust **python=3.11** if needed.
33 |
34 |
35 | 2. **Install Setuptools and Upgrade Build**
36 | Within your Conda environment:
37 | ```
38 | conda install setuptools
39 | pip install --upgrade build
40 | ```
41 |
42 | 3. **Install Package in Editable Mode**
43 | Install the package in editable mode to allow changes in your code to be immediately reflected in the installed package. This also installs all dependencies listed in `project.toml`.
44 | ```
45 | pip install --editable .
46 | ```
47 |
48 | ### Usage
49 | Once the setup is complete, you can start using the package. For example, you can import modules from the forecasting subdirectory as follows:
50 | ```
51 | import information_retrieval
52 | import utils
53 | from prompts.prompts import prompt_dict
54 | ```
55 |
56 | ### Deactivating Your Environment
57 | When you're finished working, you can deactivate the Conda environment:
58 | ```
59 | conda deactivate
60 | ```
61 |
62 | ## Demo
63 | See the [system_demo](https://github.com/dannyallover/llm_forecasting/blob/main/notebooks/demo/system_demo.ipynb) for an example on how to run our system.
64 |
65 | ## Dataset
66 | The cleaned and formatted dataset can be found on [huggingface](https://huggingface.co/datasets/YuehHanChen/forecasting), as well as the [raw dataset](https://huggingface.co/datasets/YuehHanChen/forecasting_raw).
67 |
68 | ## Contributing
69 | We welcome contributions to this repository. If you'd like to contribute, please follow these steps:
70 |
71 | 1. Fork the repository.
72 | 2. Create a new branch for your feature or bug fix: `git checkout -b my-new-feature`
73 | 3. Make your changes and commit them: `git commit -am 'Add some feature'`
74 | 4. Push your changes to your forked repository: `git push origin my-new-feature`
75 | 5. Open a pull request against the main repository.
76 |
77 | Please ensure that your code adheres to the project's coding conventions and that you include tests for any new functionality or bug fixes.
78 |
79 | ## Citation
80 | If our codebase, dataset, or paper proves fruitful for your work, please consider citing us.
81 | ```
82 | @misc{halawi2024approaching,
83 | title={Approaching Human-Level Forecasting with Language Models},
84 | author={Danny Halawi and Fred Zhang and Chen Yueh-Han and Jacob Steinhardt},
85 | year={2024},
86 | eprint={2402.18563},
87 | archivePrefix={arXiv},
88 | primaryClass={cs.LG}
89 | }
90 | ```
91 |
--------------------------------------------------------------------------------
/llm_forecasting/utils/api_utils.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | import logging
3 | import time
4 |
5 | # Related third-party imports
6 | import requests
7 |
8 | # Set up logging
9 | logging.basicConfig(level=logging.INFO)
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | def request_with_retries(
14 | method, url, headers, params=None, data=None, max_retries=5, delay=30
15 | ):
16 | """
17 | Make an API request (GET or POST) with retries in case of rate-limiting
18 | (HTTP 429) or other defined conditions and return the JSON content or log
19 | an error and return None.
20 |
21 | Args:
22 | method (str): HTTP method ('GET' or 'POST').
23 | url (str): The API endpoint.
24 | headers (dict): Headers for the API request.
25 | params (dict, optional): Parameters for the API request.
26 | data (dict, optional): JSON data for the API request (used for POST).
27 | max_retries (int, optional): Maximum number of retries. Defaults to 5.
28 | delay (int, optional): Delay (in seconds) between retries. Defaults to
29 | 30.
30 |
31 | Returns:
32 | dict or None: The JSON response content as a dictionary or None if an
33 | error occurred.
34 | """
35 | for _ in range(max_retries):
36 | try:
37 | if method == "GET":
38 | response = requests.get(url, headers=headers, params=params)
39 | elif method == "POST":
40 | response = requests.post(url, headers=headers, json=data)
41 | else:
42 | logging.error(f"Unsupported method: {method}")
43 | return None
44 |
45 | if response.status_code == 429:
46 | time.sleep(delay)
47 | continue
48 |
49 | response.raise_for_status()
50 | return response.json()
51 |
52 | except requests.RequestException as e:
53 | logging.error(f"Request error: {e}")
54 | return None
55 |
56 | logging.error(f"Exceeded max retries for URL: {url}")
57 | return None
58 |
59 |
60 | def get_response_content(url, headers, params=None, max_retries=5, delay=30):
61 | """
62 | Create a wrapper function that issues a GET API request, utilizing a
63 | generic retry mechanism.
64 | """
65 | return request_with_retries(
66 | "GET", url, headers, params=params, max_retries=max_retries, delay=delay
67 | )
68 |
69 |
70 | def post_request_with_retries(endpoint, headers, payload, retries=5):
71 | """
72 | Create a wrapper function that makes a POST API request using the generic
73 | retry mechanism.
74 | """
75 | response = request_with_retries(
76 | "POST", endpoint, headers, data=payload, max_retries=retries
77 | )
78 | if (
79 | response
80 | and "detail" in response
81 | and "Expected available in" in response["detail"]
82 | ):
83 | wait_seconds = int(response["detail"].split(" ")[-2]) + 1
84 | time.sleep(wait_seconds)
85 | return request_with_retries(
86 | "POST", endpoint, headers, data=payload, max_retries=retries
87 | )
88 | return response
89 |
90 |
91 | def fetch_all_questions(base_url, headers, params):
92 | """
93 | Fetch all questions from the API using pagination.
94 |
95 | Args:
96 | - base_url (str): The base URL of the API.
97 | - headers (dict): The headers to use for the requests.
98 | - params (dict): The parameters to use for the requests.
99 |
100 | Returns:
101 | - list: List of all questions fetched from the API.
102 | """
103 | all_questions = []
104 |
105 | current_url = base_url
106 | while current_url:
107 | logging.info(f"Fetching data from {current_url} with params: {params}")
108 |
109 | data = get_response_content(current_url, headers, params)
110 | if not data:
111 | break
112 |
113 | if "results" in data:
114 | all_questions.extend(data["results"])
115 | else:
116 | all_questions.append(data)
117 |
118 | current_url = data.get("next")
119 |
120 | return all_questions
121 |
--------------------------------------------------------------------------------
/llm_forecasting/utils/metrics_utils.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | import logging
3 |
4 | # Related third-party imports
5 | import numpy as np
6 | from numpy.linalg import norm
7 | import torch
8 |
9 | # Local application/library-specific imports
10 | from utils import time_utils
11 |
12 | # Set up logging
13 | logging.basicConfig(level=logging.INFO)
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def brier_score(probabilities, answer_idx):
18 | """
19 | Calculate the Brier score for a set of probabilities and the correct answer
20 | index.
21 |
22 | Args:
23 | - probabilities (numpy array): The predicted probabilities for each class.
24 | - answer_idx (int): Index of the correct answer.
25 |
26 | Returns:
27 | - float: The Brier score.
28 | """
29 | answer = np.zeros_like(probabilities)
30 | answer[answer_idx] = 1
31 | return ((probabilities - answer) ** 2).sum() / 2
32 |
33 |
34 | def calculate_cosine_similarity_bert(text_list, tokenizer, model):
35 | """
36 | Calculate the average cosine similarity between texts in a given list using
37 | embeddings.
38 |
39 | Define Bert outside the function:
40 | from transformers import BertTokenizer, BertModel
41 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
42 | model = BertModel.from_pretrained('bert-base-uncased')
43 |
44 | Parameters:
45 | text_list (List[str]): A list of strings where each string is a text
46 | document.
47 |
48 | Returns:
49 | float: The average cosine similarity between each pair of texts in the list.
50 | Returns 0 if the list contains less than two text documents.
51 | """
52 | if len(text_list) < 2:
53 | return 0
54 |
55 | # Function to get embeddings
56 | def get_embedding(text):
57 | inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
58 | outputs = model(**inputs)
59 | return torch.mean(outputs.last_hidden_state, dim=1)
60 |
61 | # Generating embeddings for each text
62 | embeddings = [get_embedding(text) for text in text_list]
63 |
64 | # Calculating cosine similarity between each pair of embeddings
65 | similarity_scores = []
66 | for i in range(len(embeddings)):
67 | for j in range(i + 1, len(embeddings)):
68 | similarity = torch.nn.functional.cosine_similarity(
69 | embeddings[i], embeddings[j]
70 | )
71 | similarity_scores.append(similarity.item())
72 |
73 | # Calculating average similarity
74 | average_similarity = np.mean(similarity_scores)
75 |
76 | return average_similarity
77 |
78 |
79 | def cosine_similarity(u, v):
80 | """
81 | Compute the cosine similarity between two vectors.
82 | """
83 | return np.dot(u, v) / (norm(u) * norm(v))
84 |
85 |
86 | def get_average_forecast(date_pred_list):
87 | """
88 | Retrieve the average forecast value from the list of predictions.
89 |
90 | Args:
91 | - date_pred_list (list of tuples): list contain tuples of (date str, pred).
92 |
93 | Returns:
94 | - float: The average prediction.
95 | """
96 | if not date_pred_list or len(date_pred_list) == 0:
97 | return 0.5 # Return a default value of 0.5 if there is no history
98 | return sum(tup[1] for tup in date_pred_list) / len(date_pred_list)
99 |
100 |
101 | def compute_bs_and_crowd_bs(pred, date_pred_list, retrieve_date, answer):
102 | """
103 | Computes Brier scores for individual prediction and community prediction.
104 |
105 | Parameters:
106 | - pred (float): The individual's probability prediction for an event.
107 | - date_pred_list (list of tuples): A list of tuples containing dates
108 | and community predictions. Each tuple is in the format (date, prediction).
109 | - retrieve_date (date): The date for which the community prediction is to be retrieved.
110 | - answer (int): The actual outcome of the event, where 0 indicates the event
111 | did not happen, and 1 indicates it did.
112 |
113 | Returns:
114 | - bs (float): The Brier score for the individual prediction.
115 | - bs_comm (float): The Brier score for the community prediction closest to the specified retrieve_date.
116 | """
117 | pred_comm = time_utils.find_closest_date(retrieve_date, date_pred_list)[-1]
118 | bs = brier_score([1 - pred, pred], answer)
119 | bs_comm = brier_score([1 - pred_comm, pred_comm], answer)
120 |
121 | return bs, bs_comm
122 |
--------------------------------------------------------------------------------
/llm_forecasting/utils/db_utils.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | from io import StringIO
3 | import logging
4 | import os
5 | import pickle
6 | import random
7 | import string
8 |
9 | # Related third-party imports
10 | import boto3
11 | import pandas as pd
12 |
13 | # Set up logging
14 | logging.basicConfig(level=logging.INFO)
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | def initialize_s3_client(aws_access_key_id, aws_secret_access_key):
19 | """
20 | Initialize an Amazon S3 client using provided AWS credentials.
21 |
22 | Args:
23 | - aws_access_key_id (str): AWS access key for authentication.
24 | - aws_secret_access_key (str): AWS secret access key for authentication.
25 |
26 | Returns:
27 | - boto3.client: Initialized S3 client.
28 | """
29 | return boto3.client(
30 | "s3",
31 | aws_access_key_id=aws_access_key_id,
32 | aws_secret_access_key=aws_secret_access_key,
33 | )
34 |
35 |
36 | def upload_data_structure_to_s3(s3, data_structure, bucket, s3_path):
37 | """
38 | Upload a local file to a specified path in an Amazon S3 bucket.
39 |
40 | Args:
41 | s3 (boto3.client): An initialized S3 client instance.
42 | data_structure (.): Data structure to save.
43 | bucket (str): Name of the target S3 bucket.
44 | s3_path (str): Desired filename within the S3 bucket.
45 | """
46 | try:
47 | extension = s3_path.split(".")[-1]
48 | hash = "".join(random.choices(string.ascii_letters + string.digits, k=5))
49 | temp_file = f"temp{hash}.{extension}"
50 | with open(temp_file, "wb") as f:
51 | pickle.dump(data_structure, f)
52 |
53 | s3.upload_file(temp_file, bucket, s3_path)
54 | os.remove(temp_file)
55 | logging.info(f"Successfully uploaded data to {bucket}/{s3_path}")
56 | except Exception as e:
57 | logging.error(f"Error uploading data to {bucket}/{s3_path}. Error: {e}")
58 |
59 |
60 | def upload_file_to_s3(s3, local_file, bucket, s3_path):
61 | """
62 | Upload a local file to a specified path in an Amazon S3 bucket.
63 |
64 | Args:
65 | - s3 (boto3.client): An initialized S3 client instance.
66 | - local_file (str): Path of the local file to upload.
67 | - bucket (str): Name of the target S3 bucket.
68 | - s3_path (str): Desired path within the S3 bucket.
69 |
70 | """
71 | try:
72 | s3.upload_file(local_file, bucket, s3_path)
73 | logging.info(f"Successfully uploaded {local_file} to {bucket}/{s3_path}")
74 | except Exception as e:
75 | logging.error(f"Error uploading {local_file} to {bucket}/{s3_path}. Error: {e}")
76 |
77 |
78 | def read_pickle_from_s3(s3, bucket, s3_path):
79 | """
80 | Fetch and deserialize a pickle file from an S3 bucket.
81 |
82 | This can be used in conjunction with `upload_data_structure_to_s3` to
83 | upload and download arbitrary data structures to/from S3.
84 |
85 | Args:
86 | - s3 (boto3.client): An initialized S3 client.
87 | - bucket (str): Name of the S3 bucket containing the pickle file.
88 | - s3_path (str): Path of the pickle file within the S3 bucket.
89 |
90 | Returns:
91 | - object: Deserialized object from the pickle file.
92 | """
93 | obj = s3.get_object(Bucket=bucket, Key=s3_path)
94 | data = pickle.loads(obj["Body"].read())
95 | return data
96 |
97 |
98 | def read_pickle_files_from_s3_folder(s3, bucket, s3_folder_path):
99 | # List objects in the specified S3 folder
100 | objects = s3.list_objects(Bucket=bucket, Prefix=s3_folder_path)
101 |
102 | pickle_files = []
103 | # Loop through the objects and filter for pickle files
104 | for obj in objects.get("Contents", []):
105 | key = obj["Key"]
106 | if key.endswith(".pickle"):
107 | # Download the pickle file
108 | response = s3.get_object(Bucket=bucket, Key=key)
109 | # Read the pickle file from the response
110 | pickle_data = pickle.loads(response["Body"].read())
111 | pickle_files.append(pickle_data)
112 |
113 | return pickle_files
114 |
115 |
116 | def read_csv_from_s3(s3, bucket, s3_path):
117 | """
118 | Fetch and read a CSV file from an S3 bucket into a pandas DataFrame.
119 |
120 | Args:
121 | - s3 (boto3.client): An initialized S3 client.
122 | - bucket (str): Name of the S3 bucket containing the CSV file.
123 | - s3_path (str): Path of the CSV file within the S3 bucket.
124 |
125 | Returns:
126 | - pd.DataFrame: DataFrame populated with the CSV data.
127 | """
128 | obj = s3.get_object(Bucket=bucket, Key=s3_path)
129 | df = pd.read_csv(StringIO(obj["Body"].read().decode("utf-8")))
130 | return df
131 |
--------------------------------------------------------------------------------
/notebooks/results/results.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "aab054a6",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import pickle\n",
11 | "import numpy as np"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 2,
17 | "id": "340d2fa4",
18 | "metadata": {},
19 | "outputs": [],
20 | "source": [
21 | "def calculate_normalized_weighted_trimmed_mean(predictions):\n",
22 | " # Step 1: Find the median\n",
23 | " median_prediction = np.median(predictions)\n",
24 | "\n",
25 | " # Step 2: Determine the prediction farthest from the median\n",
26 | " distances = np.abs(predictions - median_prediction)\n",
27 | " max_distance = np.max(distances)\n",
28 | "\n",
29 | " # Step 3: Down-weight the furthest prediction by half\n",
30 | " weights = np.ones(len(predictions))\n",
31 | " weights[distances == max_distance] *= 0.5\n",
32 | "\n",
33 | " # Step 4: Distribute the saved weight among other predictions\n",
34 | " saved_weight = (1.0 - 0.5) / (len(predictions) - 1)\n",
35 | " weights[distances != max_distance] += saved_weight\n",
36 | "\n",
37 | " # Step 5: Calculate the weighted mean\n",
38 | " weighted_mean = np.average(predictions, weights=weights)\n",
39 | "\n",
40 | " return weighted_mean"
41 | ]
42 | },
43 | {
44 | "cell_type": "code",
45 | "execution_count": 3,
46 | "id": "dc954c5e",
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "with open(\"data/answers.pickle\", \"rb\") as file:\n",
51 | " answers = pickle.load(file)\n",
52 | "with open(\"data/base_predictions.pickle\", \"rb\") as file:\n",
53 | " base_predictions = pickle.load(file)\n",
54 | "with open(\"data/finetuned_predictions.pickle\", \"rb\") as file:\n",
55 | " finetuned_predictions = pickle.load(file)\n",
56 | "with open(\"data/finetuned_other_predictions.pickle\", \"rb\") as file:\n",
57 | " finetuned_other_predictions = pickle.load(file)\n",
58 | "with open(\"data/crowd_predictions.pickle\", \"rb\") as file:\n",
59 | " community_predictions = pickle.load(file)"
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "execution_count": 4,
65 | "id": "be79c2e2",
66 | "metadata": {},
67 | "outputs": [
68 | {
69 | "name": "stdout",
70 | "output_type": "stream",
71 | "text": [
72 | "Base Brier Score: 0.1863574497732625\n",
73 | "Finetuned Brier Score: 0.18005945446102836\n",
74 | "Finetuned and Base Brier Score: 0.17988713531620448\n",
75 | "Crowd Brier Score: 0.1486199294280867\n"
76 | ]
77 | }
78 | ],
79 | "source": [
80 | "base_brier_score = 0\n",
81 | "finetuned_brier_score = 0\n",
82 | "finetuned_and_base_brier_score = 0\n",
83 | "crowd_brier_score = 0\n",
84 | "n = 0\n",
85 | "for i in range(5): # num retrieval dates\n",
86 | " for j in range(len(finetuned_predictions[i])):\n",
87 | " answer = answers[i][j]\n",
88 | " \n",
89 | " base_preds = base_predictions[i][j]\n",
90 | " finetuned_preds = finetuned_predictions[i][j]\n",
91 | " finetuned_other_preds = finetuned_other_predictions[i][j]\n",
92 | " crowd_pred = community_predictions[i][j]\n",
93 | " \n",
94 | " base_pred = calculate_normalized_weighted_trimmed_mean(base_preds)\n",
95 | " finetuned_pred = np.mean(finetuned_preds + finetuned_other_preds)\n",
96 | " finetuned_and_base_pred = calculate_normalized_weighted_trimmed_mean(base_preds + finetuned_preds + finetuned_other_preds)\n",
97 | " \n",
98 | " base_brier_score += (base_pred - answer) ** 2\n",
99 | " finetuned_brier_score += (finetuned_pred - answer) ** 2\n",
100 | " finetuned_and_base_brier_score += (finetuned_and_base_pred - answer) ** 2\n",
101 | " crowd_brier_score += (crowd_pred - answer) ** 2\n",
102 | " n += 1\n",
103 | "\n",
104 | "print(\"Base Brier Score:\", base_brier_score/n)\n",
105 | "print(\"Finetuned Brier Score:\", finetuned_brier_score/n)\n",
106 | "print(\"Finetuned and Base Brier Score:\", finetuned_and_base_brier_score/n)\n",
107 | "print(\"Crowd Brier Score:\", crowd_brier_score/n)"
108 | ]
109 | }
110 | ],
111 | "metadata": {
112 | "kernelspec": {
113 | "display_name": "Python 3 (ipykernel)",
114 | "language": "python",
115 | "name": "python3"
116 | },
117 | "language_info": {
118 | "codemirror_mode": {
119 | "name": "ipython",
120 | "version": 3
121 | },
122 | "file_extension": ".py",
123 | "mimetype": "text/x-python",
124 | "name": "python",
125 | "nbconvert_exporter": "python",
126 | "pygments_lexer": "ipython3",
127 | "version": "3.11.0"
128 | }
129 | },
130 | "nbformat": 4,
131 | "nbformat_minor": 5
132 | }
133 |
--------------------------------------------------------------------------------
/scripts/data_scraping/cset.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | import argparse
3 | import datetime
4 | import json
5 | import logging
6 | import os
7 | import random
8 | import time
9 |
10 | # Related third-party imports
11 | import jsonlines
12 | from selenium.webdriver.common.by import By
13 |
14 | # Local application/library-specific imports
15 | import data_scraping
16 | from config.keys import keys
17 |
18 | logger = logging.getLogger(__name__)
19 | MAX_CONSECUTIVE_NOT_FOUND = 1000
20 | # Writing to file for debugging purposes. It will be deleted once the script is done.
21 | FILE_PATH = "cset_dump.jsonl"
22 | # Use your own executable_path (download from https://chromedriver.chromium.org/).
23 | CHROMEDRIVER_PATH = "/Users/apple/Downloads/chromedriver-mac-x64/chromedriver"
24 | CSET_EMAIL = keys["EMAIL_JOHN"]
25 | CSET_PASSWORD = keys["GJOPEN_CSET_PASSWORD_JOHN"]
26 |
27 |
28 | def main(n_days):
29 | """
30 | Scrape, process, and upload question data from CSET (https://www.infer-pub.com/)
31 |
32 | Args:
33 | n_days (int): Number of days to look back for questions.
34 |
35 | Returns:
36 | list: A list of processed question data.
37 | """
38 | driver = data_scraping.initialize_and_login(
39 | signin_page="https://www.infer-pub.com/users/sign_in",
40 | email=CSET_EMAIL,
41 | password=CSET_PASSWORD,
42 | executable_path=CHROMEDRIVER_PATH,
43 | )
44 |
45 | question_counter = 0
46 | consecutive_not_found_or_skipped = 0
47 | while True:
48 | question_counter += 1
49 | url = f"https://www.infer-pub.com/questions/{question_counter}"
50 |
51 | try:
52 | driver.get(url)
53 | trend_graph_element = driver.find_element(
54 | By.CSS_SELECTOR,
55 | "div[data-react-class='FOF.Forecast.QuestionTrendGraph']",
56 | )
57 | props = json.loads(trend_graph_element.get_attribute("data-react-props"))
58 | props["extracted_articles_urls"] = data_scraping.get_source_links(
59 | driver, url
60 | )
61 |
62 | with jsonlines.open(FILE_PATH, mode="a") as writer:
63 | writer.write(props)
64 | consecutive_not_found_or_skipped = 0
65 | except BaseException:
66 | if data_scraping.question_not_found(driver):
67 | logger.info(f"Question {question_counter} not found")
68 | else:
69 | logger.info(f"Skipping question {question_counter}")
70 | consecutive_not_found_or_skipped += 1
71 | if consecutive_not_found_or_skipped > MAX_CONSECUTIVE_NOT_FOUND:
72 | logger.info("Reached maximum consecutive not found.")
73 | break
74 |
75 | time.sleep(random.uniform(0, 2)) # random delay between requests
76 |
77 | data = []
78 | with open(FILE_PATH, "r") as file:
79 | for line in file:
80 | json_line = json.loads(line)
81 | data.append(json_line)
82 |
83 | # Remove duplicated dicts
84 | unique_tuples = {data_scraping.make_hashable(d) for d in data}
85 | all_questions = [data_scraping.unhashable_to_dict(t) for t in unique_tuples]
86 |
87 | if n_days is not None:
88 | date_limit = datetime.datetime.now() - datetime.timedelta(days=n_days)
89 | all_questions = [
90 | q
91 | for q in all_questions
92 | if datetime.datetime.fromisoformat(q["question"]["created_at"][:-1])
93 | >= date_limit
94 | ]
95 |
96 | logger.info(f"Number of cset questions fetched: {len(all_questions)}")
97 |
98 | for i in range(len(all_questions)):
99 | try:
100 | all_questions[i]["community_prediction"] = all_questions[i].pop(
101 | "trend_graph_probabilities"
102 | )
103 | except BaseException:
104 | all_questions[i]["community_prediction"] = []
105 | all_questions[i]["url"] = "https://www.infer-pub.com/questions/" + str(
106 | all_questions[i]["question"]["id"]
107 | )
108 | all_questions[i]["title"] = all_questions[i]["question"]["name"]
109 | all_questions[i]["close_time"] = all_questions[i]["question"].pop("closed_at")
110 | all_questions[i]["created_time"] = all_questions[i]["question"].pop(
111 | "created_at"
112 | )
113 | all_questions[i]["background"] = all_questions[i]["question"]["description"]
114 | all_questions[i]["data_source"] = "cset"
115 |
116 | if all_questions[i]["question"]["state"] != "resolved":
117 | all_questions[i]["resolution"] = "Not resolved."
118 | all_questions[i]["is_resolved"] = False
119 | else:
120 | all_questions[i]["resolution"] = all_questions[i]["question"]["answers"][0][
121 | "probability"
122 | ]
123 | all_questions[i]["is_resolved"] = True
124 |
125 | all_questions[i]["question_type"] = "binary"
126 | answer_set = set(
127 | [answer["name"] for answer in all_questions[i]["question"]["answers"]]
128 | )
129 | if answer_set == {"Yes", "No"} or answer_set == {"Yes"}:
130 | all_questions[i]["question_type"] = "binary"
131 | else:
132 | all_questions[i]["question_type"] = "multiple_choice"
133 |
134 | logger.info("Uploading to s3...")
135 | question_types = ["binary", "multiple_choice"]
136 | data_scraping.upload_scraped_data(all_questions, "cset", question_types, n_days)
137 |
138 | # Delete the file after script completion
139 | if os.path.exists(FILE_PATH):
140 | os.remove(FILE_PATH)
141 | logger.info(f"Deleted the file: {FILE_PATH}")
142 | else:
143 | logger.info(f"The file {FILE_PATH} does not exist")
144 |
145 |
146 | if __name__ == "__main__":
147 | parser = argparse.ArgumentParser(description="Fetch cset data.")
148 | parser.add_argument(
149 | "--n_days",
150 | type=int,
151 | help="Fetch markets created in the last N days",
152 | default=None,
153 | )
154 | args = parser.parse_args()
155 | main(args.n_days)
156 |
--------------------------------------------------------------------------------
/scripts/data_scraping/gjopen.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | import argparse
3 | import datetime
4 | import json
5 | import logging
6 | import os
7 | import random
8 | import time
9 |
10 | # Related third-party imports
11 | import jsonlines
12 | from selenium.webdriver.common.by import By
13 |
14 | # Local application/library-specific imports
15 | import data_scraping
16 | from config.keys import keys
17 |
18 | logger = logging.getLogger(__name__)
19 | MAX_CONSECUTIVE_NOT_FOUND = 1000
20 | # Writing to file for debugging purposes. It will be deleted once the script is done.
21 | FILE_PATH = "gjopen_dump.jsonl"
22 | # Use your own executable_path (download from https://chromedriver.chromium.org/).
23 | CHROMEDRIVER_PATH = "/Users/apple/Downloads/chromedriver-mac-x64/chromedriver"
24 | GJOPEN_EMAIL = keys["EMAIL"]
25 | GJOPEN_PASSWORD = keys["GJOPEN_CSET_PASSWORD"]
26 |
27 |
28 | def main(n_days):
29 | """
30 | Scrape, process, and upload question data from gjopen (https://www.gjopen.com/)
31 |
32 | Args:
33 | n_days (int): Number of days to look back for questions.
34 |
35 | Returns:
36 | list: A list of processed question data.
37 | """
38 | driver = data_scraping.initialize_and_login(
39 | signin_page="https://www.gjopen.com/users/sign_in",
40 | email=GJOPEN_EMAIL,
41 | password=GJOPEN_PASSWORD,
42 | executable_path=CHROMEDRIVER_PATH,
43 | )
44 |
45 | question_counter = 0
46 | consecutive_not_found_or_skipped = 0
47 | while True:
48 | question_counter += 1
49 | url = f"https://www.gjopen.com/questions/{question_counter}"
50 |
51 | try:
52 | driver.get(url)
53 | trend_graph_element = driver.find_element(
54 | By.CSS_SELECTOR,
55 | "div[data-react-class='FOF.Forecast.QuestionTrendGraph']",
56 | )
57 | props = json.loads(trend_graph_element.get_attribute("data-react-props"))
58 | props["extracted_articles_urls"] = data_scraping.get_source_links(
59 | driver, url
60 | )
61 |
62 | with jsonlines.open(FILE_PATH, mode="a") as writer:
63 | writer.write(props)
64 | consecutive_not_found_or_skipped = 0
65 | except BaseException:
66 | if data_scraping.question_not_found(driver):
67 | logger.info(f"Question {question_counter} not found")
68 | else:
69 | logger.info(f"Skipping question {question_counter}")
70 | consecutive_not_found_or_skipped += 1
71 | if consecutive_not_found_or_skipped > MAX_CONSECUTIVE_NOT_FOUND:
72 | logger.info("Reached maximum consecutive not found.")
73 | break
74 |
75 | time.sleep(random.uniform(0, 2)) # random delay between requests
76 |
77 | data = []
78 | with open(FILE_PATH, "r") as file:
79 | for line in file:
80 | json_line = json.loads(line)
81 | data.append(json_line)
82 |
83 | # Remove duplicated dicts
84 | unique_tuples = {data_scraping.make_hashable(d) for d in data}
85 | all_questions = [data_scraping.unhashable_to_dict(t) for t in unique_tuples]
86 |
87 | if n_days is not None:
88 | date_limit = datetime.datetime.now() - datetime.timedelta(days=n_days)
89 | all_questions = [
90 | q
91 | for q in all_questions
92 | if datetime.datetime.fromisoformat(q["question"]["created_at"][:-1])
93 | >= date_limit
94 | ]
95 |
96 | logger.info(f"Number of gjopen questions fetched: {len(all_questions)}")
97 |
98 | for i in range(len(all_questions)):
99 | try:
100 | all_questions[i]["community_prediction"] = all_questions[i].pop(
101 | "trend_graph_probabilities"
102 | )
103 | except BaseException:
104 | all_questions[i]["community_prediction"] = []
105 | all_questions[i]["url"] = "https://www.gjopen.com/questions/" + str(
106 | all_questions[i]["question"]["id"]
107 | )
108 | all_questions[i]["title"] = all_questions[i]["question"]["name"]
109 | all_questions[i]["close_time"] = all_questions[i]["question"].pop("closed_at")
110 | all_questions[i]["created_time"] = all_questions[i]["question"].pop(
111 | "created_at"
112 | )
113 | all_questions[i]["background"] = all_questions[i]["question"]["description"]
114 | all_questions[i]["data_source"] = "gjopen"
115 |
116 | if all_questions[i]["question"]["state"] != "resolved":
117 | all_questions[i]["resolution"] = "Not resolved."
118 | all_questions[i]["is_resolved"] = False
119 | else:
120 | all_questions[i]["resolution"] = all_questions[i]["question"]["answers"][0][
121 | "probability"
122 | ]
123 | all_questions[i]["is_resolved"] = True
124 |
125 | all_questions[i]["question_type"] = "binary"
126 | answer_set = set(
127 | [answer["name"] for answer in all_questions[i]["question"]["answers"]]
128 | )
129 | if answer_set == {"Yes", "No"} or answer_set == {"Yes"}:
130 | all_questions[i]["question_type"] = "binary"
131 | else:
132 | all_questions[i]["question_type"] = "multiple_choice"
133 |
134 | logger.info("Uploading to s3...")
135 | question_types = ["binary", "multiple_choice"]
136 | data_scraping.upload_scraped_data(all_questions, "gjopen", question_types, n_days)
137 |
138 | # Delete the file after script completion
139 | if os.path.exists(FILE_PATH):
140 | os.remove(FILE_PATH)
141 | logger.info(f"Deleted the file: {FILE_PATH}")
142 | else:
143 | logger.info(f"The file {FILE_PATH} does not exist")
144 |
145 |
146 | if __name__ == "__main__":
147 | parser = argparse.ArgumentParser(description="Fetch gjopen data.")
148 | parser.add_argument(
149 | "--n_days",
150 | type=int,
151 | help="Fetch markets created in the last N days",
152 | default=None,
153 | )
154 | args = parser.parse_args()
155 | main(args.n_days)
156 |
--------------------------------------------------------------------------------
/llm_forecasting/prompts/summarization.py:
--------------------------------------------------------------------------------
1 | SUMMARIZATION_PROMPT_0 = (
2 | """Summarize the article below, ensuring to include details pertinent to the subsequent question.
3 |
4 | Question: {question}
5 | Question Background: {background}
6 |
7 | Article:
8 | ---
9 | {article}
10 | ---""",
11 | ("QUESTION", "BACKGROUND"),
12 | )
13 |
14 | SUMMARIZATION_PROMPT_1 = (
15 | """I will present a forecasting question and a related article.
16 |
17 | Question: {question}
18 | Question Background: {background}
19 |
20 | Article:
21 | ---
22 | {article}
23 | ---
24 |
25 | A forecaster prefers a list of bullet points containing facts, observations, details, analysis, etc., over reading a full article.
26 |
27 | Your task is to distill the article as a list of bullet points that would help a forecaster in his deliberation.""",
28 | ("QUESTION", "BACKGROUND"),
29 | )
30 |
31 | SUMMARIZATION_PROMPT_2 = (
32 | """I want to make the following article shorter (condense it to no more than 500 words).
33 |
34 | Article:
35 | ---
36 | {article}
37 | ---
38 |
39 | When doing this task for me, please do not remove any details that would be helpful for making considerations about the following forecasting question.
40 |
41 | Forecasting Question: {question}
42 | Question Background: {background}""",
43 | ("QUESTION", "BACKGROUND"),
44 | )
45 |
46 | SUMMARIZATION_PROMPT_3 = (
47 | """I will present a forecasting question and a related article.
48 |
49 | Forecasting Question: {question}
50 | Question Background: {background}
51 |
52 | Article:
53 | ---
54 | {article}
55 | ---
56 |
57 | Use the article to write a list of bullet points that help a forecaster in their deliberation.
58 |
59 | Guidelines:
60 | - Ensure each bullet point contains specific, detailed information.
61 | - Avoid vague statements; instead, focus on summarizing key observations, data, predictions, or analysis presented in the article.
62 | - Also, extract points that directly or indirectly contribute to a better understanding or prediction of the specified question.""",
63 | ("QUESTION", "BACKGROUND"),
64 | )
65 |
66 |
67 | SUMMARIZATION_PROMPT_4 = (
68 | """Summarize the below article.
69 |
70 | Article:
71 | ---
72 | {article}
73 | ---""",
74 | ("",),
75 | )
76 |
77 |
78 | SUMMARIZATION_PROMPT_5 = (
79 | """I will present a forecasting question and a related article.
80 |
81 | Question: {question}
82 | Question Background: {background}
83 |
84 | Article:
85 | ---
86 | {article}
87 | ---
88 |
89 | I want to shorten the following article (condense it to no more than 500 words). When doing this task for me, please do not remove any details that would be helpful for making considerations about the forecasting question.""",
90 | ("QUESTION", "BACKGROUND"),
91 | )
92 |
93 |
94 | SUMMARIZATION_PROMPT_6 = (
95 | """Create a summary of the following article that assists in making a prediction for the following question.
96 |
97 | Forecasting Question: {question}
98 | Question Background: {background}
99 |
100 | Article:
101 | ---
102 | {article}
103 | ---
104 |
105 | Guidelines for Summary:
106 | - Include bullet points that extract key facts, observations, and analyses directly relevant to the forecasting question.
107 | - Then include analysis that connects the article's content to the forecasting question.
108 | - Strive for a balance between brevity and completeness, aiming for a summary that is informative yet efficient for a forecaster's analysis.""",
109 | ("QUESTION", "BACKGROUND"),
110 | )
111 |
112 |
113 | SUMMARIZATION_PROMPT_7 = (
114 | """Create a summary of the following article that assists in making a prediction for the following question.
115 |
116 | Prediction Question: {question}
117 | Question Background: {background}
118 |
119 | Article:
120 | ---
121 | {article}
122 | ---
123 |
124 | Guidelines for Summary:
125 | - Include bullet points that extract key facts, observations, and analyses directly relevant to the forecasting question.
126 | - Where applicable, highlight direct or indirect connections between the article's content and the forecasting question.
127 | - Strive for a balance between brevity and completeness, aiming for a summary that is informative yet efficient for a forecaster's analysis.""",
128 | ("QUESTION", "BACKGROUND"),
129 | )
130 |
131 |
132 | SUMMARIZATION_PROMPT_8 = (
133 | """I want to make the following article shorter (condense it to no more than 100 words).
134 |
135 | Article:
136 | ---
137 | {article}
138 | ---
139 |
140 | When doing this task for me, please do not remove any details that would be helpful for making considerations about the following forecasting question.
141 |
142 | Forecasting Question: {question}
143 | Question Background: {background}""",
144 | ("QUESTION", "BACKGROUND"),
145 | )
146 |
147 | SUMMARIZATION_PROMPT_9 = (
148 | """I want to make the following article shorter (condense it to no more than 100 words).
149 |
150 | Article:
151 | ---
152 | {article}
153 | ---
154 |
155 | When doing this task for me, please do not remove any details that would be helpful for making considerations about the following forecasting question.
156 |
157 | Forecasting Question: {question}
158 | Question Background: {background}""",
159 | ("QUESTION", "BACKGROUND"),
160 | )
161 |
162 | SUMMARIZATION_PROMPT_10 = (
163 | """I will present a forecasting question and a related article.
164 |
165 | Forecasting Question: {question}
166 | Question Background: {background}
167 |
168 | Article:
169 | ---
170 | {article}
171 | ---
172 |
173 | Use the article to write a list of bullet points that help a forecaster in their deliberation.
174 |
175 | Guidelines:
176 | - Ensure each bullet point contains specific, detailed information.
177 | - Avoid vague statements; instead, focus on summarizing key observations, data, predictions, or analysis presented in the article.
178 | - Your list should never exceed 5 bullet points.""",
179 | ("QUESTION", "BACKGROUND"),
180 | )
181 |
182 |
183 | SUMMARIZATION_PROMPT_11 = (
184 | """I will present a forecasting question and a related article.
185 |
186 | Question: {question}
187 | Question Background: {background}
188 |
189 | Article:
190 | ---
191 | {article}
192 | ---
193 |
194 | A forecaster prefers a list of bullet points containing facts, observations, details, analysis, etc., over reading a full article.
195 |
196 | Your task is to distill the article as a list of bullet points that would help a forecaster in his deliberation. Ensure, that you make your list as concise and short as possible without removing critical information.""",
197 | ("QUESTION", "BACKGROUND"),
198 | )
199 |
--------------------------------------------------------------------------------
/llm_forecasting/data_scraping.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import requests
3 | import logging
4 |
5 | from selenium import webdriver
6 | from selenium.webdriver.chrome.options import Options
7 | from selenium.webdriver.chrome.service import Service
8 | from selenium.webdriver.common.by import By
9 | from selenium.webdriver.support import expected_conditions as EC
10 | from selenium.webdriver.support.ui import WebDriverWait
11 |
12 | from config.constants import S3, S3_BUCKET_NAME
13 | from utils import db_utils
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | def upload_scraped_data(data, source, question_types, n_days_or_not=None):
19 | """
20 | Upload data (scraped by the script) by type to S3, partitioned by
21 | the specified source and date range if specified.
22 |
23 | Args:
24 | data (list): List of question data to process.
25 | source (str): The source identifier for categorizing uploaded data.
26 | question_types (list): List of question types to filter and upload.
27 | n_days_or_not (int or None): Number of days to look back for questions,
28 | or None for all available data.
29 | """
30 | today_date = datetime.datetime.now().strftime("%Y_%m_%d")
31 |
32 | for q_type in question_types:
33 | questions = [q for q in data if q["question_type"] == q_type]
34 | q_type = q_type.lower()
35 | if n_days_or_not:
36 | file_name = f"n_days_updated_{q_type}_questions_{today_date}.pickle"
37 | s3_path = f"{source}/n_days_updated_{q_type}_questions/{file_name}"
38 | else:
39 | file_name = f"updated_{q_type}_questions_{today_date}.pickle"
40 | s3_path = f"{source}/updated_{q_type}_questions/{file_name}"
41 |
42 | db_utils.upload_data_structure_to_s3(S3, questions, S3_BUCKET_NAME, s3_path)
43 |
44 |
45 | def fetch_question_description(headers, question_id):
46 | """
47 | Fetch and return the description of a specific question from Metaculus.
48 |
49 | The function queries the Metaculus API for a given question ID and
50 | extracts the question's description from the response.
51 |
52 | Parameters:
53 | - headers (dict): Headers to include in the API request.
54 | - question_id (int): The ID of the question to fetch.
55 |
56 | Returns:
57 | - str: Description of the question, or an empty string if not found.
58 | """
59 | endpoint = f"https://www.metaculus.com/api2/questions/{question_id}"
60 | response = requests.get(endpoint, headers=headers)
61 | return response.json().get("description", "")
62 |
63 |
64 | # Use your own executable_path (download from https://chromedriver.chromium.org/).
65 | def initialize_and_login(
66 | signin_page,
67 | email,
68 | password,
69 | executable_path="/Users/apple/Downloads/chromedriver-mac-x64/chromedriver",
70 | ):
71 | """
72 | Initialize the WebDriver and log into the website.
73 |
74 | Returns:
75 | selenium.webdriver.Chrome: An instance of the Chrome WebDriver after logging in.
76 | """
77 | # Webdriver options
78 | chrome_options = Options()
79 | chrome_options.add_argument(
80 | "user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 "
81 | "(KHTML, like Gecko) Chrome/88.0.4324.150 Safari/537.36"
82 | )
83 | service = Service(executable_path=executable_path)
84 | driver = webdriver.Chrome(service=service, options=chrome_options)
85 |
86 | # log in
87 | driver.get(signin_page)
88 | driver.find_element(By.ID, "user_email").send_keys(email)
89 | driver.find_element(By.ID, "user_password").send_keys(password)
90 | driver.find_element(By.NAME, "commit").click()
91 |
92 | return driver
93 |
94 |
95 | def question_not_found(driver):
96 | """
97 | Check if a specific question is not found on the website.
98 |
99 | Args:
100 | driver (webdriver.Chrome): The Selenium Chrome WebDriver instance.
101 |
102 | Returns:
103 | bool: True if the question is not found, False otherwise.
104 | """
105 | try:
106 | return (
107 | driver.find_element(By.CLASS_NAME, "flash-message").text
108 | == "Could not find that question."
109 | )
110 | except BaseException:
111 | return False
112 |
113 |
114 | def get_source_links(driver, url):
115 | """
116 | Retrieve source links from a given question page.
117 |
118 | Args:
119 | driver (webdriver.Chrome): The Selenium Chrome WebDriver instance.
120 | url (str): The URL of the question page.
121 |
122 | Returns:
123 | list: A list of retrieved source links.
124 | """
125 | source_links = set()
126 | try:
127 | driver.get(url) # Make sure to navigate to the page first
128 | driver.find_element(By.XPATH, '//a[text() = "Source Links"]').click()
129 | WebDriverWait(driver, 2).until(
130 | EC.visibility_of_element_located((By.ID, "links-table"))
131 | )
132 | rows = driver.find_elements(
133 | By.XPATH, '//table[contains(@id, "links-table")]/tbody/tr'
134 | )
135 | for entry in rows:
136 | try:
137 | url_elem = entry.find_element(By.TAG_NAME, "a")
138 | source_links.add(url_elem.get_attribute("href"))
139 | except BaseException:
140 | # If no tag is found in this table data, it will skip to
141 | # the next
142 | continue
143 | except: # Catch any other exception
144 | logger.info("Failed to get links.")
145 | return list(source_links)
146 |
147 |
148 | def make_hashable(e):
149 | """
150 | Convert elements, including dictionaries, into a hashable form.
151 |
152 | Args:
153 | e (Any): The element to be converted.
154 |
155 | Returns:
156 | tuple: A tuple representing the hashable form of the element.
157 | """
158 | if isinstance(e, dict):
159 | return tuple((key, make_hashable(val)) for key, val in sorted(e.items()))
160 | elif isinstance(e, list):
161 | return tuple(make_hashable(x) for x in e)
162 | else:
163 | return e
164 |
165 |
166 | def unhashable_to_dict(t):
167 | """
168 | Convert tuples back into their original dictionary or list forms.
169 |
170 | Args:
171 | t (tuple): The tuple to be converted.
172 |
173 | Returns:
174 | Any: The original dictionary or list form of the tuple.
175 | """
176 | if isinstance(t, tuple):
177 | try:
178 | return dict((k, unhashable_to_dict(v)) for k, v in t)
179 | except ValueError:
180 | return [unhashable_to_dict(x) for x in t]
181 | else:
182 | return t
183 |
--------------------------------------------------------------------------------
/scripts/training_data/training_point_generation.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | import argparse
3 | import asyncio
4 | import logging
5 | import random
6 |
7 | # Local application/library specific imports
8 | from utils import data_utils
9 | from prompts.prompts import PROMPT_DICT
10 | import evaluation
11 |
12 | # Set up logging
13 | logging.basicConfig(level=logging.INFO)
14 | logger = logging.getLogger(__name__)
15 |
16 | TRAINING_RETRIEVAL_CONFIG = {
17 | # Number of search query keywords per question.
18 | "NUM_SEARCH_QUERY_KEYWORDS": [4, 5, 6],
19 | "MAX_WORDS_NEWSCATCHER": [5],
20 | "MAX_WORDS_GNEWS": [7, 8, 9],
21 | "SEARCH_QUERY_MODEL_NAME": ["gpt-4-1106-preview"],
22 | "SEARCH_QUERY_TEMPERATURE": [0.0],
23 | "SEARCH_QUERY_PROMPT_TEMPLATES": [
24 | [PROMPT_DICT["search_query"]["0"]],
25 | [PROMPT_DICT["search_query"]["1"]],
26 | [PROMPT_DICT["search_query"]["0"], PROMPT_DICT["search_query"]["1"]],
27 | ],
28 | "NUM_ARTICLES_PER_QUERY": [5, 7, 9],
29 | "SUMMARIZATION_MODEL_NAME": ["gpt-3.5-turbo-1106"],
30 | "SUMMARIZATION_TEMPERATURE": [0.2],
31 | "SUMMARIZATION_PROMPT_TEMPLATE": [
32 | PROMPT_DICT["summarization"]["0"],
33 | PROMPT_DICT["summarization"]["1"],
34 | PROMPT_DICT["summarization"]["9"],
35 | ],
36 | "RANKING_MODEL_NAME": ["gpt-3.5-turbo-1106"],
37 | "RANKING_TEMPERATURE": [0.0],
38 | "RANKING_PROMPT_TEMPLATE": [PROMPT_DICT["ranking"]["0"]],
39 | "RANKING_RELEVANCE_THRESHOLD": [4],
40 | "SORT_BY": ["relevancy"],
41 | "RANKING_METHOD": ["llm-rating"],
42 | "RANKING_METHOD_LLM": ["title_250_tokens"],
43 | "NUM_SUMMARIES_THRESHOLD": [15, 20, 25],
44 | }
45 |
46 | TRAINING_REASONING_CONFIG = {
47 | "BASE_REASONING_MODEL_NAMES": [["claude-2.1", "gpt-4-1106-preview"]],
48 | "BASE_REASONING_TEMPERATURE": [0.2, 0.3, 0.4],
49 | "BASE_REASONING_PROMPT_TEMPLATES": [[], []],
50 | "AGGREGATION_METHOD": ["meta"],
51 | "AGGREGATION_PROMPT_TEMPLATE": [PROMPT_DICT["meta_reasoning"]["0"]],
52 | "AGGREGATION_TEMPERATURE": [0.2, 0.3, 0.4],
53 | "AGGREGATION_MODEL_NAME": ["claude-2.1", "gpt-4-1106-preview"],
54 | "AGGREGATION_WEIGTHTS": None,
55 | }
56 |
57 |
58 | def sample_retrieval_hyperparms(ir_config):
59 | """
60 | Sample hyperparameters for information retrieval configuration.
61 |
62 | Args:
63 | ir_config (dict): A dictionary containing different hyperparameters for information retrieval.
64 |
65 | Returns:
66 | dict: A dictionary with the same keys as ir_config, but each key has a single randomly
67 | sampled hyperparameter.
68 | """
69 | sampled_ir_config = ir_config.copy()
70 | for key, hyperparams in ir_config.items():
71 | sampled_ir_config[key] = random.choice(hyperparams)
72 | return sampled_ir_config
73 |
74 |
75 | def sample_reasoning_hyperparams(reasoning_config, prompts_to_sample, prompt_weights):
76 | sampled_reasoning_config = reasoning_config.copy()
77 | for key, hyperparams in reasoning_config.items():
78 | if key != "BASE_REASONING_PROMPT_TEMPLATES":
79 | sampled_reasoning_config[key] = random.choice(hyperparams)
80 | else:
81 | # For BASE_REASONING_PROMPT_TEMPLATES, sample 5 prompts for each model
82 | models = sampled_reasoning_config["BASE_REASONING_MODEL_NAMES"]
83 | print(models)
84 | sampled_prompts = []
85 | for model in models:
86 | model_prompts = random.choices(
87 | prompts_to_sample, weights=prompt_weights[model], k=5
88 | )
89 | sampled_prompts.append(model_prompts)
90 | sampled_reasoning_config[key] = sampled_prompts
91 | return sampled_reasoning_config
92 |
93 |
94 | async def generate_training_points(
95 | s3_path,
96 | retrieval_index,
97 | num_retrievals,
98 | questions_after,
99 | ir_config,
100 | reasoning_config,
101 | output_dir,
102 | prompts_to_sample,
103 | prompt_weights,
104 | ):
105 | """
106 | Asynchronously generates training data points.
107 |
108 | Args:
109 | s3_path (str): The S3 path to retrieve initial data.
110 | retrieval_index (int): An index specifying the location or category for retrieval.
111 | num_retrievals (int): Number of data retrievals to perform.
112 | questions_after (str): A filtering criterion for questions.
113 | ir_config (dict, optional): Configuration for information retrieval. Defaults to {}.
114 | reasoning_config (dict, optional): Configuration for reasoning processes. Defaults to {}.
115 | output_dir (str, optional): The directory where output files are stored. Defaults to 'data_point_generation'.
116 |
117 | Description:
118 | Retrieves training data, iterates through questions, evaluates them if necessary, processes them
119 | based on given configurations, and saves the output.
120 |
121 | Returns:
122 | None: The function is used for its side effect of processing and saving data.
123 | """
124 | data_dict, raw_data = data_utils.get_data(
125 | s3_path,
126 | retrieval_index,
127 | num_retrievals,
128 | questions_after=questions_after,
129 | return_raw_question_data=True,
130 | )
131 | for q_index, question in enumerate(data_dict["question_list"]):
132 | if not evaluation.to_eval(question, retrieval_index, output_dir):
133 | logger.info(f"Already processed question, {q_index}: {question}")
134 | continue
135 |
136 | logger.info(f"Starting question, {q_index}: {question}")
137 | try:
138 | ir_config_samp = sample_retrieval_hyperparms(ir_config)
139 | reasoning_config_samp = sample_reasoning_hyperparams(
140 | reasoning_config, prompts_to_sample, prompt_weights
141 | )
142 | output, _, ranked_articles = await evaluation.retrieve_and_forecast(
143 | data_utils.format_single_question(data_dict, q_index),
144 | raw_data[q_index],
145 | ir_config=ir_config_samp,
146 | reason_config=reasoning_config_samp,
147 | return_articles=True,
148 | calculate_alignment=True,
149 | )
150 | output["ranked_articles"] = [
151 | (art.summary, art.relevance_rating) for art in ranked_articles
152 | ]
153 | evaluation.save_results(output, question, retrieval_index, output_dir)
154 | except Exception as e:
155 | logger.error(f"Error processing question {q_index}: {e}")
156 |
157 | return None
158 |
159 |
160 | async def main():
161 | """
162 | Start the training point generation job.
163 | """
164 | parser = argparse.ArgumentParser()
165 | parser.add_argument(
166 | "--s3_path",
167 | type=str,
168 | default="training_sets/forecasting_binary_training_set.pickle",
169 | help="S3 dataset path to run (use 'default' for metaculus training set).",
170 | )
171 | parser.add_argument(
172 | "--retrieval_index",
173 | type=int,
174 | default=1,
175 | help="Index for retrieval (1 to |num_retrievals|).",
176 | )
177 | parser.add_argument(
178 | "--num_retrievals",
179 | type=int,
180 | default=5,
181 | help="Number of ideal retrieval dates.",
182 | )
183 | parser.add_argument(
184 | "--questions_after",
185 | type=str,
186 | default="2015",
187 | help="The lower-bound year for questions to evaluate.",
188 | )
189 | args = parser.parse_args()
190 |
191 | await generate_training_points(
192 | args.s3_path,
193 | args.retrieval_index,
194 | args.num_retrievals,
195 | args.questions_after,
196 | ir_config=TRAINING_RETRIEVAL_CONFIG,
197 | reasoning_config=TRAINING_REASONING_CONFIG,
198 | output_dir="data_point_generation",
199 | )
200 |
201 |
202 | if __name__ == "__main__":
203 | asyncio.run(main())
204 |
--------------------------------------------------------------------------------
/llm_forecasting/config/constants.py:
--------------------------------------------------------------------------------
1 | # Local application/library specific imports
2 | from config.keys import keys
3 | from prompts.prompts import PROMPT_DICT
4 | from utils import db_utils
5 |
6 | OAI_SOURCE = "OAI"
7 | ANTHROPIC_SOURCE = "ANTHROPIC"
8 | TOGETHER_AI_SOURCE = "TOGETHER"
9 | GOOGLE_SOURCE = "GOOGLE"
10 | HUGGINGFACE_SOURCE = "HUGGINGFACE"
11 |
12 | DEFAULT_RETRIEVAL_CONFIG = {
13 | "NUM_SEARCH_QUERY_KEYWORDS": 3,
14 | "MAX_WORDS_NEWSCATCHER": 5,
15 | "MAX_WORDS_GNEWS": 8,
16 | "SEARCH_QUERY_MODEL_NAME": "gpt-4-1106-preview",
17 | "SEARCH_QUERY_TEMPERATURE": 0.0,
18 | "SEARCH_QUERY_PROMPT_TEMPLATES": [
19 | PROMPT_DICT["search_query"]["0"],
20 | PROMPT_DICT["search_query"]["1"],
21 | ],
22 | "NUM_ARTICLES_PER_QUERY": 5,
23 | "SUMMARIZATION_MODEL_NAME": "gpt-3.5-turbo-1106",
24 | "SUMMARIZATION_TEMPERATURE": 0.2,
25 | "SUMMARIZATION_PROMPT_TEMPLATE": PROMPT_DICT["summarization"]["9"],
26 | "NUM_SUMMARIES_THRESHOLD": 10,
27 | "PRE_FILTER_WITH_EMBEDDING": True,
28 | "PRE_FILTER_WITH_EMBEDDING_THRESHOLD": 0.32,
29 | "RANKING_MODEL_NAME": "gpt-3.5-turbo-1106",
30 | "RANKING_TEMPERATURE": 0.0,
31 | "RANKING_PROMPT_TEMPLATE": PROMPT_DICT["ranking"]["0"],
32 | "RANKING_RELEVANCE_THRESHOLD": 4,
33 | "RANKING_COSINE_SIMILARITY_THRESHOLD": 0.5,
34 | "SORT_BY": "date",
35 | "RANKING_METHOD": "llm-rating",
36 | "RANKING_METHOD_LLM": "title_250_tokens",
37 | "NUM_SUMMARIES_THRESHOLD": 20,
38 | "EXTRACT_BACKGROUND_URLS": True,
39 | }
40 |
41 | DEFAULT_REASONING_CONFIG = {
42 | "BASE_REASONING_MODEL_NAMES": ["gpt-4-1106-preview"],
43 | "BASE_REASONING_TEMPERATURE": 1.0,
44 | "BASE_REASONING_PROMPT_TEMPLATES": [
45 | [
46 | PROMPT_DICT["binary"]["scratch_pad"]["1"],
47 | PROMPT_DICT["binary"]["scratch_pad"]["2"],
48 | ],
49 | ],
50 | "ALIGNMENT_MODEL_NAME": "gpt-3.5-turbo-1106",
51 | "ALIGNMENT_TEMPERATURE": 0,
52 | "ALIGNMENT_PROMPT": PROMPT_DICT["alignment"]["0"],
53 | "AGGREGATION_METHOD": "meta",
54 | "AGGREGATION_PROMPT_TEMPLATE": PROMPT_DICT["meta_reasoning"]["0"],
55 | "AGGREGATION_TEMPERATURE": 0.2,
56 | "AGGREGATION_MODEL_NAME": "gpt-4",
57 | "AGGREGATION_WEIGTHTS": None,
58 | }
59 |
60 | CHARS_PER_TOKEN = 4
61 |
62 | S3 = db_utils.initialize_s3_client(keys["AWS_ACCESS_KEY"], keys["AWS_SECRET_KEY"])
63 | S3_BUCKET_NAME = "my-forecasting-bucket"
64 |
65 | MODEL_TOKEN_LIMITS = {
66 | "claude-2.1": 200000,
67 | "claude-2": 100000,
68 | "claude-3-opus-20240229": 200000,
69 | "claude-3-sonnet-20240229": 200000,
70 | "gpt-4": 8000,
71 | "gpt-3.5-turbo-1106": 16000,
72 | "gpt-3.5-turbo-16k": 16000,
73 | "gpt-3.5-turbo": 8000,
74 | "gpt-4-1106-preview": 128000,
75 | "gemini-pro": 30720,
76 | "togethercomputer/llama-2-7b-chat": 4096,
77 | "togethercomputer/llama-2-13b-chat": 4096,
78 | "togethercomputer/llama-2-70b-chat": 4096,
79 | "togethercomputer/StripedHyena-Hessian-7B": 32768,
80 | "togethercomputer/LLaMA-2-7B-32K": 32768,
81 | "mistralai/Mistral-7B-Instruct-v0.2": 32768,
82 | "mistralai/Mixtral-8x7B-Instruct-v0.1": 32768,
83 | "zero-one-ai/Yi-34B-Chat": 4096,
84 | "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": 32768,
85 | "NousResearch/Nous-Hermes-2-Yi-34B": 32768,
86 | }
87 |
88 | MODEL_NAME_TO_SOURCE = {
89 | "claude-2.1": ANTHROPIC_SOURCE,
90 | "claude-2": ANTHROPIC_SOURCE,
91 | "claude-3-opus-20240229": ANTHROPIC_SOURCE,
92 | "claude-3-sonnet-20240229": ANTHROPIC_SOURCE,
93 | "gpt-4": OAI_SOURCE,
94 | "gpt-3.5-turbo-1106": OAI_SOURCE,
95 | "gpt-3.5-turbo-16k": OAI_SOURCE,
96 | "gpt-3.5-turbo": OAI_SOURCE,
97 | "gpt-4-1106-preview": OAI_SOURCE,
98 | "gemini-pro": GOOGLE_SOURCE,
99 | "togethercomputer/llama-2-7b-chat": TOGETHER_AI_SOURCE,
100 | "togethercomputer/llama-2-13b-chat": TOGETHER_AI_SOURCE,
101 | "togethercomputer/llama-2-70b-chat": TOGETHER_AI_SOURCE,
102 | "togethercomputer/LLaMA-2-7B-32K": TOGETHER_AI_SOURCE,
103 | "togethercomputer/StripedHyena-Hessian-7B": TOGETHER_AI_SOURCE,
104 | "mistralai/Mistral-7B-Instruct-v0.2": TOGETHER_AI_SOURCE,
105 | "mistralai/Mixtral-8x7B-Instruct-v0.1": TOGETHER_AI_SOURCE,
106 | "zero-one-ai/Yi-34B-Chat": TOGETHER_AI_SOURCE,
107 | "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": TOGETHER_AI_SOURCE,
108 | "NousResearch/Nous-Hermes-2-Yi-34B": TOGETHER_AI_SOURCE,
109 | }
110 |
111 | ANTHROPIC_RATE_LIMIT = 5
112 |
113 | IRRETRIEVABLE_SITES = [
114 | "wsj.com",
115 | "english.alarabiya.net",
116 | "consilium.europa.eu",
117 | "abc.net.au",
118 | "thehill.com",
119 | "democracynow.org",
120 | "fifa.com",
121 | "si.com",
122 | "aa.com.tr",
123 | "thestreet.com",
124 | "newsweek.com",
125 | "spokesman.com",
126 | "aninews.in",
127 | "commonslibrary.parliament.uk",
128 | "cybernews.com",
129 | "lineups.com",
130 | "expressnews.com",
131 | "news-herald.com",
132 | "c-span.org/video",
133 | "investors.com",
134 | "finance.yahoo.com", # This site has a “read more” button.
135 | "metaculus.com", # newspaper4k cannot parse metaculus pages well
136 | "houstonchronicle.com",
137 | "unrwa.org",
138 | "njspotlightnews.org",
139 | "crisisgroup.org",
140 | "vanguardngr.com", # protected by Cloudflare
141 | "ahram.org.eg", # protected by Cloudflare
142 | "reuters.com", # blocked by Javascript and CAPTCHA
143 | "carnegieendowment.org",
144 | "casino.org",
145 | "legalsportsreport.com",
146 | "thehockeynews.com",
147 | "yna.co.kr",
148 | "carrefour.com",
149 | "carnegieeurope.eu",
150 | "arabianbusiness.com",
151 | "inc.com",
152 | "joburg.org.za",
153 | "timesofindia.indiatimes.com",
154 | "seekingalpha.com",
155 | "producer.com", # protected by Cloudflare
156 | "oecd.org",
157 | "almayadeen.net", # protected by Cloudflare
158 | "manifold.markets", # prevent data contamination
159 | "goodjudgment.com", # prevent data contamination
160 | "infer-pub.com", # prevent data contamination
161 | "www.gjopen.com", # prevent data contamination
162 | "polymarket.com", # prevent data contamination
163 | "betting.betfair.com", # protected by Cloudflare
164 | "news.com.au", # blocks crawler
165 | "predictit.org", # prevent data contamination
166 | "atozsports.com",
167 | "barrons.com",
168 | "forex.com",
169 | "www.cnbc.com/quotes", # stock market data: prevent data contamination
170 | "montrealgazette.com",
171 | "bangkokpost.com",
172 | "editorandpublisher.com",
173 | "realcleardefense.com",
174 | "axios.com",
175 | "mensjournal.com",
176 | "warriormaven.com",
177 | "tapinto.net",
178 | "indianexpress.com",
179 | "science.org",
180 | "businessdesk.co.nz",
181 | "mmanews.com",
182 | "jdpower.com",
183 | "hrexchangenetwork.com",
184 | "arabnews.com",
185 | "nationalpost.com",
186 | "bizjournals.com",
187 | "thejakartapost.com",
188 | ]
189 |
190 | QUESTION_CATEGORIES = [
191 | "Science & Tech",
192 | "Healthcare & Biology",
193 | "Economics & Business",
194 | "Environment & Energy",
195 | "Politics & Governance",
196 | "Education & Research",
197 | "Arts & Recreation",
198 | "Security & Defense",
199 | "Social Sciences",
200 | "Sports",
201 | "Other",
202 | ]
203 |
204 | (
205 | METACULUS_PLATFORM,
206 | CSET_PLATFORM,
207 | GJOPEN_PLATFORM,
208 | MANIFOLD_PLATFORM,
209 | POLYMARKET_PLATFORM,
210 | ) = ("metaculus", "cset", "gjopen", "manifold", "polymarket")
211 |
212 | ALL_PLATFORMS = [
213 | METACULUS_PLATFORM,
214 | CSET_PLATFORM,
215 | GJOPEN_PLATFORM,
216 | MANIFOLD_PLATFORM,
217 | POLYMARKET_PLATFORM,
218 | ]
219 |
220 | END_WORDS_TO_PROBS_6 = {
221 | "No": 0.05,
222 | "Very Unlikely": 0.15,
223 | "Unlikely": 0.35,
224 | "Likely": 0.55,
225 | "Very Likely": 0.75,
226 | "Yes": 0.95,
227 | }
228 |
229 | END_WORDS_TO_PROBS_10 = {
230 | "No": 0.05,
231 | "Extremely Unlikely": 0.15,
232 | "Very Unlikely": 0.25,
233 | "Unlikely": 0.35,
234 | "Slightly Unlikely": 0.45,
235 | "Slightly Likely": 0.55,
236 | "Likely": 0.65,
237 | "Very Likely": 0.75,
238 | "Extremely Likely": 0.85,
239 | "Yes": 0.95,
240 | }
241 |
242 | TOKENS_TO_PROBS_DICT = {
243 | "six_options": END_WORDS_TO_PROBS_6,
244 | "ten_options": END_WORDS_TO_PROBS_10,
245 | }
--------------------------------------------------------------------------------
/llm_forecasting/utils/time_utils.py:
--------------------------------------------------------------------------------
1 | # Related third-party imports
2 | from datetime import datetime, timedelta
3 | import pandas as pd
4 | import math
5 |
6 |
7 | def extract_date(datetime):
8 | """
9 | Extract a date string from a datetime object or raw string.
10 |
11 | Args:
12 | datetime (datetime or str): A datetime object or string.
13 | If a string in the format 'YYYY-MM-DDTHH:MM:SSZ', the date part
14 | is extracted.
15 |
16 | Returns:
17 | str: A date string in the format 'YYYY-MM-DD'.
18 | """
19 | if isinstance(datetime, str):
20 | if "T" in datetime:
21 | return datetime.split("T")[0]
22 | else:
23 | return datetime
24 | else:
25 | return str(datetime.date())
26 |
27 |
28 | def convert_date_string_to_tuple(date_string):
29 | """
30 | Convert a date string of the form 'year-month-day' to a tuple (year, month, day).
31 |
32 | Args:
33 | date_string (str): A string representing the date in 'year-month-day' format.
34 |
35 | Returns:
36 | tuple: A tuple containing the year, month, and day as integers.
37 | """
38 | # Split the date string by '-'
39 | parts = date_string.split("-")
40 | # Check that the date string is in the correct format
41 | assert len(parts) == 3, "Date string must be in 'year-month-day' format."
42 | # Convert the parts to integers and return as a tuple
43 | return tuple(map(int, parts))
44 |
45 |
46 | def safe_to_datetime(date_str):
47 | """
48 | Safely convert a date string to a datetime object.
49 |
50 | Args:
51 | date_str (str): Date string to be converted.
52 |
53 | Returns:
54 | datetime: Converted datetime object or None if conversion fails.
55 | """
56 | try:
57 | return pd.to_datetime(date_str.replace("Z", "+00:00"))
58 | except pd.errors.OutOfBoundsDatetime:
59 | return None
60 |
61 |
62 | def move_date_by_percentage(date_str1, date_str2, percentage):
63 | """
64 | Compute a date that is a specified percentage between two dates.
65 |
66 | Returns the date before |date_str2| if the computed date is equal
67 | to |date_str2|.
68 |
69 | Args:
70 | date_str1 (str): Start date in "YYYY-MM-DD" format.
71 | date_str2 (str): End date in "YYYY-MM-DD" format.
72 | percentage (float): Percentage to move from start date towards end date.
73 |
74 | Returns:
75 | str: The new date in "YYYY-MM-DD" format.
76 | """
77 | # Parse dates
78 | date1 = datetime.strptime(date_str1, "%Y-%m-%d")
79 | date2 = datetime.strptime(date_str2, "%Y-%m-%d")
80 |
81 | # Ensure date1 is earlier than date2
82 | if date1 > date2:
83 | date1, date2 = date2, date1
84 |
85 | # Calculate new date at the given percentage between them
86 | target_date = date1 + (date2 - date1) * (percentage / 100.0)
87 |
88 | # Check if target date is the same as date_str2, if so, subtract one day
89 | if target_date.strftime("%Y-%m-%d") == date_str2:
90 | target_date -= timedelta(days=1)
91 |
92 | return target_date.strftime("%Y-%m-%d")
93 |
94 |
95 | def adjust_date_by_days(date_str, days_to_adjust):
96 | """
97 | Adjust a date string by a specified number of days.
98 |
99 | Args:
100 | date_str (str): A date string in the format 'YYYY-MM-DD'.
101 | days_to_adjust (int): The number of days to adjust the date by. Can be
102 | positive or negative.
103 |
104 | Returns:
105 | str: A new date string in the format 'YYYY-MM-DD' adjusted by the
106 | specified number of days.
107 | """
108 | # Parse the date string into a datetime object
109 | date_obj = datetime.strptime(date_str, "%Y-%m-%d")
110 |
111 | # Adjust the date by the given number of days
112 | adjusted_date = date_obj + datetime.timedelta(days=days_to_adjust)
113 |
114 | # Convert the adjusted datetime object back into a string
115 | new_date_str = adjusted_date.strftime("%Y-%m-%d")
116 |
117 | return new_date_str
118 |
119 |
120 | def convert_timestamp(timestamp):
121 | """
122 | Convert a numeric timestamp into a formatted date string.
123 |
124 | This function checks if the given timestamp is in milliseconds or seconds,
125 | and converts it to a date string in the 'YYYY-MM-DD' format. It assumes that
126 | timestamps are in milliseconds if they are greater than 1e10, which
127 | typically corresponds to dates after the year 2001.
128 |
129 | Args:
130 | timestamp (float or int): The timestamp to convert. Can be in seconds
131 | or milliseconds.
132 |
133 | Returns:
134 | str: The converted timestamp as a date string in 'YYYY-MM-DD' format.
135 | """
136 | # Identify if the timestamp is in milliseconds or seconds
137 | timestamp = float(timestamp)
138 | is_millisecond = timestamp > 1e10 # assuming data is from after 2001
139 | if is_millisecond:
140 | timestamp = timestamp / 1000 # Convert from milliseconds to seconds
141 |
142 | # Convert to formatted date string
143 | return datetime.utcfromtimestamp(timestamp).strftime("%Y-%m-%d")
144 |
145 |
146 | def is_more_recent(first_date_str, second_date_str, or_equal_to=False):
147 | """
148 | Determine if |second_date_str| is more recent than |first_date_str|.
149 |
150 | Args:
151 | first_date_str (str): A string representing the first date to compare against. Expected format: 'YYYY-MM-DD'.
152 | second_date_str (str): A string representing the second date. Expected format: 'YYYY-MM-DD'.
153 |
154 | Returns:
155 | bool: True if the second date is more recent than the first date, False otherwise.
156 | """
157 | first_date_obj = datetime.strptime(first_date_str, "%Y-%m-%d")
158 | second_date_obj = datetime.strptime(second_date_str, "%Y-%m-%d")
159 | if or_equal_to:
160 | return second_date_obj >= first_date_obj
161 | return second_date_obj > first_date_obj
162 |
163 |
164 | def is_less_than_N_days_apart(date1_str, date2_str, N=3):
165 | """
166 | Check if the difference between two dates is less than N days.
167 |
168 | :param date1_str: First date as a string in 'YYYY-MM-DD' format.
169 | :param date2_str: Second date as a string in 'YYYY-MM-DD' format.
170 | :param N: Number of days for comparison.
171 | :return: True if the difference is less than N days, otherwise False.
172 | """
173 | date_obj1 = datetime.strptime(date1_str, "%Y-%m-%d")
174 | date_obj2 = datetime.strptime(date2_str, "%Y-%m-%d")
175 | return (date_obj2 - date_obj1) < timedelta(days=N)
176 |
177 |
178 | def find_pred_with_closest_date(date_str, date_pred_list):
179 | """
180 | Finds the tuple with the date closest to the given reference date from a list of (date, prediction) tuples.
181 |
182 | Parameters:
183 | - date_str (str): Reference date in 'YYYY-MM-DD' format.
184 | - date_pred_list (list of tuples): Each tuple contains a date string in 'YYYY-MM-DD' format and an associated prediction.
185 |
186 | Returns:
187 | - tuple: The tuple with the closest date to date_str. Returns None if date_pred_list is empty.
188 |
189 | Raises:
190 | - ValueError: If date_str or dates in date_pred_list are not in the correct format.
191 | """
192 | # Convert the reference date string to a datetime object
193 | ref_date = datetime.strptime(date_str, "%Y-%m-%d")
194 |
195 | # Initialize variables to store the closest date and its difference
196 | closest_date = None
197 | min_diff = float("inf")
198 |
199 | # Iterate through the list of tuples
200 | for date_tuple in date_pred_list:
201 | # Convert the date string in the tuple to a datetime object
202 | current_date = datetime.strptime(date_tuple[0], "%Y-%m-%d")
203 |
204 | # Calculate the absolute difference in days
205 | diff = abs((current_date - ref_date).days)
206 |
207 | # Update the closest date and min_diff if this date is closer
208 | if diff < min_diff:
209 | min_diff = diff
210 | closest_date = date_tuple
211 |
212 | return closest_date
213 |
214 |
215 | def get_retrieval_date(
216 | retrieval_index, num_retrievals, date_begin, date_close, resolve_date
217 | ):
218 | """
219 | Calculate a specific retrieval date within a given time range.
220 |
221 | The retrieval date is determined using an exponential distribution based on the total number of retrievals.
222 | If the calculated date is after the resolve date, None is returned.
223 |
224 | Args:
225 | retrieval_index (int): Index of the current retrieval (starting from 0).
226 | num_retrievals (int): Total number of retrievals planned within the date range.
227 | date_begin (str): Start date of the range in 'YYYY-MM-DD' format.
228 | date_close (str): End date of the range in 'YYYY-MM-DD' format.
229 | resolve_date (str): Date by which the retrieval should be resolved in 'YYYY-MM-DD' format.
230 |
231 | Returns:
232 | str or None: The calculated retrieval date in 'YYYY-MM-DD' format, or None if it falls after the resolve date.
233 | """
234 | date_begin_obj = datetime.strptime(date_begin, "%Y-%m-%d")
235 | date_close_obj = datetime.strptime(date_close, "%Y-%m-%d")
236 | resolve_date_obj = datetime.strptime(resolve_date, "%Y-%m-%d")
237 |
238 | # Early return if date range is invalid or reversed
239 | if date_begin_obj >= date_close_obj or date_begin_obj > resolve_date_obj:
240 | return None
241 |
242 | # Calculate the total number of days in the range
243 | total_days = (date_close_obj - date_begin_obj).days
244 |
245 | # Calculate the retrieval date
246 | retrieval_days = math.exp((math.log(total_days) / num_retrievals) * retrieval_index)
247 | retrieval_date_obj = date_begin_obj + timedelta(days=retrieval_days)
248 |
249 | if retrieval_date_obj >= date_close_obj:
250 | retrieval_date_obj = date_close_obj - timedelta(days=1)
251 | if retrieval_date_obj >= resolve_date_obj:
252 | return None
253 |
254 | # Check against the previous retrieval date
255 | if retrieval_index > 1:
256 | previous_days = math.exp(
257 | (math.log(total_days) / num_retrievals) * (retrieval_index - 1)
258 | )
259 | previous_date_obj = date_begin_obj + timedelta(days=previous_days)
260 | if retrieval_date_obj <= previous_date_obj:
261 | return None
262 |
263 | return retrieval_date_obj.strftime("%Y-%m-%d")
264 |
--------------------------------------------------------------------------------
/llm_forecasting/evaluation.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | import logging
3 |
4 | # Local application/library-specific imports
5 | import alignment
6 | from config.constants import (
7 | DEFAULT_RETRIEVAL_CONFIG,
8 | DEFAULT_REASONING_CONFIG,
9 | S3,
10 | S3_BUCKET_NAME,
11 | )
12 | import ensemble
13 | import ranking
14 | import summarize
15 | from utils import db_utils
16 |
17 | # Set up logging
18 | logging.basicConfig(level=logging.INFO)
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | def to_eval(question, retrieval_number, output_dir):
23 | """
24 | Determines if a given question needs evaluation based on its presence in an S3 storage.
25 |
26 | Args:
27 | question (str): The question to be evaluated.
28 | retrieval_number (int): A unique identifier for the retrieval process.
29 | output_dir (str): The directory path where the file is expected to be found.
30 | file_type (str, optional): The type of the file. Default is "pickle".
31 |
32 | Returns:
33 | bool: True if the question needs evaluation, False otherwise.
34 | """
35 | question_formatted = question.replace(" ", "_").replace("/", "")
36 | file_name = f"{output_dir}/{retrieval_number}/{question_formatted}.pickle"
37 | try:
38 | # Try reading the file from S3
39 | _ = db_utils.read_pickle_from_s3(S3, S3_BUCKET_NAME, file_name)
40 | return False
41 | except BaseException:
42 | # If the file is not found in S3, return True to indicate it needs evaluation
43 | return True
44 |
45 |
46 | def save_results(save_dict, question, retrieval_number, output_dir, file_type="pickle"):
47 | """
48 | Save a dictionary of results to an S3 storage, formatted based on a specific question.
49 |
50 | Args:
51 | save_dict (dict): The dictionary of results to be saved.
52 | question (str): The title of the question related to the results.
53 | retrieval_number (int): A unique identifier for the retrieval process.
54 | output_dir (str): The directory path where the file will be saved.
55 | file_type (str, optional): The type of the file to save. Default is "pickle".
56 | """
57 | question_formatted = question.replace(" ", "_").replace("/", "")
58 | file_name = f"{output_dir}/{retrieval_number}/{question_formatted}.{file_type}"
59 | db_utils.upload_data_structure_to_s3(S3, save_dict, S3_BUCKET_NAME, file_name)
60 |
61 |
62 | async def retrieve_and_forecast(
63 | question_dict,
64 | question_raw,
65 | ir_config=DEFAULT_RETRIEVAL_CONFIG,
66 | reason_config=DEFAULT_REASONING_CONFIG,
67 | calculate_alignment=False,
68 | return_articles=False,
69 | ):
70 | """
71 | Asynchronously evaluates the forecasting question using the end-to-end system.
72 |
73 | This function integrates steps of information retrieval, summarization, reasoning,
74 | and alignment scoring.
75 |
76 | Args:
77 | question_dict (dict): A dictionary containing detailed data about the question.
78 | question_raw (dict): The raw question data.
79 | ir_config (dict, optional): The configuration for information retrieval.
80 | Defaults to DEFAULT_RETRIEVAL_CONFIG.
81 | reason_config (dict, optional): The configuration for reasoning.
82 | Defaults to DEFAULT_REASONING_CONFIG.
83 | calculate_alignment (bool, optional): Flag to determine if alignment scores
84 | should be calculated. Defaults to False.
85 | return_articles (bool, optional): Flag to decide if the retrieved articles
86 | should be returned. Defaults to False.
87 |
88 | Returns:
89 | dict: The reasoning for the questions, along with metadata and intermediate
90 | outputs from the system. If return_articles is True, also includes retrieved
91 | articles. Returns None if the question does not meet evaluation criteria.
92 | """
93 | assert isinstance(
94 | reason_config["BASE_REASONING_PROMPT_TEMPLATES"], list
95 | ), "BASE_REASONING_PROMPT_TEMPLATES must be a list."
96 | assert len(reason_config["BASE_REASONING_PROMPT_TEMPLATES"]) == len(
97 | reason_config["BASE_REASONING_MODEL_NAMES"]
98 | ), "BASE_REASONING_PROMPT_TEMPLATES and BASE_REASONING_MODEL_NAMES must have the same length."
99 |
100 | question = question_dict["question"]
101 | background_info = question_dict["background"]
102 | resolution_criteria = question_dict["resolution_criteria"]
103 | answer = question_dict["answer"]
104 | question_dates = question_dict["question_dates"]
105 | retrieval_dates = question_dict["retrieval_dates"]
106 | urls_in_background = question_dict["urls_in_background"]
107 | # Information retrieval
108 | try:
109 | (
110 | ranked_articles,
111 | all_articles,
112 | search_queries_list_gnews,
113 | search_queries_list_nc,
114 | ) = await ranking.retrieve_summarize_and_rank_articles(
115 | question,
116 | background_info,
117 | resolution_criteria,
118 | retrieval_dates,
119 | urls=urls_in_background,
120 | config=ir_config,
121 | return_intermediates=True,
122 | )
123 | except Exception as e: # skip the question if failed
124 | logger.error(f"Error message: {e}")
125 | logger.info(f"IR failed at question: {question}.")
126 | return None
127 | subset_ranked_articles = ranked_articles[
128 | : ir_config.get("NUM_SUMMARIES_THRESHOLD", 100)
129 | ].copy()
130 | all_summaries = summarize.concat_summaries(subset_ranked_articles)
131 | logger.info(f"Information retrieval complete for question: {question}.")
132 | logger.info(f"Number of summaries: {len(subset_ranked_articles)}.")
133 |
134 | # Reasoning (using ensemble)
135 | today_to_close_date = [retrieval_dates[1], question_dates[1]]
136 | ensemble_dict = await ensemble.meta_reason(
137 | question=question,
138 | background_info=background_info,
139 | resolution_criteria=resolution_criteria,
140 | today_to_close_date_range=today_to_close_date,
141 | retrieved_info=all_summaries,
142 | reasoning_prompt_templates=reason_config["BASE_REASONING_PROMPT_TEMPLATES"],
143 | base_model_names=reason_config["BASE_REASONING_MODEL_NAMES"],
144 | base_temperature=reason_config["BASE_REASONING_TEMPERATURE"],
145 | aggregation_method=reason_config["AGGREGATION_METHOD"],
146 | answer_type="probability",
147 | weights=reason_config["AGGREGATION_WEIGTHTS"],
148 | meta_model_name=reason_config["AGGREGATION_MODEL_NAME"],
149 | meta_prompt_template=reason_config["AGGREGATION_PROMPT_TEMPLATE"],
150 | meta_temperature=reason_config["AGGREGATION_TEMPERATURE"],
151 | )
152 |
153 | alignment_scores = None
154 | if calculate_alignment:
155 | alignment_scores = alignment.get_alignment_scores(
156 | ensemble_dict["base_reasonings"],
157 | alignment_prompt=reason_config["ALIGNMENT_PROMPT"],
158 | model_name=reason_config["ALIGNMENT_MODEL_NAME"],
159 | temperature=reason_config["ALIGNMENT_TEMPERATURE"],
160 | question=question,
161 | background=background_info,
162 | resolution_criteria=resolution_criteria,
163 | )
164 |
165 | # Compute brier score (base_predictions is a list of lists of
166 | # probabilities)
167 | base_brier_scores = []
168 | # For each sublist (corresponding to a base model name)
169 | for base_predictions in ensemble_dict["base_predictions"]:
170 | base_brier_scores.append(
171 | [(base_prediction - answer) ** 2 for base_prediction in base_predictions]
172 | )
173 | # Visualization (draw the HTML)
174 | base_html = visualize_utils.visualize_all(
175 | question_data=question_raw,
176 | retrieval_dates=retrieval_dates,
177 | search_queries_gnews=search_queries_list_gnews,
178 | search_queries_nc=search_queries_list_nc,
179 | all_articles=all_articles,
180 | ranked_articles=ranked_articles,
181 | all_summaries=all_summaries,
182 | model_names=reason_config["BASE_REASONING_MODEL_NAMES"],
183 | base_reasoning_prompt_templates=reason_config[
184 | "BASE_REASONING_PROMPT_TEMPLATES"
185 | ],
186 | base_reasoning_full_prompts=ensemble_dict["base_reasoning_full_prompts"],
187 | base_reasonings=ensemble_dict["base_reasonings"],
188 | base_predictions=ensemble_dict["base_predictions"],
189 | base_brier_scores=base_brier_scores,
190 | )
191 | meta_html = visualize_utils.visualize_all_ensemble(
192 | question_data=question_raw,
193 | ranked_articles=ranked_articles,
194 | all_articles=all_articles,
195 | search_queries_gnews=search_queries_list_gnews,
196 | search_queries_nc=search_queries_list_nc,
197 | retrieval_dates=retrieval_dates,
198 | meta_reasoning=ensemble_dict["meta_reasoning"],
199 | meta_full_prompt=ensemble_dict["meta_prompt"],
200 | meta_prediction=ensemble_dict["meta_prediction"],
201 | )
202 | # Generate outputs, one dict per question
203 | output = {
204 | "question": question,
205 | "answer": int(answer),
206 | "data_source": question_raw["data_source"],
207 | "background_info": background_info,
208 | "resolution_criteria": resolution_criteria,
209 | "retrieval_dates": retrieval_dates,
210 | "search_queries_gnews": search_queries_list_gnews,
211 | "search_queries_nc": search_queries_list_nc,
212 | "retrieved_info": all_summaries,
213 | "base_model_names": reason_config["BASE_REASONING_MODEL_NAMES"],
214 | "base_reasoning_full_prompts": ensemble_dict["base_reasoning_full_prompts"],
215 | "base_reasonings": ensemble_dict["base_reasonings"],
216 | "base_predictions": ensemble_dict["base_predictions"],
217 | "alignment_scores": alignment_scores,
218 | "meta_reasoning_full_prompt": ensemble_dict["meta_prompt"],
219 | "meta_reasoning": ensemble_dict["meta_reasoning"],
220 | "meta_prediction": ensemble_dict["meta_prediction"],
221 | "community_prediction": question_dict["community_pred_at_retrieval"],
222 | "base_html": base_html,
223 | "meta_html": meta_html,
224 | "base_brier_score": base_brier_scores,
225 | "meta_brier_score": (ensemble_dict["meta_prediction"] - answer) ** 2,
226 | "community_brier_score": (question_dict["community_pred_at_retrieval"] - answer)
227 | ** 2,
228 | }
229 | if return_articles:
230 | return output, all_articles, ranked_articles
231 | return output
232 |
--------------------------------------------------------------------------------
/llm_forecasting/prompts/search_query.py:
--------------------------------------------------------------------------------
1 | SEARCH_QUERY_PROMPT_0 = (
2 | """I will provide you with a forecasting question and the background information for the question. I will then ask you to generate short search queries (up to {max_words} words each) that I'll use to find articles on Google News to help answer the question.
3 |
4 | Question:
5 | {question}
6 |
7 | Question Background:
8 | {background}
9 |
10 | Today's date: {date_begin}
11 | Question close date: {date_end}
12 |
13 | You must generate this exact amount of queries: {num_keywords}
14 |
15 | Start off by writing down sub-questions. Then use your sub-questions to help steer the search queries you produce.
16 |
17 | Your response should take the following structure:
18 | Thoughts:
19 | {{ Insert your thinking here. }}
20 | Search Queries:
21 | {{ Insert the queries here. Use semicolons to separate the queries. }}""",
22 | (
23 | "QUESTION",
24 | "BACKGROUND",
25 | "DATES",
26 | "NUM_KEYWORDS",
27 | "MAX_WORDS",
28 | ),
29 | )
30 |
31 | SEARCH_QUERY_PROMPT_1 = (
32 | """I will provide you with a forecasting question and the background information for the question.
33 |
34 | Question:
35 | {question}
36 |
37 | Question Background:
38 | {background}
39 |
40 | Today's date: {date_begin}
41 | Question close date: {date_end}
42 |
43 | Task:
44 | - Generate brief search queries (up to {max_words} words each) to gather information on Google that could influence the forecast.
45 |
46 | You must generate this exact amount of queries: {num_keywords}
47 |
48 | Your response should take the following structure:
49 | Thoughts:
50 | {{ Insert your thinking here. }}
51 | Search Queries:
52 | {{ Insert the queries here. Use semicolons to separate the queries. }}""",
53 | (
54 | "QUESTION",
55 | "BACKGROUND",
56 | "DATES",
57 | "NUM_KEYWORDS",
58 | "MAX_WORDS",
59 | ),
60 | )
61 |
62 | SEARCH_QUERY_PROMPT_2 = (
63 | """In this task, I will present a forecasting question along with relevant background information. Your goal is to create {num_keywords} concise search queries (up to {max_words} words each) to gather information that could influence the forecast. Consider different angles and aspects that might impact the outcome.
64 |
65 | Question:
66 | {question}
67 |
68 | Question Background:
69 | {background}
70 |
71 | Today's date: {date_begin}
72 | Question close date: {date_end}
73 |
74 | Now, generate {num_keywords} short search queries to search for information on Google News.
75 | You must generate this exact amount of queries: {num_keywords}.
76 | When formulating your search queries, think about various factors that could affect the forecast, such as recent trends, historical data, or external influences.
77 |
78 | Your response should take the following structure:
79 | Thoughts:
80 | {{ Insert your thinking here. }}
81 | Search Queries:
82 | {{ Insert the queries here. Use semicolons to separate the queries. }}""",
83 | (
84 | "QUESTION",
85 | "BACKGROUND",
86 | "DATES",
87 | "NUM_KEYWORDS",
88 | "MAX_WORDS",
89 | ),
90 | )
91 |
92 | SEARCH_QUERY_PROMPT_3 = (
93 | """I will provide you with a forecasting question and the background information for the question. I will then ask you to generate {num_keywords} short search queries (up to {max_words} words each) that I'll use to find articles on Google News to help answer the question.
94 |
95 | Question:
96 | {question}
97 |
98 | Question Background:
99 | {background}
100 |
101 | Today's date: {date_begin}
102 | Question close date: {date_end}
103 |
104 | You must generate this exact amount of queries: {num_keywords}
105 |
106 | Your response should take the following structure:
107 | Thoughts:
108 | {{ Insert your thinking here. }}
109 | Search Queries:
110 | {{ Insert the queries here. Use semicolons to separate the queries. }}""",
111 | (
112 | "QUESTION",
113 | "BACKGROUND",
114 | "DATES",
115 | "NUM_KEYWORDS",
116 | "MAX_WORDS",
117 | ),
118 | )
119 |
120 | SEARCH_QUERY_PROMPT_4 = (
121 | """Generate short search queries (up to {max_words} words) for the forecasting question below.
122 |
123 | I will use them to query Google News for articles. These search queries should result in articles that help me make an informed prediction.
124 |
125 | Question:
126 | {question}
127 |
128 | Question Background:
129 | {background}
130 |
131 | Today's date: {date_begin}
132 | Question close date: {date_end}
133 |
134 | You must generate this exact amount of queries: {num_keywords}.
135 |
136 | Your response should take the following structure:
137 | Thoughts:
138 | {{ Insert your thinking here. }}
139 | Search Queries:
140 | {{ Insert the queries here. Use semicolons to separate the queries. }}""",
141 | (
142 | "QUESTION",
143 | "BACKGROUND",
144 | "DATES",
145 | "NUM_KEYWORDS",
146 | "MAX_WORDS",
147 | ),
148 | )
149 |
150 | SEARCH_QUERY_PROMPT_5 = (
151 | """
152 | Please provide {num_keywords} search queries to input into Google to help me research this forecasting question:
153 |
154 | Question: {question}
155 |
156 | Background: {background}
157 |
158 | Today's Date: {date_begin}
159 | Close Date: {date_end}
160 |
161 | Guidelines:
162 | - Include terms related to influential factors that could sway the outcome.
163 | - Use different keyword approaches to get balanced perspectives.
164 | - Each search query should be up to {max_words} words.
165 |
166 | You must generate this exact amount of queries: {num_keywords}.
167 |
168 | Your response should take the following structure:
169 | Thoughts:
170 | {{ Insert your thinking here. }}
171 | Search Queries:
172 | {{ Insert the queries here. Use semicolons to separate the queries. }}""",
173 | (
174 | "QUESTION",
175 | "BACKGROUND",
176 | "DATES",
177 | "NUM_KEYWORDS",
178 | "MAX_WORDS",
179 | ),
180 | )
181 |
182 | SEARCH_QUERY_PROMPT_6 = (
183 | """
184 | In this task, you will receive a forecasting question along with its background information. Your objective is to create {num_keywords} targeted search queries, each not exceeding {max_words} words, to unearth information that could shape the forecast.
185 |
186 | Question:
187 | {question}
188 |
189 | Background:
190 | {background}
191 |
192 | Current Date:
193 | {date_begin}
194 | Question Close Date:
195 | {date_end}
196 |
197 | Your job is to formulate {num_keywords} distinct and concise search queries. These queries will be used to query Google News to capture diverse perspectives and relevant data from various sources. Think about different elements that could influence the outcome.
198 |
199 | Structure your response as follows:
200 | Thoughts:
201 | {{ Insert your thinking here. }}
202 | Search Queries:
203 | {{ Insert the queries here. Use semicolons to separate the queries. }}""",
204 | (
205 | "QUESTION",
206 | "BACKGROUND",
207 | "DATES",
208 | "NUM_KEYWORDS",
209 | "MAX_WORDS",
210 | ),
211 | )
212 |
213 | SEARCH_QUERY_PROMPT_7 = (
214 | """
215 | In this task, I will present a forecasting question along with relevant background information. Your goal is to create {num_keywords} concise search queries (up to {max_words} words each) to gather information on Google that could influence the forecast.
216 |
217 | Question:
218 | {question}
219 |
220 | Question Background:
221 | {background}
222 |
223 | Today's date: {date_begin}
224 | Question close date: {date_end}
225 |
226 | Now, generate {num_keywords} short search queries to search for information on Google News.
227 | Begin by formulating sub-questions related to the main question. Use these sub-questions to guide the creation of your search queries.
228 |
229 | Your response should take the following structure:
230 | Thoughts:
231 | {{ Insert your thinking here. }}
232 | Search Queries:
233 | {{ Insert the queries here. Use semicolons to separate the queries. }}""",
234 | (
235 | "QUESTION",
236 | "BACKGROUND",
237 | "DATES",
238 | "NUM_KEYWORDS",
239 | "MAX_WORDS",
240 | ),
241 | )
242 |
243 | SEARCH_QUERY_PROMPT_8 = (
244 | """In this task, I will present a forecasting question along with relevant background information. Your goal is to create {num_keywords} concise search queries (up to {max_words} words each) to gather information that could influence the forecast. Consider different angles and aspects that might impact the outcome.
245 |
246 | Question:
247 | {question}
248 |
249 | Question Background:
250 | {background}
251 |
252 | Today's date: {date_begin}
253 | Question close date: {date_end}
254 |
255 | Now, generate {num_keywords} short search queries to search for information on Google News.
256 |
257 | Your response should take the following structure:
258 | Thoughts:
259 | {{ Insert your thinking here. }}
260 | Search Queries:
261 | {{ Insert the queries here. Use semicolons to separate the queries. }}""",
262 | (
263 | "QUESTION",
264 | "BACKGROUND",
265 | "DATES",
266 | "NUM_KEYWORDS",
267 | "MAX_WORDS",
268 | ),
269 | )
270 |
271 | # To be evaluated
272 | SEARCH_QUERY_PROMPT_NO_DATE_0 = (
273 | """Generate {num_keywords} search queries (up to {max_words} words each) for the forecasting question below.
274 |
275 | I will use them to query Google News for articles. These search queries should result in articles that help me make an informed prediction.
276 |
277 | ---
278 | Question:
279 | {question}
280 | ---
281 |
282 | ---
283 | Question Background:
284 | {background}
285 | ---
286 |
287 | Your response should take the following structure:
288 |
289 | Thoughts:
290 | {{ insert your thinking here }}
291 |
292 | Search Queries:
293 | {{ Insert the queries here. Use semicolons to separate the queries. }}""",
294 | (
295 | "QUESTION",
296 | "BACKGROUND",
297 | "NUM_KEYWORDS",
298 | "MAX_WORDS",
299 | ),
300 | )
301 |
302 |
303 | # To be evaluated
304 | SEARCH_QUERY_PROMPT_NO_DATE_1 = (
305 | """I will give you a forecasting question and its background information.
306 |
307 | Your goal is to generate {num_keywords} search queries (up to {max_words} words each).
308 | The search queries wil be used to query Google News for articles.
309 | They should result in articles that help me make an informed prediction.
310 |
311 | ---
312 | Question:
313 | {question}
314 | ---
315 |
316 | ---
317 | Question Background:
318 | {background}
319 | ---
320 |
321 | Your response should take the following structure:
322 |
323 | Thoughts:
324 | {{ insert your thinking here }}
325 |
326 | Search Queries:
327 | {{ Insert the queries here. Use semicolons to separate the queries. }}""",
328 | (
329 | "QUESTION",
330 | "BACKGROUND",
331 | "NUM_KEYWORDS",
332 | "MAX_WORDS",
333 | ),
334 | )
335 |
--------------------------------------------------------------------------------
/scripts/data_scraping/polymarket.py:
--------------------------------------------------------------------------------
1 | # run this script with a venv with python 3.9 so it's compatible with
2 | # py_clob_client (polymarket API python client)
3 |
4 | # Standard library imports
5 | from concurrent.futures import ThreadPoolExecutor
6 | from datetime import datetime, timedelta
7 | import argparse
8 | import ast
9 | import logging
10 | import time
11 |
12 | # Related third-party imports
13 | import requests
14 | from tqdm import tqdm
15 |
16 | # Local application/library-specific imports
17 | import data_scraping
18 | import information_retrieval
19 | from config.keys import keys
20 | from py_clob_client.client import ClobClient
21 | from py_clob_client.constants import POLYGON
22 |
23 | # Setup logging and other configurations
24 | logger = logging.getLogger(__name__)
25 |
26 | client = ClobClient(
27 | "https://clob.polymarket.com", key=keys["CRYPTO_PRIVATE_KEY"], chain_id=POLYGON
28 | )
29 |
30 |
31 | def get_market_query(offset_value):
32 | """
33 | Construct and return a GraphQL query string for fetching market data.
34 |
35 | Args:
36 | offset_value (int): The offset value to use in the query for pagination.
37 |
38 | Returns:
39 | str: A GraphQL query string.
40 | """
41 | query = f"""{{
42 | markets(offset: {offset_value}) {{
43 | id
44 | conditionId
45 | question
46 | description
47 | category
48 | createdAt
49 | questionID
50 | outcomes
51 | outcomePrices
52 | clobTokenIds
53 | endDate
54 | volume
55 | closed
56 | }}
57 | }}"""
58 | return query
59 |
60 |
61 | def get_comment_query(market_id, offset_value):
62 | """
63 | Construct and return a GraphQL query string for fetching comments related
64 | to a specific market ID with pagination.
65 |
66 | Args:
67 | market_id (int): The unique identifier of the market.
68 | offset_value (int): The offset value to use for pagination.
69 |
70 | Returns:
71 | str: A GraphQL query string for fetching comments.
72 | """
73 | query = f"""{{
74 | comments(marketID: {market_id}, offset: {offset_value}) {{
75 | id
76 | body
77 | createdAt
78 | }}
79 | }}"""
80 | return query
81 |
82 |
83 | def generate_json_markets(field, market_id=None):
84 | """
85 | Perform a POST request to retrieve market or comment data from the
86 | Polymarket API.
87 |
88 | Args:
89 | field (str): Specifies the type of data to fetch ('markets' or 'comments').
90 | market_id (int, optional): The market ID for which comments are to be fetched.
91 | Required if 'field' is 'comments'.
92 |
93 | Returns:
94 | list: A list of dictionaries containing market or comment data.
95 | """
96 | url = "https://gamma-api.polymarket.com/query"
97 | offset_value = 0
98 | all_data = []
99 |
100 | while True:
101 | if field == "comments":
102 | query = get_comment_query(market_id, offset_value)
103 | elif field == "markets":
104 | query = get_market_query(offset_value)
105 | else:
106 | print("Wrong field name!")
107 |
108 | try:
109 | response = requests.post(url, json={"query": query})
110 | # Will raise an HTTPError if the HTTP request returned an
111 | # unsuccessful status code
112 | response.raise_for_status()
113 | data = response.json().get("data", {}).get(field, [])
114 |
115 | if not data:
116 | break # Exit loop if no more markets are returned
117 |
118 | all_data.extend(data)
119 | offset_value += 1000
120 |
121 | except requests.exceptions.RequestException as e:
122 | print(f"Request failed: {e}")
123 | break # Break the loop in case of request failure
124 |
125 | return all_data
126 |
127 |
128 | def question_to_url(question, base_url="https://polymarket.com/event/"):
129 | """
130 | Convert a Polymarket question into a URL format.
131 |
132 | Args:
133 | question (str): The question text to be formatted.
134 | base_url (str, optional): The base URL to prepend to the formatted question.
135 |
136 | Returns:
137 | str: The formatted URL representing the Polymarket question.
138 | """
139 | cleaned_question = "".join(
140 | "" if char == "$" else char
141 | for char in question.strip()
142 | if char.isalnum() or char in [" ", "-", "$"]
143 | )
144 |
145 | # Replace spaces with hyphens and convert to lowercase
146 | url_formatted_question = cleaned_question.replace(" ", "-").lower()
147 |
148 | # Concatenate with the base URL
149 | url = base_url + url_formatted_question
150 | return url
151 |
152 |
153 | def fetch_price_history(market_id):
154 | """
155 | Retrieve the price history of a market from the Polymarket API.
156 |
157 | Args:
158 | market_id (str): The unique identifier of the market.
159 |
160 | Returns:
161 | list: A list of dictionaries containing the price history data, or an empty list
162 | if the data retrieval fails.
163 | """
164 | url = (
165 | f"https://clob.polymarket.com/prices-history?interval=all&market="
166 | f"{market_id}&fidelity=60"
167 | )
168 |
169 | response = requests.get(url)
170 |
171 | if response.status_code == 200:
172 | data = response.json()
173 | history_data = data.get("history", [])
174 | return history_data
175 | else:
176 | print("Failed to retrieve data:", response.status_code)
177 | return []
178 |
179 |
180 | def process_market(m):
181 | """
182 | Process a single market dictionary by adding additional information
183 | such as comments, URLs, community predictions, and other metadata.
184 | It also renames fields to align with a specific format.
185 |
186 | Args:
187 | market (dict): A dictionary representing a single market with its initial data.
188 |
189 | Returns:
190 | dict: The processed market dictionary with additional fields and formatted data.
191 | """
192 | m["comments"] = generate_json_markets("comments", market_id=int(m["id"]))
193 | m["url"] = question_to_url(m["question"])
194 |
195 | # Resolution
196 | try:
197 | m["outcomes"] = ast.literal_eval(m["outcomes"])
198 | m["outcomePrices"] = ast.literal_eval(m["outcomePrices"])
199 |
200 | # Make sure that 'outcomes' and 'outcomePrices' have the same length
201 | if len(m["outcomes"]) != len(m["outcomePrices"]):
202 | raise ValueError(
203 | "The lengths of 'outcomes' and 'outcomePrices' do not match."
204 | )
205 |
206 | # Find the outcome with the highest price
207 | highest_price = max(m["outcomePrices"], key=lambda x: float(x))
208 | highest_price_index = m["outcomePrices"].index(highest_price)
209 | resolution_outcome = m["outcomes"][highest_price_index]
210 |
211 | m["resolution"] = resolution_outcome
212 |
213 | if m["resolution"] == "Yes":
214 | m["resolution"] = 1
215 | elif m["resolution"] == "No":
216 | m["resolution"] = 0
217 | except Exception as e:
218 | print(f"An error occurred: {e}")
219 | m["resolution"] = "Error"
220 |
221 | # Question type
222 | m["question_type"] = "multiple_choice"
223 | if set(m["outcomes"]) == {"Yes", "No"}:
224 | m["question_type"] = "binary"
225 |
226 | # Community predictions for the first outcome
227 | try:
228 | if m["clobTokenIds"] is not None:
229 | # Attempt to fetch community predictions
230 | m["community_predictions"] = fetch_price_history(
231 | m["clobTokenIds"].split('"')[1]
232 | )
233 | else:
234 | m["community_predictions"] = []
235 | except IndexError as e:
236 | # Print the error and the problematic clobTokenIds
237 | print(f"Error: {e}, clobTokenIds: {m.get('clobTokenIds')}")
238 | m["community_predictions"] = []
239 |
240 | # Rename field names so it aligns with mateculus
241 | m["title"] = m.pop("question")
242 | m["close_time"] = m.pop("endDate")
243 | m["created_time"] = m.pop("createdAt")
244 | m["background"] = m.pop("description")
245 | m["is_resolved"] = m.pop("closed")
246 |
247 | # Data source
248 | m["data_source"] = "polymarket"
249 |
250 | return m
251 |
252 |
253 | def main(n_days):
254 | """
255 | Main function to fetch and process Polymarket data.
256 |
257 | Args:
258 | n_days (int): Number of days in the past to limit the data fetching.
259 | If None, fetches all available data.
260 |
261 | Returns:
262 | list: A list of processed market dictionaries.
263 | """
264 | logger.info("Starting the polymarket script...")
265 |
266 | start_time = time.time()
267 |
268 | all_markets = generate_json_markets("markets")
269 |
270 | if n_days is not None:
271 | date_limit = datetime.now() - timedelta(days=n_days)
272 | all_markets = [
273 | market
274 | for market in all_markets
275 | if datetime.fromisoformat(market["createdAt"][:-1]) >= date_limit
276 | ]
277 |
278 | logger.info(f"Number of polymarket questions fetched: {len(all_markets)}")
279 |
280 | logger.info("Start preprocess the question...")
281 | with ThreadPoolExecutor(max_workers=50) as executor:
282 | results = list(
283 | tqdm(
284 | executor.map(process_market, all_markets),
285 | total=len(all_markets),
286 | )
287 | )
288 |
289 | logger.info("Start extracting articles links...")
290 |
291 | for question in results:
292 | question["extracted_articles_urls"] = information_retrieval.get_urls_from_text(
293 | question["background"]
294 | )
295 | if question["comments"]:
296 | for comment in question["comments"]:
297 | question["extracted_articles_urls"].extend(
298 | information_retrieval.get_urls_from_text(comment["body"])
299 | )
300 |
301 | elapsed_time = time.time() - start_time
302 | logger.info(f"Total execution time: {elapsed_time} seconds")
303 |
304 | logger.info("Uploading to s3...")
305 | question_types = ["binary", "multiple_choice"]
306 | data_scraping.upload_scraped_data(results, "polymarket", question_types, n_days)
307 |
308 |
309 | if __name__ == "__main__":
310 | parser = argparse.ArgumentParser(description="Fetch polymarket data.")
311 | parser.add_argument(
312 | "--n_days",
313 | type=int,
314 | help="Fetch questions created in the last N days",
315 | default=None,
316 | )
317 | args = parser.parse_args()
318 | main(args.n_days)
319 |
--------------------------------------------------------------------------------
/llm_forecasting/summarize.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | import asyncio
3 | import logging
4 | import time
5 |
6 | # Local application/library-specific imports
7 | from config.constants import MODEL_TOKEN_LIMITS
8 | import model_eval
9 | from prompts.prompts import PROMPT_DICT
10 | from utils import model_utils
11 |
12 | # Set up logging
13 | logging.basicConfig(level=logging.INFO)
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def concat_summaries(articles, return_summaries_list=False):
18 | """
19 | Combine the summaries of various articles into a single, cohesive string,
20 | ensuring to incorporate each article's title, publication date, and a
21 | sequential index for easy reference.
22 |
23 | Args:
24 | articles (list of article objects): List of article objects with
25 | a 'summary' field.
26 | return_summaries_list (bool, optional): Whether to return the list of
27 | summaries as well. Defaults to False.
28 |
29 | Returns:
30 | str: A string containing the concatenated summaries of the articles.
31 |
32 | Example output:
33 | ---
34 | ARTICLES
35 | [1] Title 1 (published on 2021-01-01)
36 | ....
37 | [2] Title 2 (published on 2021-01-02)
38 | ....
39 | ----
40 |
41 | If return_summaries_list is True, then the function returns a tuple
42 | containing the concatenated summaries string and the list of summaries.
43 | """
44 | if not articles:
45 | return "---\nNo articles were retrieved for this question.\n----"
46 | article_summaries = [
47 | f"[{index}] {article.title} (published on {(article.publish_date.date() if article.publish_date else 'unknown date')})\nSummary: {article.summary}\n"
48 | for index, article in enumerate(articles, start=1)
49 | ]
50 | concatenated_summaries_str = (
51 | "---\nARTICLES\n" + "\n".join(article_summaries) + "----"
52 | )
53 | if return_summaries_list:
54 | return concatenated_summaries_str, article_summaries
55 | return concatenated_summaries_str
56 |
57 |
58 | def split_text_into_chunks(text, model_name, token_limit):
59 | """
60 | Split the text into chunks, ensuring each chunk is below the token limit.
61 |
62 | Args:
63 | text (str): Input text to be split.
64 | model_name (str): Name of the model to be used for token counting.
65 | token_limit (int): Maximum number of tokens allowed per chunk.
66 |
67 | Returns:
68 | list: List of text chunks.
69 | """
70 | words = text.split()
71 | current_chunk = []
72 | current_chunk_tokens = 0
73 | chunks = []
74 |
75 | for word in words:
76 | word_tokens = model_utils.count_tokens(word, model_name)
77 | if current_chunk_tokens + word_tokens > token_limit:
78 | chunks.append(" ".join(current_chunk))
79 | current_chunk = [word]
80 | current_chunk_tokens = word_tokens
81 | else:
82 | current_chunk.append(word)
83 | current_chunk_tokens += word_tokens
84 |
85 | if current_chunk:
86 | chunks.append(" ".join(current_chunk))
87 |
88 | return chunks
89 |
90 |
91 | def recursive_summarize(
92 | text,
93 | model_name,
94 | prompt,
95 | output_token_length=None,
96 | temperature=0,
97 | ):
98 | """
99 | Recursively summarize the text until the summary fits within the context
100 | window.
101 |
102 | Args:
103 | text (str): The input text to be summarized.
104 | model_name (str): Name of the model to be used for summarization.
105 | primary_prompt (str): Main prompt to guide the summarization.
106 | output_token_length (int, optional): Desired word count of the final
107 | summary. Defaults to None.
108 | temperature (float, optional): Sampling temperature for the completion.
109 | Defaults to 0.
110 |
111 | Returns:
112 | str: Summarized text.
113 | """
114 | start_time = time.time()
115 |
116 | total_tokens = model_utils.count_tokens(text, model_name)
117 | logger.info(f"Total number of tokens of the given text: {total_tokens}")
118 |
119 | if total_tokens <= MODEL_TOKEN_LIMITS[model_name]:
120 | if output_token_length:
121 | prompt += (
122 | f"\n\nAlso, ensure the summary is under {output_token_length} words.\n"
123 | )
124 |
125 | output = model_eval.get_response_from_model(
126 | model_name=model_name,
127 | prompt=prompt.format(article=text),
128 | max_tokens=output_token_length,
129 | )
130 |
131 | end_time = time.time()
132 | elapsed_time = end_time - start_time
133 | logger.info(f"Time taken for summarization: {elapsed_time:.2f} seconds")
134 | logger.info("Finished summarizing the article!")
135 | return output
136 | else:
137 | token_limit = MODEL_TOKEN_LIMITS[model_name] - 1000
138 | chunks = split_text_into_chunks(text, model_name, token_limit)
139 |
140 | summarized_chunks = []
141 | for chunk in chunks:
142 | summarized_chunk = recursive_summarize(
143 | chunk,
144 | model_name,
145 | prompt,
146 | output_token_length,
147 | temperature,
148 | )
149 | summarized_chunks.append(summarized_chunk)
150 |
151 | summarized_text = " ".join(summarized_chunks)
152 | return recursive_summarize(
153 | summarized_text,
154 | model_name,
155 | prompt,
156 | output_token_length,
157 | temperature,
158 | )
159 |
160 |
161 | async def summarize_articles(
162 | articles,
163 | model_name="gpt-3.5-turbo-1106",
164 | prompt=PROMPT_DICT["summarization"]["0"][0],
165 | temperature=0.2,
166 | update_object=True,
167 | inline_questions=[],
168 | ):
169 | """
170 | Summarizes a list of articles asynchronously.
171 |
172 | Long articles are truncated down to token limit and summarized separately.
173 |
174 | Example usage:
175 | >> summarized_results = await summarize_articles(articles)
176 |
177 | Args:
178 | articles (list of obj): List of article objects. Each article object should have the following fields:
179 | text_cleaned (str): Full text of the article.
180 | model_name (str, optional): Name of the OpenAI model to be used for summarization (defaults to "gpt-3.5-turbo-1106").
181 | prompt (str, optional): Prompt to use for the API call (defaults to PROMPT_DICT["summarization"]["0"][0]).
182 | This is not the full prompt, but contains a placeholder for the article text.
183 | output_token_length (int, optional): Desired word count of the final summary. Defaults to None.
184 | temperature (float, optional): Sampling temperature for the completion. Defaults to 0.2.
185 | update_object (bool, optional): Whether to update the article object with the summary (defaults to True).
186 | inline_questions (dict, optional): List containing the inline questions. Defaults to [].
187 |
188 | Returns:
189 | dict: Dictionary containing the summarized results for each article.
190 | Example output:
191 | ---
192 | {
193 | "Title 1": "Summary 1",
194 | "Title 2": "Summary 2",
195 | ...
196 | }
197 | ---
198 | Also, the article objects are updated with the summaries if update_object is True.
199 | """
200 | summarized_results = {}
201 | # Truncate articles that are too long
202 | for article in articles:
203 | if ( # exceeds token limit
204 | model_utils.count_tokens(article.text_cleaned, model_name)
205 | > MODEL_TOKEN_LIMITS[model_name] - 1000
206 | ):
207 | article.text_cleaned = split_text_into_chunks(
208 | article.text_cleaned, model_name, MODEL_TOKEN_LIMITS[model_name] - 1000
209 | )[0]
210 | # Summarize all articles asynchronously
211 | logger.info(f"Async summarizing {len(articles)} short articles")
212 | all_summaries = await async_summarize(
213 | articles,
214 | prompt,
215 | update_object=update_object,
216 | temperature=temperature,
217 | model_name=model_name,
218 | inline_questions=inline_questions,
219 | )
220 | for i, article in enumerate(articles):
221 | summarized_results[article.title] = all_summaries[i]
222 | return summarized_results
223 |
224 |
225 | async def async_summarize(
226 | articles,
227 | prompt=PROMPT_DICT["summarization"]["0"][0],
228 | update_object=True,
229 | model_name="gpt-3.5-turbo-1106",
230 | temperature=0.2,
231 | inline_questions=[],
232 | ):
233 | """
234 | Asynchronously summarizes a list of articles.
235 | Example usage:
236 | >> all_summaries = await async_summarize(articles)
237 | The order of the returned summaries is the same as the order of the input articles.
238 | All summaries are stored in the "summary" field of each article object, if update_object is True.
239 |
240 | Args:
241 | articles (list of obj): List of article objects. Each article object should have the following fields:
242 | text_cleaned (str): Full text of the article.
243 | prompt (str): Prompt to use for the API call (defaults to PROMPT_DICT["summarization"]["0"][0]).
244 | This is not the full prompt, but contains a placeholder for the article text.
245 | update_object (bool): Whether to update the article object with the summary (defaults to True).
246 |
247 | Returns:
248 | list of str: List of summaries (str) for each article.
249 | """
250 | if len(articles) == 0 or not articles:
251 | return []
252 | if inline_questions:
253 | question = inline_questions["title"]
254 | background = inline_questions["background"]
255 | resolution_criteria = inline_questions["resolution_criteria"]
256 | prompts = [
257 | prompt.format(
258 | question=question,
259 | background=background,
260 | resolution_criteria=resolution_criteria,
261 | article=article.text_cleaned,
262 | )
263 | for article in articles
264 | ]
265 | else:
266 | prompts = [prompt.format(article=article.text_cleaned) for article in articles]
267 |
268 | summarization_tasks = [
269 | model_eval.get_async_response(
270 | prompt, model_name=model_name, temperature=temperature
271 | )
272 | for prompt in prompts
273 | ]
274 | all_summaries = await asyncio.gather(*summarization_tasks)
275 | if update_object:
276 | for i, article in enumerate(articles):
277 | article.summary = all_summaries[i]
278 | return all_summaries
279 |
--------------------------------------------------------------------------------
/llm_forecasting/utils/string_utils.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | import logging
3 | import re
4 |
5 | # Third-party imports
6 | import urllib.parse
7 |
8 | # Local application/library-specific imports
9 | from config.constants import TOKENS_TO_PROBS_DICT
10 |
11 | # Set up logging
12 | logging.basicConfig(level=logging.INFO)
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def is_string_in_list(target_string, string_list):
17 | """
18 | Check if the target string is in the list of strings; case insensitive.
19 | """
20 | # Convert the target string to lowercase
21 | target_string_lower = target_string.lower()
22 |
23 | # Check if the lowercase target string is in the list of lowercase strings
24 | return any(s.lower() == target_string_lower for s in string_list)
25 |
26 |
27 | def find_end_word(paragraph, end_words, window_size=50):
28 | """
29 | Find one of the end_words in the last window_size words of the paragraph.
30 | Return the found word or None if no word is found.
31 |
32 | TODO: Lowercase the paragraph and end_words before searching so that the search is case-insensitive?
33 |
34 | Args:
35 | - paragraph (str): The paragraph to search in.
36 | - end_words (list of str): The words to search for.
37 | - window_size (int): The number of words from the end to search within.
38 |
39 | Returns:
40 | str: found word or None
41 | """
42 | sorted_words = sorted(end_words, key=lambda s: len(s.split(" ")), reverse=True)
43 | for end_word in sorted_words:
44 | if end_word in paragraph[-window_size:]:
45 | return end_word
46 | logger.debug(f"Could not find any end word in {paragraph[-window_size:]}.")
47 | return None
48 |
49 |
50 | def get_prompt(
51 | prompt_template,
52 | fields,
53 | question=None,
54 | data_source=None,
55 | dates=None,
56 | background=None,
57 | resolution_criteria=None,
58 | num_keywords=None,
59 | retrieved_info=None,
60 | reasoning=None,
61 | article=None,
62 | summary=None,
63 | few_shot_examples=None,
64 | max_words=None,
65 | ):
66 | """
67 | Fill in a prompt template with specific data based on provided fields.
68 |
69 | Args:
70 | prompt_template (str): The template containing placeholders for data
71 | insertion.
72 | fields (list): Placeholders within the template to be filled ('QUESTION',
73 | 'DATES', 'FEW_SHOT_EXAMPLES', 'RETRIEVED_INFO').
74 | question (str, optional): The question text for the 'QUESTION'
75 | placeholder.
76 | data_source (str, optional): The platform (e.g. "metaculus").
77 | background (str, optional): Background information for the 'BACKGROUND'
78 | placeholder.
79 | resolution_criteria (str, optional): Resolution criteria for the
80 | 'RESOLUTION_CRITERIA' placeholder.
81 | dates (tuple or list of strs, optional): Start and end dates for the
82 | 'DATES' placeholder (length == 2)
83 | retrieved_info (str, optional): Information text for the 'RETRIEVED_INFO'
84 | placeholder.
85 | reasoning (str, optional): Reasoning text for the 'REASONING' placeholder.
86 | article (str, optional): Article text for the 'ARTICLE' placeholder.
87 | summary (str, optional): Summary text for the 'SUMMARY' placeholder.
88 | few_shot_examples (list, optional): List of (question, answer) tuples for
89 | the 'FEW_SHOT_EXAMPLES' placeholder.
90 | max_words (int, optional): Maximum number of words for the 'MAX_WORDS'
91 | used for search query generation.
92 |
93 | Returns:
94 | str: A string with the placeholders in the template replaced with the provided
95 | data.
96 | """
97 | mapping = {}
98 | for f in fields:
99 | if f == "QUESTION":
100 | mapping["question"] = question
101 | elif f == "DATES":
102 | mapping["date_begin"] = dates[0]
103 | mapping["date_end"] = dates[1]
104 | elif f == "FEW_SHOT_EXAMPLES":
105 | examples = few_shot_examples or []
106 | for j, (q, answer) in enumerate(examples, 1):
107 | mapping[f"question_{j}"] = q
108 | mapping[f"answer_{j}"] = answer
109 | elif f == "RETRIEVED_INFO":
110 | mapping["retrieved_info"] = retrieved_info
111 | elif f == "BACKGROUND":
112 | mapping["background"] = background
113 | elif f == "RESOLUTION_CRITERIA":
114 | mapping["resolution_criteria"] = resolution_criteria
115 | elif f == "REASONING":
116 | mapping["reasoning"] = reasoning
117 | elif f == "BASE_REASONINGS":
118 | mapping["base_reasonings"] = reasoning
119 | elif f == "NUM_KEYWORDS":
120 | mapping["num_keywords"] = num_keywords
121 | elif f == "MAX_WORDS":
122 | mapping["max_words"] = str(max_words)
123 | elif f == "ARTICLE":
124 | mapping["article"] = article
125 | elif f == "SUMMARY":
126 | mapping["summary"] = summary
127 | elif f == "DATA_SOURCE":
128 | mapping["data_source"] = data_source
129 | return prompt_template.format(**mapping)
130 |
131 |
132 | def extract_probability_with_stars(text):
133 | """
134 | Extract a probability value from a given text string.
135 |
136 | The function searches for numbers enclosed in asterisks (*), interpreting
137 | them as potential probability values. If a percentage sign is found with
138 | the number, it's converted to a decimal. The function returns the last
139 | number found that is less than or equal to 1, as a probability should be.
140 | If no such number is found, a default probability of 0.5 is returned.
141 |
142 | Args:
143 | - text (str): The text string from which the probability value is to be
144 | extracted.
145 |
146 | Returns:
147 | - float: The extracted probability value, if found. Otherwise, returns 0.5.
148 | """
149 | # Regular expression to find numbers between stars
150 | pattern = r"\*(.*?[\d\.]+.*?)\*"
151 | matches = re.findall(pattern, text)
152 |
153 | # Extracting the numerical values from the matches
154 | extracted_numbers = []
155 | for match in matches:
156 | # Extract only the numerical part (ignoring potential non-numeric
157 | # characters)
158 | number_match = re.search(r"[\d\.]+", match)
159 | if number_match:
160 | try:
161 | number = float(number_match.group())
162 | if "%" in match:
163 | number /= 100
164 | extracted_numbers.append(number)
165 | except BaseException:
166 | continue
167 |
168 | if len(extracted_numbers) > 0 and extracted_numbers[-1] <= 1:
169 | return extracted_numbers[-1]
170 |
171 | # Regular expression to find numbers between stars
172 | pattern = r"([\d\.]+.*?)\*"
173 | matches = re.findall(pattern, text)
174 |
175 | # Extracting the numerical values from the matches
176 | extracted_numbers = []
177 | for match in matches:
178 | # Extract only the numerical part (ignoring potential non-numeric
179 | # characters)
180 | number_matches = re.findall(r"[\d\.]+", match)
181 | for num_match in number_matches:
182 | try:
183 | number = float(num_match)
184 | extracted_numbers.append(number)
185 | except BaseException:
186 | continue
187 |
188 | if len(extracted_numbers) > 0 and extracted_numbers[-1] <= 1:
189 | return extracted_numbers[-1]
190 |
191 | return 0.5
192 |
193 |
194 | def extract_prediction(
195 | response,
196 | answer_type="probability",
197 | end_words=list(TOKENS_TO_PROBS_DICT["ten_options"].keys()),
198 | ):
199 | """
200 | A generic function to extract a prediction from a response string.
201 |
202 | Args:
203 | response (str): The response string from which the prediction is to be
204 | extracted.
205 | answer_type (str): The type of answer to extract. Can be "probability"
206 | or "tokens".
207 | end_words (list): The list of end words to search for in the response
208 | string. The first end word found in the response string will be
209 | used to extract the prediction.
210 | Only used if answer_type == "tokens".
211 |
212 | Returns:
213 | str or float: The extracted prediction.
214 | """
215 | if answer_type == "probability":
216 | return extract_probability_with_stars(response)
217 | elif answer_type == "tokens":
218 | return find_end_word(response, end_words)
219 | else:
220 | raise ValueError(f"Invalid answer_type: {answer_type}")
221 |
222 |
223 | def extract_and_decode_title_from_wikiurl(url):
224 | """
225 | Extract the title from a Wikipedia URL and decode it.
226 |
227 | Args:
228 | url (str): The Wikipedia URL.
229 |
230 | Returns:
231 | str: The decoded title. None if the URL is invalid.
232 | """
233 | if "wikipedia.org" in url and "upload.wikimedia.org" not in url:
234 | # Extract the part after '/wiki/' and before the first '#?' if present
235 | match = re.search(r"/wiki/([^#?]+)", url)
236 | if match:
237 | # Replace underscores with spaces and decode percent encoding
238 | return urllib.parse.unquote(re.sub(r"_", " ", match.group(1)))
239 | return None
240 |
241 |
242 | def concat_summaries_from_fields(summary_texts, titles, publish_dates):
243 | """
244 | Concatenate summaries from a list of summary texts. Fill in the titles and
245 | publish dates for each summary.
246 |
247 | The length of the summary_texts, titles, and publish_dates lists should be
248 | the same.
249 |
250 | Args:
251 | summary_texts (list of str): A list of summary texts.
252 | titles (list of str): A list of titles.
253 | publish_dates (list str): A list of publish dates.
254 |
255 | Returns:
256 | str: The concatenated summaries with titles and publish dates.
257 | """
258 | logger.info(
259 | f"Concatenating summaries from {len(summary_texts)} articles ({len(titles)} titles and {len(publish_dates)} dates)."
260 | )
261 | if (len(summary_texts) != len(titles)) or (
262 | len(summary_texts) != len(publish_dates)
263 | ):
264 | logger.error(
265 | f"Lengths of summary_texts, titles, and publish_dates should be the same. Got {len(summary_texts)}, {len(titles)}, and {len(publish_dates)}."
266 | )
267 | return "Not available."
268 | article_summaries = [
269 | f"[{i+1}] {titles[i]} (published on {(publish_dates[i] if publish_dates[i] else 'unknown date')})\nSummary: {article_summary}\n"
270 | for i, article_summary in enumerate(summary_texts)
271 | ]
272 | concatenated_summaries_str = (
273 | "---\nARTICLES\n" + "\n".join(article_summaries) + "----"
274 | )
275 | return concatenated_summaries_str
--------------------------------------------------------------------------------
/llm_forecasting/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | import concurrent.futures
3 | import logging
4 | import re
5 |
6 | # Local application/library-specific imports
7 | from config.constants import S3, S3_BUCKET_NAME
8 | import model_eval
9 | from prompts.prompts import PROMPT_DICT
10 | from utils import db_utils, time_utils, string_utils
11 |
12 | # Set up logging
13 | logging.basicConfig(level=logging.INFO)
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def get_formatted_data(
18 | s3_path,
19 | retrieval_index=0,
20 | num_retrievals=5,
21 | questions_after="2015",
22 | return_raw_question_data=False,
23 | data=None,
24 | ):
25 | """
26 | Retrieve and process training data from an S3 path.
27 |
28 | This function reads data from S3, processes it, and structures it for training purposes.
29 | It calculates retrieval dates and filters out data based on these dates. The function can
30 | optionally return raw question data.
31 |
32 | Also, the function can optionally take in the |data| directly.
33 |
34 | Parameters:
35 | s3_path (str): Path to the data file in S3.
36 | retrieval_index (int, optional): Index for calculating the retrieval date. Defaults to 0.
37 | num_retrievals (int, optional): Total number of retrievals to consider. Defaults to 5.
38 | return_raw_question_data (bool, optional): Flag to return raw question data. Defaults to False.
39 | data (list): List of forecasting questions.
40 |
41 | Returns:
42 | dict or tuple: A dictionary containing structured training data, or a tuple with the dictionary
43 | and raw data if return_raw_question_data is True.
44 |
45 | Note:
46 | This function expects specific keys in the data (e.g., 'date_close', 'date_resolve_at', etc.),
47 | and logs an error if reading from S3 fails.
48 | """
49 | if not data:
50 | try:
51 | data = db_utils.read_pickle_from_s3(S3, S3_BUCKET_NAME, s3_path)
52 | except Exception as e:
53 | logger.error(f"Error reading data from S3: {e}")
54 | return {}
55 |
56 | question_dict = {
57 | "question_list": [],
58 | "background_list": [],
59 | "resolution_criteria_list": [],
60 | "question_dates_list": [],
61 | "resolve_dates_list": [],
62 | "retrieval_dates_list": [],
63 | "answer_list": [],
64 | "data_source_list": [],
65 | "community_pred_at_retrieval_list": [],
66 | "urls_in_background_list": [],
67 | "category_list": [],
68 | }
69 | raw_data = []
70 | for q in data:
71 | q["date_close"] = q["date_close"] or q["date_resolve_at"]
72 | retrieval_date = time_utils.get_retrieval_date(
73 | retrieval_index,
74 | num_retrievals,
75 | q["date_begin"],
76 | q["date_close"],
77 | q["date_resolve_at"],
78 | )
79 |
80 | if retrieval_date is None:
81 | continue
82 | elif retrieval_date == q["date_resolve_at"]:
83 | continue
84 | elif not time_utils.is_more_recent(
85 | f"{questions_after}-01-01", q["date_begin"], or_equal_to=True
86 | ):
87 | continue
88 |
89 | raw_data.append(q)
90 | for key, value in {
91 | "question_list": q["question"],
92 | "background_list": q["background"],
93 | "resolution_criteria_list": q["resolution_criteria"],
94 | "question_dates_list": (
95 | time_utils.extract_date(q["date_begin"]),
96 | time_utils.extract_date(q["date_close"]),
97 | ),
98 | "resolve_dates_list": q["date_resolve_at"],
99 | "retrieval_dates_list": (
100 | time_utils.extract_date(q["date_begin"]),
101 | retrieval_date,
102 | ),
103 | "answer_list": int(q["resolution"]),
104 | "data_source_list": q["data_source"],
105 | "community_pred_at_retrieval_list": time_utils.find_pred_with_closest_date(
106 | retrieval_date, q["community_predictions"]
107 | )[1],
108 | "urls_in_background_list": q["urls_in_background"],
109 | "category_list": q["gpt_3p5_category"],
110 | }.items():
111 | question_dict[key].append(value)
112 |
113 | return (question_dict, raw_data) if return_raw_question_data else question_dict
114 |
115 |
116 | def format_single_question(data_dict, index):
117 | """
118 | Format a single question, located by |index|, from a dictionary of all
119 | questions.
120 |
121 | Args:
122 | data_dict (dict): Dictionary containing all question data.
123 | index (int): Index of question.
124 |
125 | Returns:
126 | dict: Formatted question
127 | """
128 | return {
129 | "question": data_dict["question_list"][index],
130 | "background": data_dict["background_list"][index],
131 | "resolution_criteria": data_dict["resolution_criteria_list"][index],
132 | "answer": data_dict["answer_list"][index],
133 | "question_dates": data_dict["question_dates_list"][index],
134 | "retrieval_dates": data_dict["retrieval_dates_list"][index],
135 | "data_source": data_dict["data_source_list"][index],
136 | "resolve_date": data_dict["resolve_dates_list"][index],
137 | "community_pred_at_retrieval": data_dict["community_pred_at_retrieval_list"][
138 | index
139 | ],
140 | "urls_in_background": data_dict["urls_in_background_list"][index],
141 | "category": data_dict["category_list"][index],
142 | }
143 |
144 |
145 | def is_question_ill_defined(question, model_name):
146 | """
147 | Determine if a given question is ill-defined using a specified model.
148 | Returns True if ill-defined, False if not, and None if the determination cannot be made.
149 | """
150 | prompt = string_utils.get_prompt(
151 | PROMPT_DICT["data_wrangling"]["is_bad_title"][0],
152 | PROMPT_DICT["data_wrangling"]["is_bad_title"][1],
153 | question=question,
154 | )
155 | response = model_eval.get_response_from_model(
156 | model_name, prompt, max_tokens=500, temperature=0.1
157 | )
158 |
159 | if "Classification:" not in response:
160 | logger.error(
161 | f"'Classification:' is not in the response for question: {question}"
162 | )
163 | return None
164 |
165 | end_resp = response.split("Classification:")[1]
166 | if "ok" in end_resp:
167 | return False
168 | elif "flag" in end_resp:
169 | logger.info(f"The following question is ill-defined: {question}")
170 | return True
171 |
172 | logger.error(f"Ambiguous response for question: {question}")
173 | return True
174 |
175 |
176 | def assign_ill_defined_questions(data_list, model_name="gpt-3.5-turbo-1106"):
177 | """
178 | Evaluate each question in data_list to determine if it's ill-defined using the specified model.
179 | Modifies data_list in place by adding a key 'ill-defined' with a Boolean value.
180 | """
181 | number_of_workers = 50
182 |
183 | with concurrent.futures.ThreadPoolExecutor(
184 | max_workers=number_of_workers
185 | ) as executor:
186 | future_to_question = {
187 | executor.submit(is_question_ill_defined, item["question"], model_name): item
188 | for item in data_list
189 | if "is_ill_defined" not in item
190 | }
191 |
192 | for future in concurrent.futures.as_completed(future_to_question):
193 | question_item = future_to_question[future]
194 | try:
195 | result = future.result()
196 | if result is not None:
197 | question_item["is_ill_defined"] = result
198 | else:
199 | logger.warning(
200 | f"Could not determine if question is ill-defined: {question_item['question']}"
201 | )
202 | except Exception as exc:
203 | logger.error(
204 | f"Error processing question {question_item['question']}: {exc}"
205 | )
206 | return None
207 |
208 |
209 | def assign_category(question, background, model_name):
210 | try:
211 | prompt = string_utils.get_prompt(
212 | PROMPT_DICT["data_wrangling"]["assign_category"][0],
213 | PROMPT_DICT["data_wrangling"]["assign_category"][1],
214 | question=question,
215 | background=background,
216 | )
217 | response = model_eval.get_response_from_model(
218 | model_name, prompt, max_tokens=500, temperature=0.1
219 | )
220 | return response.strip('"').strip("'").strip(" ").strip(".")
221 | except Exception as e:
222 | logger.error(f"Error in assign_category: {e}")
223 | return None
224 |
225 |
226 | def assign_categories(data_list, model_name="gpt-3.5-turbo-1106"):
227 | number_of_workers = 100
228 | updated_items = []
229 |
230 | with concurrent.futures.ThreadPoolExecutor(
231 | max_workers=number_of_workers
232 | ) as executor:
233 | future_to_question = {
234 | executor.submit(
235 | assign_category, item["question"], item["background"], model_name
236 | ): item
237 | for item in data_list
238 | if "gpt_3p5_category" not in item
239 | }
240 |
241 | for future in concurrent.futures.as_completed(future_to_question):
242 | question_item = future_to_question[future]
243 | try:
244 | result = future.result()
245 | if result is not None:
246 | question_item["gpt_3p5_category"] = result
247 | else:
248 | logger.warning(
249 | f"Could not assign category: {question_item['question']}"
250 | )
251 | updated_items.append(question_item)
252 | except Exception as exc:
253 | logger.error(
254 | f"Error processing question {question_item['question']}: {exc}"
255 | )
256 |
257 | return None
258 |
259 |
260 | def reformat_metaculus_questions(
261 | data,
262 | model_name="gpt-3.5-turbo-1106",
263 | prompt=PROMPT_DICT["data_wrangling"]["reformat"],
264 | ):
265 | """
266 | Reformat questions from Metaculus to be more readable.
267 |
268 | In particular, some questions have a title that ends with a parenthesis,
269 | containing the actual subject.
270 | This function rephrases it to be a Yes/No question.
271 |
272 | For example,
273 | >>> "Who will win the 2020 US presidential election? (Biden)"
274 | will be reformatted by the langauge model to
275 | >>> "Will Biden win the 2020 US presidential election?"
276 |
277 | Args:
278 | data (list of dict): List of questions in dictionary format.
279 | model_name (str, optional): Language model name, default is
280 | "gpt-3.5-turbo-1106".
281 | prompt (tuple of str, optional): Prompt to use for model evaluation.
282 | Default is PROMPT_DICT["data_cleaning"]["reformat"].
283 |
284 | Returns:
285 | Modifies the input data in-place, and returns None.
286 | """
287 |
288 | def find_text_between_stars(text):
289 | match = re.search(r"\*([^*]+)\*", text)
290 | return match.group(1) if match else None
291 |
292 | for d in data:
293 | if "? (" in d["title"]:
294 | prompt = string_utils.get_prompt(
295 | prompt[0],
296 | prompt[1],
297 | question=d["title"],
298 | )
299 | response = model_eval.get_response_from_model(
300 | model_name=model_name, prompt=prompt
301 | )
302 | transformed_title = find_text_between_stars(response)
303 | if transformed_title:
304 | d["title"] = transformed_title
305 |
306 | return None
307 |
--------------------------------------------------------------------------------
/scripts/data_scraping/manifold.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import certifi
3 | import concurrent.futures
4 | import logging
5 | import requests
6 | from datetime import datetime, timedelta
7 | from tqdm import tqdm
8 |
9 | import data_scraping
10 | import information_retrieval
11 | from config.keys import keys
12 | from utils import time_utils
13 |
14 | # Logger configuration
15 | logger = logging.getLogger(__name__)
16 |
17 | # Manifold API configuration
18 | MANIFOLD_API = keys["MANIFOLD_KEY"]
19 | BASE_URL = "https://api.manifold.markets/v0/markets?"
20 |
21 | # Headers for API requests
22 | headers = {"Authorization": f"Key {MANIFOLD_API}"}
23 |
24 |
25 | def fetch_all_manifold_questions(base_url, headers, limit=1000):
26 | """
27 | Fetch all questions from the Manifold API.
28 |
29 | Retrieve a list of questions from the Manifold API, paginating through the results
30 | based on the specified limit until all questions are fetched.
31 |
32 | Args:
33 | base_url (str): The base URL for the Manifold API.
34 | headers (dict): The headers for the API request, including authorization.
35 | limit (int): The maximum number of questions to fetch in each request.
36 |
37 | Returns:
38 | list: A list of all questions fetched from the API.
39 | """
40 | all_questions = []
41 | last_id = None
42 |
43 | while True:
44 | params = {"limit": limit, "sort": "created-time", "order": "desc"}
45 | if last_id:
46 | params["before"] = last_id
47 |
48 | response = requests.get(
49 | base_url, headers=headers, params=params, verify=certifi.where()
50 | )
51 | if response.status_code != 200:
52 | print(f"Error: {response.status_code}")
53 | break
54 |
55 | data = response.json()
56 | if not data:
57 | break
58 |
59 | all_questions.extend(data)
60 | last_id = data[-1]["id"]
61 |
62 | return all_questions
63 |
64 |
65 | def fetch_market_details(market_id, headers):
66 | """
67 | Fetch detailed information for a specific market from the Manifold API.
68 |
69 | Retrieve detailed information about a specific market, identified by its market ID,
70 | from the Manifold API.
71 |
72 | Args:
73 | market_id (str): The unique identifier of the market.
74 | headers (dict): Headers for the API request, including authorization.
75 |
76 | Returns:
77 | dict or None: The detailed market information or None if an error occurs.
78 | """
79 | try:
80 | url = f"https://api.manifold.markets/v0/market/{market_id}"
81 | response = requests.get(url, headers=headers, verify=certifi.where())
82 | response.raise_for_status() # Raises an HTTPError for non-200 responses
83 | return response.json()
84 | except requests.exceptions.HTTPError as e:
85 | print(f"HTTP error fetching market details for ID {market_id}: {e}")
86 | except requests.exceptions.RequestException as e:
87 | print(f"Error fetching market details for ID {market_id}: {e}")
88 | return None
89 |
90 |
91 | def fetch_bets_for_market(market_id, headers):
92 | """
93 | Fetch a list of bets for a specific market from the Manifold API.
94 |
95 | Retrieve all bets placed in a specific market, identified by its market ID, from
96 | the Manifold API. Handle any HTTP errors encountered during the request and return
97 | an empty list if an error occurs.
98 |
99 | Args:
100 | market_id (str): The unique identifier of the market.
101 | headers (dict): Headers for the API request, including authorization.
102 |
103 | Returns:
104 | list: A list of bets for the specified market or an empty list if an error occurs.
105 | """
106 | try:
107 | url = "https://api.manifold.markets/v0/bets"
108 | params = {"contractId": market_id, "limit": 1000}
109 | response = requests.get(
110 | url, headers=headers, params=params, verify=certifi.where()
111 | )
112 | response.raise_for_status()
113 | return response.json()
114 | except requests.exceptions.HTTPError as e:
115 | print(f"HTTP error fetching bets for market ID {market_id}: {e}")
116 | except requests.exceptions.RequestException as e:
117 | print(f"Error fetching bets for market ID {market_id}: {e}")
118 | return []
119 |
120 |
121 | def fetch_comments_for_market(market_id, headers):
122 | """
123 | Fetch a list of comments for a specific market from the Manifold API.
124 |
125 | Retrieve all comments made in a specific market, identified by its market ID, from
126 | the Manifold API. Handle any HTTP errors encountered during the request and return
127 | an empty list if an error occurs.
128 |
129 | Args:
130 | market_id (str): The unique identifier of the market.
131 | headers (dict): Headers for the API request, including authorization.
132 |
133 | Returns:
134 | list: A list of comments for the specified market or an empty list if an error occurs.
135 | """
136 | try:
137 | url = "https://api.manifold.markets/v0/comments"
138 | params = {"contractId": market_id, "limit": 1000}
139 | response = requests.get(
140 | url, headers=headers, params=params, verify=certifi.where()
141 | )
142 | response.raise_for_status()
143 | return response.json()
144 | except requests.exceptions.HTTPError as e:
145 | print(f"HTTP error fetching comments for market ID {market_id}: {e}")
146 | except requests.exceptions.RequestException as e:
147 | print(f"Error fetching comments for market ID {market_id}: {e}")
148 | return []
149 |
150 |
151 | def process_market(market, headers):
152 | """
153 | Process a single market by fetching its details, bets, comments, and transforming the data.
154 |
155 | Fetch and add market descriptions, and map bets and comments for a given market.
156 | Convert timestamps and restructure market data for consistency and clarity.
157 |
158 | Args:
159 | market (dict): The market data to process.
160 | headers (dict): Headers for the API requests.
161 |
162 | Returns:
163 | dict: The processed market data.
164 | """
165 | market_id = market["id"]
166 |
167 | # Fetch and add market descriptions
168 | try:
169 | market_details = fetch_market_details(market_id, headers)
170 | if market_details:
171 | market["background"] = market_details.get("description")
172 |
173 | background_content = market.get("background", {}).get("content", [])
174 | background_text = " ".join(
175 | [
176 | item["content"][0].get("text", "")
177 | for item in background_content
178 | if item.get("type") == "paragraph" and item.get("content")
179 | ]
180 | )
181 | market["background"] = background_text
182 | except Exception as exc:
183 | logger.error(
184 | f"Market id {market_id} got processing error when fetching description: {exc}"
185 | )
186 | try:
187 | market["resolution"] = 1 if market["resolution"] == "YES" else 0
188 | except Exception as exc:
189 | logger.error(
190 | f"Market id {market_id} got processing error when reformatting resolution: {exc}"
191 | )
192 |
193 | # Fetch and map bets and comments
194 | market["community_predictions"] = fetch_bets_for_market(market_id, headers)
195 | market["comments"] = fetch_comments_for_market(market_id, headers)
196 |
197 | for comment in market.get("comments", []):
198 | if "createdTime" in comment:
199 | comment["createdTime"] = time_utils.convert_timestamp(
200 | comment["createdTime"]
201 | )
202 |
203 | for bet in market.get("community_predictions", []):
204 | if "createdTime" in bet:
205 | bet["createdTime"] = time_utils.convert_timestamp(bet["createdTime"])
206 |
207 | try:
208 | market["date_close"] = market.pop("closeTime")
209 | except Exception as exc:
210 | logger.error(
211 | f"Market id {market_id} got processing error when popping closeTime: {exc}"
212 | )
213 | try:
214 | market["date_begin"] = market.pop("createdTime")
215 | except Exception as exc:
216 | logger.error(
217 | f"Market id {market_id} got processing error when popping createdTime: {exc}"
218 | )
219 | try:
220 | market["is_resolved"] = market.pop("isResolved")
221 | except Exception as exc:
222 | logger.error(
223 | f"Market id {market_id} got processing error when popping isResolved: {exc}"
224 | )
225 | try:
226 | market["question_type"] = market.pop("outcomeType")
227 | except Exception as exc:
228 | logger.error(
229 | f"Market id {market_id} got processing error when popping outcomeType: {exc}"
230 | )
231 |
232 | try:
233 | market["resolved_time"] = market.pop("resolutionTime")
234 | except Exception as exc:
235 | logger.error(
236 | f"Market id {market_id} got processing error when reformatting resolution time: {exc}"
237 | )
238 |
239 | if not market["background"]:
240 | market["background"] = "Not applicable/available for this question."
241 |
242 | market["resolution_criteria"] = "Not applicable/available for this question."
243 | market["data_source"] = "manifold"
244 |
245 | return market
246 |
247 |
248 | def main(n_days):
249 | """
250 | Process and upload Manifold market data.
251 |
252 | Fetch market data from Manifold, process it, and upload the processed data
253 | to an AWS S3 bucket.
254 | Limit the fetched data to markets created within the last N days, if specified.
255 |
256 | Args:
257 | n_days (int or None): Number of days to look back for markets.
258 | If None, fetches all markets.
259 |
260 | Returns:
261 | list: A list of processed market data.
262 | """
263 | logger.info("Starting the manifold script...")
264 |
265 | all_markets = fetch_all_manifold_questions(BASE_URL, headers)
266 |
267 | # Transform time format
268 | ids_with_date_errors = []
269 | for key in [
270 | "createdTime",
271 | "closeTime",
272 | "resolutionTime",
273 | "lastUpdatedTime",
274 | "lastCommentTime",
275 | ]:
276 | for market in all_markets:
277 | if key in market:
278 | try:
279 | market[key] = time_utils.convert_timestamp(market[key])
280 | except BaseException:
281 | ids_with_date_errors.append(market["id"])
282 | continue
283 |
284 | logger.info(
285 | f"Number of manifold questions with date errors: {len(set(ids_with_date_errors))}"
286 | )
287 |
288 | if n_days is not None:
289 | date_limit = datetime.now() - timedelta(days=n_days)
290 | all_markets = [
291 | q
292 | for q in all_markets
293 | if datetime.fromisoformat(q["createdTime"][:]) >= date_limit
294 | ]
295 |
296 | logger.info(f"Number of manifold questions fetched: {len(all_markets)}")
297 |
298 | processed_markets = []
299 | number_of_workers = 50
300 |
301 | with concurrent.futures.ThreadPoolExecutor(
302 | max_workers=number_of_workers
303 | ) as executor:
304 | # Create a list to hold the futures
305 | futures = [
306 | executor.submit(process_market, market, headers) for market in all_markets
307 | ]
308 |
309 | # Use tqdm to create a progress bar. Wrap futures with tqdm.
310 | for future in tqdm(
311 | concurrent.futures.as_completed(futures),
312 | total=len(all_markets),
313 | desc="Processing",
314 | ):
315 | result = future.result()
316 | processed_markets.append(result)
317 |
318 | processed_markets = [
319 | q
320 | for q in processed_markets
321 | if "question_type" in q.keys() and q["community_predictions"]
322 | ]
323 |
324 | # Save itermediate data files
325 | logger.info("Uploading the intermediate data files to s3...")
326 | question_types = list(set([q["question_type"] for q in processed_markets]))
327 | data_scraping.upload_scraped_data(processed_markets, "manifold", question_types)
328 |
329 | for q in processed_markets:
330 | q["date_resolve_at"] = q["community_predictions"][0]["createdTime"]
331 |
332 | logger.info("Start extracting articles from links...")
333 |
334 | for question in processed_markets:
335 | question["extracted_urls"] = information_retrieval.get_urls_from_text(
336 | question["background"]
337 | )
338 | if question["comments"]:
339 | for comment in question["comments"]:
340 | try:
341 | comment = comment["content"]["content"][0]["content"][-1]["text"]
342 | question["extracted_articles_urls"].extend(
343 | information_retrieval.retrieve_webpage_from_background(
344 | comment, question["close_time"]
345 | )
346 | )
347 | except:
348 | continue
349 |
350 | logger.info("Uploading to s3...")
351 | question_types = list(set([q["question_type"] for q in processed_markets]))
352 | data_scraping.upload_scraped_data(processed_markets, "manifold", question_types)
353 |
354 |
355 | if __name__ == "__main__":
356 | parser = argparse.ArgumentParser(description="Fetch Manifold data.")
357 | parser.add_argument(
358 | "--n_days",
359 | type=int,
360 | help="Fetch questions created in the last N days",
361 | default=None,
362 | )
363 | args = parser.parse_args()
364 | main(args.n_days)
365 |
--------------------------------------------------------------------------------
/llm_forecasting/model_eval.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | import asyncio
3 | import logging
4 | import time
5 |
6 | # Related third-party imports
7 | import openai
8 | import together
9 | import anthropic
10 | import google.generativeai as google_ai
11 |
12 | # Local application/library-specific imports
13 | from config.constants import (
14 | OAI_SOURCE,
15 | ANTHROPIC_SOURCE,
16 | TOGETHER_AI_SOURCE,
17 | GOOGLE_SOURCE,
18 | )
19 | from config.keys import (
20 | ANTHROPIC_KEY,
21 | OPENAI_KEY,
22 | TOGETHER_KEY,
23 | GOOGLE_AI_KEY,
24 | )
25 | from utils import model_utils, string_utils
26 |
27 | # Setup code
28 | if ANTHROPIC_KEY:
29 | anthropic_console = anthropic.Anthropic(api_key=ANTHROPIC_KEY)
30 | anthropic_async_client = anthropic.AsyncAnthropic(api_key=ANTHROPIC_KEY)
31 |
32 | if OPENAI_KEY:
33 | oai_async_client = openai.AsyncOpenAI(api_key=OPENAI_KEY)
34 | oai = openai.OpenAI(api_key=OPENAI_KEY)
35 |
36 | if TOGETHER_KEY:
37 | together.api_key = TOGETHER_KEY
38 | client = openai.OpenAI(
39 | api_key=TOGETHER_KEY,
40 | base_url="https://api.together.xyz/v1",
41 | )
42 |
43 | if GOOGLE_AI_KEY:
44 | google_ai.configure(api_key=GOOGLE_AI_KEY)
45 |
46 | # Set up logging
47 | logging.basicConfig(level=logging.INFO)
48 | logger = logging.getLogger(__name__)
49 |
50 |
51 | def get_response_with_retry(api_call, wait_time, error_msg):
52 | """
53 | Make an API call and retry on failure after a specified wait time.
54 |
55 | Args:
56 | api_call (function): API call to make.
57 | wait_time (int): Time to wait before retrying, in seconds.
58 | error_msg (str): Error message to print on failure.
59 | """
60 | while True:
61 | try:
62 | return api_call()
63 | except Exception as e:
64 | logger.info(f"{error_msg}: {e}")
65 | logger.info(f"Waiting for {wait_time} seconds before retrying...")
66 | time.sleep(wait_time)
67 |
68 |
69 | def get_response_from_oai_model(
70 | model_name, prompt, system_prompt, max_tokens, temperature, wait_time
71 | ):
72 | """
73 | Make an API call to the OpenAI API and retry on failure after a specified
74 | wait time.
75 |
76 | Args:
77 | model_name (str): Name of the model to use (such as "gpt-4").
78 | prompt (str): Fully specififed prompt to use for the API call.
79 | system_prompt (str): Prompt to use for system prompt.
80 | max_tokens (int): Maximum number of tokens to sample.
81 | temperature (float): Sampling temperature.
82 | wait_time (int): Time to wait before retrying, in seconds.
83 |
84 | Returns:
85 | str: Response string from the API call.
86 | """
87 |
88 | def api_call():
89 | """
90 | Make an API call to the OpenAI API, without retrying on failure.
91 |
92 | Returns:
93 | str: Response string from the API call.
94 | """
95 | model_input = (
96 | [{"role": "system", "content": system_prompt}] if system_prompt else []
97 | )
98 | model_input.append({"role": "user", "content": prompt})
99 | response = oai.chat.completions.create(
100 | model=model_name,
101 | messages=model_input,
102 | max_tokens=max_tokens,
103 | temperature=temperature,
104 | )
105 | # logger.info(f"full prompt: {prompt}")
106 | return response.choices[0].message.content
107 |
108 | return get_response_with_retry(
109 | api_call, wait_time, "OpenAI API request exceeded rate limit."
110 | )
111 |
112 |
113 | def get_response_from_anthropic_model(
114 | model_name, prompt, max_tokens, temperature, wait_time
115 | ):
116 | """
117 | Make an API call to the Anthropic API and retry on failure after a
118 | specified wait time.
119 |
120 | Args:
121 | model_name (str): Name of the model to use (such as "claude-2").
122 | prompt (str): Fully specififed prompt to use for the API call.
123 | max_tokens (int): Maximum number of tokens to sample.
124 | temperature (float): Sampling temperature.
125 | wait_time (int): Time to wait before retrying, in seconds.
126 |
127 | Returns:
128 | str: Response string from the API call.
129 | """
130 | if max_tokens > 4096:
131 | max_tokens = 4096
132 |
133 | def api_call():
134 | completion = anthropic_console.messages.create(
135 | model=model_name,
136 | messages=[{"role": "user", "content": prompt}],
137 | temperature=temperature,
138 | max_tokens=max_tokens,
139 | )
140 | return completion.content[0].text
141 |
142 | return get_response_with_retry(
143 | api_call, wait_time, "Anthropic API request exceeded rate limit."
144 | )
145 |
146 |
147 | def get_response_from_together_ai_model(
148 | model_name, prompt, max_tokens, temperature, wait_time
149 | ):
150 | """
151 | Make an API call to the Together AI API and retry on failure after a
152 | specified wait time.
153 |
154 | Args:
155 | model_name (str): Name of the model to use (such as "togethercomputer/
156 | llama-2-13b-chat").
157 | prompt (str): Fully specififed prompt to use for the API call.
158 | max_tokens (int): Maximum number of tokens to sample.
159 | temperature (float): Sampling temperature.
160 | wait_time (int): Time to wait before retrying, in seconds.
161 |
162 | Returns:
163 | str: Response string from the API call.
164 | """
165 |
166 | def api_call():
167 | chat_completion = client.chat.completions.create(
168 | model=model_name,
169 | messages=[
170 | {"role": "user", "content": prompt},
171 | ],
172 | temperature=temperature,
173 | max_tokens=max_tokens,
174 | )
175 | response = chat_completion.choices[0].message.content
176 |
177 | return response
178 |
179 | return get_response_with_retry(
180 | api_call, wait_time, "Together AI API request exceeded rate limit."
181 | )
182 |
183 |
184 | def get_response_from_google_model(
185 | model_name, prompt, max_tokens, temperature, wait_time
186 | ):
187 | """
188 | Make an API call to the Together AI API and retry on failure after a specified wait time.
189 |
190 | Args:
191 | model (str): Name of the model to use (such as "gemini-pro").
192 | prompt (str): Initial prompt for the API call.
193 | max_tokens (int): Maximum number of tokens to sample.
194 | temperature (float): Sampling temperature.
195 | wait_time (int): Time to wait before retrying, in seconds.
196 |
197 | Returns:
198 | str: Response string from the API call.
199 | """
200 | model = google_ai.GenerativeModel(model_name)
201 |
202 | response = model.generate_content(
203 | prompt,
204 | generation_config=google_ai.types.GenerationConfig(
205 | candidate_count=1,
206 | max_output_tokens=max_tokens,
207 | temperature=temperature,
208 | ),
209 | )
210 | return response.text
211 |
212 |
213 | def get_response_from_model(
214 | model_name,
215 | prompt,
216 | system_prompt="",
217 | max_tokens=2000,
218 | temperature=0.8,
219 | wait_time=30,
220 | ):
221 | """
222 | Make an API call to the specified model and retry on failure after a
223 | specified wait time.
224 |
225 | Args:
226 | model_name (str): Name of the model to use (such as "gpt-4").
227 | prompt (str): Fully specififed prompt to use for the API call.
228 | system_prompt (str, optional): Prompt to use for system prompt.
229 | max_tokens (int, optional): Maximum number of tokens to generate.
230 | temperature (float, optional): Sampling temperature.
231 | wait_time (int, optional): Time to wait before retrying, in seconds.
232 | """
233 | model_source = model_utils.infer_model_source(model_name)
234 | if model_source == OAI_SOURCE:
235 | return get_response_from_oai_model(
236 | model_name, prompt, system_prompt, max_tokens, temperature, wait_time
237 | )
238 | elif model_source == ANTHROPIC_SOURCE:
239 | return get_response_from_anthropic_model(
240 | model_name, prompt, max_tokens, temperature, wait_time
241 | )
242 | elif model_source == TOGETHER_AI_SOURCE:
243 | return get_response_from_together_ai_model(
244 | model_name, prompt, max_tokens, temperature, wait_time
245 | )
246 | elif model_source == GOOGLE_SOURCE:
247 | return get_response_from_google_model(
248 | model_name, prompt, max_tokens, temperature, wait_time
249 | )
250 | else:
251 | return "Not a valid model source."
252 |
253 |
254 | async def get_async_response(
255 | prompt,
256 | model_name="gpt-3.5-turbo-1106",
257 | temperature=0.0,
258 | max_tokens=8000,
259 | ):
260 | """
261 | Asynchronously get a response from the OpenAI API.
262 |
263 | Args:
264 | prompt (str): Fully specififed prompt to use for the API call.
265 | model_name (str, optional): Name of the model to use (such as "gpt-3.5-turbo").
266 | temperature (float, optional): Sampling temperature.
267 | max_tokens (int, optional): Maximum number of tokens to sample.
268 |
269 | Returns:
270 | str: Response string from the API call (not the dictionary).
271 | """
272 | model_source = model_utils.infer_model_source(model_name)
273 | while True:
274 | try:
275 | if model_source == OAI_SOURCE:
276 | response = await oai_async_client.chat.completions.create(
277 | model=model_name,
278 | messages=[{"role": "user", "content": prompt}],
279 | temperature=temperature,
280 | )
281 | return response.choices[0].message.content
282 | elif model_source == ANTHROPIC_SOURCE:
283 | response = await anthropic_async_client.messages.create(
284 | model=model_name,
285 | messages=[{"role": "user", "content": prompt}],
286 | temperature=temperature,
287 | max_tokens=4096,
288 | )
289 | return response.content[0].text
290 | elif model_source == GOOGLE_SOURCE:
291 | model = google_ai.GenerativeModel(model_name)
292 | response = await model.generate_content_async(
293 | prompt,
294 | generation_config=google_ai.types.GenerationConfig(
295 | candidate_count=1,
296 | max_output_tokens=max_tokens,
297 | temperature=temperature,
298 | ),
299 | )
300 | return response.text
301 | elif model_source == TOGETHER_AI_SOURCE:
302 | chat_completion = await asyncio.to_thread(
303 | client.chat.completions.create,
304 | model=model_name,
305 | messages=[
306 | {"role": "user", "content": prompt},
307 | ],
308 | temperature=temperature,
309 | max_tokens=max_tokens,
310 | )
311 | return chat_completion.choices[0].message.content
312 | else:
313 | logger.debug("Not a valid model source: {model_source}")
314 | return ""
315 | except (Exception, BaseException) as e:
316 | logger.info(f"Exception, erorr message: {e}")
317 | logger.info("Waiting for 30 seconds before retrying...")
318 | time.sleep(30)
319 | continue
320 |
321 |
322 | def get_openai_embedding(texts, model="text-embedding-3-large"):
323 | """
324 | Query OpenAI's text embedding model to get the embedding of the given text.
325 |
326 | Args:
327 | texts (list of str): List of texts to embed.
328 |
329 | Returns:
330 | list of Embedding objects: List of embeddings, where embedding[i].embedding is a list of floats.
331 | """
332 | texts = [text.replace("\n", " ") for text in texts]
333 | while True:
334 | try:
335 | embedding = oai.embeddings.create(input=texts, model=model)
336 | return embedding.data
337 | except Exception as e:
338 | logger.info(f"erorr message: {e}")
339 | logger.info("Waiting for 30 seconds before retrying...")
340 | time.sleep(30)
341 | continue
342 |
343 |
344 | async def async_make_forecast(
345 | question,
346 | background_info,
347 | resolution_criteria,
348 | dates,
349 | retrieved_info,
350 | reasoning_prompt_templates,
351 | model_name="gpt-4-1106-preview",
352 | temperature=1.0,
353 | return_prompt=False,
354 | ):
355 | """
356 | Asynchronously make forecasts using the given information.
357 |
358 | Args:
359 | question (str): Question to ask the model.
360 | background_info (str): Background information to provide to the model.
361 | resolution_criteria (str): Resolution criteria to provide to the model.
362 | dates (str): Dates to provide to the model.
363 | retrieved_info (str): Retrieved information to provide to the model.
364 | reasoning_prompt_templates (list of str): List of reasoning prompt templates to use.
365 | model_name (str, optional): Name of the model to use (such as "gpt-4-1106-preview").
366 | temperature (float, optional): Sampling temperature.
367 | return_prompt (bool, optional): Whether to return the full prompt or not.
368 |
369 | Returns:
370 | list of str: List of forecasts and reasonings from the model.
371 | """
372 | assert (
373 | len(reasoning_prompt_templates) > 0
374 | ), "No reasoning prompt templates provided."
375 | reasoning_full_prompts = []
376 | for reasoning_prompt_template in reasoning_prompt_templates:
377 | template, fields = reasoning_prompt_template
378 | reasoning_full_prompts.append(
379 | string_utils.get_prompt(
380 | template,
381 | fields,
382 | question=question,
383 | retrieved_info=retrieved_info,
384 | background=background_info,
385 | resolution_criteria=resolution_criteria,
386 | dates=dates,
387 | )
388 | )
389 | # Get all reasonings from the model
390 | reasoning_tasks = [
391 | get_async_response(
392 | prompt,
393 | model_name=model_name,
394 | temperature=temperature,
395 | )
396 | for prompt in reasoning_full_prompts
397 | ]
398 | # a list of strings
399 | all_reasonings = await asyncio.gather(*reasoning_tasks)
400 | logger.info(
401 | "Finished {} base reasonings generated by {}".format(
402 | len(reasoning_full_prompts), model_name
403 | )
404 | )
405 | if return_prompt:
406 | return all_reasonings, reasoning_full_prompts
407 | return all_reasonings
408 |
--------------------------------------------------------------------------------
/llm_forecasting/ensemble.py:
--------------------------------------------------------------------------------
1 | # Standard library imports
2 | import logging
3 |
4 | # Related third-party imports
5 | import numpy as np
6 |
7 | # Local application/library-specific imports
8 | from config.constants import TOKENS_TO_PROBS_DICT
9 | import model_eval
10 | from prompts.prompts import PROMPT_DICT
11 | from utils import string_utils, utils
12 |
13 | # Set up logging
14 | logging.basicConfig(level=logging.INFO)
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | def concatenate_reasonings(reasonings):
19 | """
20 | Concatenate a list of reasonings into a single string.
21 |
22 | Each reasoning is separated by a newline, a separator (---) and a number
23 | (Response 1, 2, 3, ...).
24 |
25 | Args:
26 | reasonings (list[str]): A list of reasonings.
27 |
28 | Returns:
29 | str: A single string containing all reasonings.
30 | """
31 | concat_reasonings = []
32 | for i, reasoning in enumerate(reasonings):
33 | reason_str = f"Response from forecaster {i + 1}:\n{reasoning}"
34 | concat_reasonings.append(reason_str)
35 | return "---\n" + "\n\n-\n".join(concat_reasonings) + "\n---"
36 |
37 |
38 | async def meta_reason(
39 | question,
40 | background_info,
41 | resolution_criteria,
42 | today_to_close_date_range,
43 | retrieved_info,
44 | reasoning_prompt_templates,
45 | base_model_names=["gpt-4-1106-preview", "claude-2.1"],
46 | base_temperature=1.0, # temperature for the base reasonings
47 | aggregation_method="meta",
48 | answer_type="probability",
49 | weights=None,
50 | end_words=list(TOKENS_TO_PROBS_DICT["ten_options"].keys()),
51 | meta_model_name="gpt-4-1106-preview",
52 | meta_prompt_template=PROMPT_DICT["meta_reasoning"]["0"],
53 | meta_temperature=0.2,
54 | ):
55 | """
56 | Given a question and its retrieved articles, elicit model reasonings via
57 | reasoning_prompts, aggregate the reasonings and return the answer.
58 |
59 | Args:
60 | question (str): Forecast question to be answered.
61 | background_info (str): Background information of the question.
62 | resolution_criteria (str): Resolution criteria for the question.
63 | retrieved_info (str): Retrieved articles from our news retrieval system
64 | (a concatenation of the article titles and summaries).
65 | today_to_close_date_range (str): A string containing the today's date
66 | and the close date.
67 | retrieved_info (str): Retrieved articles from our news retrieval system.
68 | reasoning_prompt_templates (list[list[[str]]): A list of reasoning prompts; string templates
69 | that must have tow fields {question} and {retrieved_info}.
70 | There should be a list of reasoning prompts for each base model.
71 | base_model_names (list[str], optional): A list of base model names.
72 | base_temperature (float, optional): Sampling temperature for the base reasonings.
73 | aggregation_method (str, optional): The method to aggregate the reasonings.
74 | Must be either 'vote-or-median','mean', 'weighted-mean', 'meta'.
75 | answer_type (str, optional): The type of the answer to return. Must be either 'probability' or 'tokens'.
76 | weights (np.array, optional): A numpy array of weights for the reasonings.
77 | It will only be used if aggregation_method is 'weighted-mean'.
78 | It must have the same length as reasoning_prompt_templates: shape[0] ==
79 | len(reasoning_prompt_templates).
80 | end_words (list, optional): A list of words like "Very Unlikely" and "Very Likely" that represent the answer.
81 | It will only be used if answer_type is 'tokens' and aggregation_method
82 | is 'vote-or-median'.
83 | meta_model_name (str, optional): The name of the meta model.
84 | meta_prompt_template (tuple of str, optional): A meta reasoning prompt template; a string template
85 | meta_temparature (float, optional): Sampling temperature for the meta-reasoning.
86 |
87 | Returns:
88 | tuple: The method returns the final answer, all base reasonings, and the meta-reasoning
89 | (if aggregation_method is 'meta').
90 |
91 | For the final answer:
92 | If answer_type is 'probability' and aggregation_method is 'vote-or-median',
93 | the function returns the median of all answers.
94 | If answer_type is 'tokens' and aggregation_method is 'vote-or-median',
95 | the function returns the most frequent answer.
96 | If the aggregation_method is 'meta', the function returns an answer
97 | by eliciting another meta-reasoning using the meta_prompt_template.
98 | """
99 | assert answer_type in [
100 | "probability",
101 | "tokens",
102 | ], "answer_type must be either 'probability' or 'tokens'"
103 | assert aggregation_method in [
104 | "vote-or-median",
105 | "meta",
106 | "mean",
107 | "weighted-mean",
108 | ], "aggregation_method must be either 'vote-or-median', 'meta', 'mean', or 'weighted-mean'"
109 | if aggregation_method == "weighted-mean":
110 | assert (
111 | weights is not None
112 | ), "weights must be provided if aggregation_method is 'weighted-mean'"
113 | assert weights.shape[0] == len(
114 | reasoning_prompt_templates
115 | ), "weights must have the same length as reasoning_prompt_templates"
116 | all_base_reasonings = []
117 | all_base_reasoning_full_prompts = []
118 | for i, base_model_name in enumerate(base_model_names):
119 | (
120 | base_reasonings,
121 | base_reasoning_full_prompts,
122 | ) = await model_eval.async_make_forecast(
123 | question=question,
124 | background_info=background_info,
125 | resolution_criteria=resolution_criteria,
126 | dates=today_to_close_date_range,
127 | retrieved_info=retrieved_info,
128 | reasoning_prompt_templates=reasoning_prompt_templates[i],
129 | model_name=base_model_name,
130 | temperature=base_temperature,
131 | return_prompt=True,
132 | )
133 | # list of lists (not flattened)
134 | all_base_reasonings.append(base_reasonings)
135 | all_base_reasoning_full_prompts.append(base_reasoning_full_prompts)
136 | aggregation_dict = aggregate_base_reasonings(
137 | base_reasonings=all_base_reasonings,
138 | question=question,
139 | background_info=background_info,
140 | today_to_close_date_range=today_to_close_date_range,
141 | resolution_criteria=resolution_criteria,
142 | retrieved_info=retrieved_info,
143 | aggregation_method=aggregation_method,
144 | answer_type=answer_type,
145 | weights=weights,
146 | end_words=end_words,
147 | model_name=meta_model_name, # meta model name
148 | meta_prompt_template=meta_prompt_template,
149 | meta_temperature=meta_temperature,
150 | )
151 | aggregation_dict["base_reasoning_full_prompts"] = all_base_reasoning_full_prompts
152 | return aggregation_dict
153 |
154 |
155 | def aggregate_base_reasonings(
156 | base_reasonings,
157 | question,
158 | background_info,
159 | today_to_close_date_range,
160 | resolution_criteria,
161 | retrieved_info,
162 | aggregation_method="meta",
163 | answer_type="probability",
164 | weights=None,
165 | end_words=list(TOKENS_TO_PROBS_DICT["ten_options"].keys()),
166 | model_name="gpt-4-1106-preview", # meta model name
167 | meta_prompt_template=PROMPT_DICT["meta_reasoning"]["0"],
168 | meta_temperature=0.2,
169 | ):
170 | """
171 | Aggregate a list of lists of base reasonings via ensembling method.
172 |
173 | Args:
174 | base_reasonings (list[list[str]]): A list of lists of base reasonings.
175 | question (str): Forecast question to be answered.
176 | background_info (str): Background information of the question.
177 | today_to_close_date_range (str): A string containing the today's date and the close date.
178 | resolution_criteria (str): Resolution criteria for the question.
179 | retrieved_info (str): Retrieved articles from our news retrieval system
180 | (a concatenation of the article titles and summaries).
181 | aggregation_method (str, optional): The method to aggregate the reasonings.
182 | Must be either 'vote-or-median','mean', 'weighted-mean', 'meta'.
183 | answer_type (str, optional): The type of the answer to return. Must be either 'probability' or 'tokens'.
184 | weights (np.array, optional): A numpy array of weights for the reasonings.
185 | It will only be used if aggregation_method is 'weighted-mean'.
186 | It must have the same length as reasoning_prompt_templates: shape[0] ==
187 | len(reasoning_prompt_templates).
188 | end_words (list, optional): A list of words like "Very Unlikely" and "Very Likely" that represent the answer.
189 | It will only be used if answer_type is 'tokens' and aggregation_method
190 | is 'vote-or-median'.
191 | model_name (str, optional): The name of the meta model.
192 | meta_prompt_template (str, optional): A string that represents the meta reasoning prompt.
193 | meta_temperature (float, optional): Sampling temperature for the meta-reasoning.
194 |
195 | Returns:
196 | dict: A dictionary containing the final answer, all base reasonings, and the meta-reasoning
197 | (if aggregation_method is 'meta').
198 | """
199 | assert len(base_reasonings) > 0, "base_reasonings must be a non-empty list"
200 | # Extract final prediction from each reasoning
201 | all_base_predictions = [] # list of lists of floats
202 | for base_reasonings_list in base_reasonings:
203 | base_predictions = [ # for one model; list of floats
204 | string_utils.extract_prediction(
205 | reasoning, answer_type=answer_type, end_words=end_words
206 | )
207 | for reasoning in base_reasonings_list
208 | ]
209 | all_base_predictions.append(base_predictions)
210 | flattened_all_base_predictions = [
211 | item for sublist in all_base_predictions for item in sublist
212 | ]
213 | if len(flattened_all_base_predictions) == 1: # no aggregation needed
214 | return {
215 | "base_reasonings": base_reasonings,
216 | "base_predictions": all_base_predictions,
217 | "meta_prediction": flattened_all_base_predictions[0],
218 | "meta_prompt": None,
219 | "meta_reasoning": None,
220 | }
221 | if answer_type == "probability" and aggregation_method != "meta":
222 | if aggregation_method == "mean" or aggregation_method is None:
223 | meta_prediction = np.mean(flattened_all_base_predictions) # default to mean
224 | if aggregation_method == "vote-or-median":
225 | meta_prediction = np.median(flattened_all_base_predictions)
226 | if aggregation_method == "weighted-mean":
227 | meta_prediction = np.average(
228 | flattened_all_base_predictions, weights=weights
229 | )
230 | if meta_prediction is None or meta_prediction < 0.0 or meta_prediction > 1.0:
231 | logger.debug(
232 | "final_answer {} is not between 0 and 1".format(meta_prediction)
233 | )
234 | meta_prediction = 0.5 # default to 0.5
235 | return {
236 | "base_reasonings": base_reasonings,
237 | "base_predictions": all_base_predictions,
238 | "meta_prediction": meta_prediction,
239 | "meta_prompt": None,
240 | "meta_reasoning": None,
241 | }
242 | elif answer_type == "tokens" and aggregation_method == "vote-or-median":
243 | meta_prediction = utils.most_frequent_item(
244 | flattened_all_base_predictions
245 | ) # majority vote
246 | if meta_prediction is None or not string_utils.is_string_in_list(
247 | meta_prediction, end_words
248 | ):
249 | logger.debug("final_answer {} is not valid".format(meta_prediction))
250 | meta_prediction = "Slightly Unlikely" # default to "Slightly Unlikely"
251 | return {
252 | "base_reasonings": base_reasonings,
253 | "base_predictions": all_base_predictions,
254 | "meta_prediction": meta_prediction,
255 | "meta_prompt": None,
256 | "meta_reasoning": None,
257 | }
258 |
259 | # If aggregation_method is 'meta', elicit a meta-reasoning using the
260 | # meta_prompt_template
261 | prompt, fields = meta_prompt_template
262 | flattened_base_reasonings = [
263 | item for sublist in base_reasonings for item in sublist
264 | ]
265 | meta_full_prompt = string_utils.get_prompt(
266 | prompt,
267 | fields,
268 | question=question,
269 | background=background_info,
270 | dates=today_to_close_date_range,
271 | retrieved_info=retrieved_info,
272 | reasoning=concatenate_reasonings(flattened_base_reasonings),
273 | resolution_criteria=resolution_criteria,
274 | )
275 | meta_reasoning = model_eval.get_response_from_model(
276 | model_name=model_name,
277 | prompt=meta_full_prompt,
278 | temperature=meta_temperature,
279 | ) # raw response
280 | # Extract final prediction from raw response
281 | if answer_type == "probability":
282 | # Get the probability from the meta-reasoning
283 | meta_prediction = string_utils.extract_probability_with_stars(meta_reasoning)
284 | if meta_prediction is None or meta_prediction < 0.0 or meta_prediction > 1.0:
285 | logger.debug(
286 | "final_answer {} is not between 0 and 1".format(meta_prediction)
287 | )
288 | meta_prediction = 0.5
289 | elif answer_type == "tokens":
290 | # Get the final token answer from the meta-reasoning
291 | meta_prediction = string_utils.find_end_word(meta_reasoning, end_words)
292 | if meta_prediction is None or not string_utils.is_string_in_list(
293 | meta_prediction, end_words
294 | ):
295 | logger.debug("final_answer {} is not valid".format(meta_prediction))
296 | meta_prediction = "Slightly Unlikely"
297 |
298 | return {
299 | "base_reasonings": base_reasonings,
300 | "base_predictions": all_base_predictions,
301 | "meta_prompt": meta_full_prompt,
302 | "meta_reasoning": meta_reasoning,
303 | "meta_prediction": meta_prediction,
304 | }
305 |
306 |
307 | def calculate_normalized_weighted_trimmed_mean(predictions):
308 | """
309 | Calculate the normalized weighted trimmed mean of a set of predictions.
310 |
311 | This function performs the following steps:
312 | 1. Compute the median of the predictions to serve as a reference point.
313 | 2. Identify the prediction that is furthest from the median value.
314 | 3. Reduce the weight of the furthest prediction by half, acknowledging its potential outlier status.
315 | 4. Equally distribute the reduced weight from the furthest prediction among the remaining predictions, ensuring the total weight remains constant.
316 | 5. Compute and return the weighted mean of the predictions, using the adjusted weights to account for outliers and variance.
317 |
318 | Parameters:
319 | - predictions (np.ndarray): An array of numerical prediction values.
320 |
321 | Returns:
322 | - float: The normalized weighted trimmed mean of the predictions.
323 | """
324 | # Step 1: Find the median
325 | median_prediction = np.median(predictions)
326 |
327 | # Step 2: Determine the prediction farthest from the median
328 | distances = np.abs(predictions - median_prediction)
329 | max_distance = np.max(distances)
330 |
331 | # Step 3: Down-weight the furthest prediction by half
332 | weights = np.ones(len(predictions))
333 | weights[distances == max_distance] *= 0.5
334 |
335 | # Step 4: Distribute the saved weight among other predictions
336 | saved_weight = (1.0 - 0.5) / (len(predictions) - 1)
337 | weights[distances != max_distance] += saved_weight
338 |
339 | # Step 5: Calculate the weighted mean
340 | weighted_mean = np.average(predictions, weights=weights)
341 |
342 | return weighted_mean
343 |
--------------------------------------------------------------------------------