Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions code_tests/integration_tests/test_ai_models/test_general_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def _all_tests() -> list[ModelTest]:
GeneralLlm(model="openai/gpt-4o"),
[{"role": "user", "content": test_data.get_cheap_user_message()}],
),
ModelTest(
GeneralLlm(model="openai/gpt-4o"),
[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": test_data.get_cheap_user_message()},
],
),
ModelTest(
GeneralLlm(model="o3-mini", reasoning_effort="low"),
test_data.get_cheap_user_message(),
Expand Down Expand Up @@ -99,6 +106,21 @@ def test_general_llm_instances_run(
assert cost_manager.current_usage > 0, "No cost was incurred"


def test_system_prompt_parameter() -> None:
model = GeneralLlm(model="gpt-4o")
response = asyncio.run(
model.invoke(
"Hello, world!",
system_prompt="If the user says 'Hello, world!', say 'Hello there friend! I am a turkey.' back to them.",
)
)
assert response is not None, "Response is None"
assert response != "", "Response is an empty string"
logger.info(f"Response: {response}")
assert "friend" in response.lower(), "Response is not friendly"
assert "turkey" in response.lower(), "Response is not a turkey"


def test_timeout_works() -> None:
model = GeneralLlm(model="gpt-4o", timeout=0.1)
model_input = "Hello, world!"
Expand Down
114 changes: 114 additions & 0 deletions code_tests/integration_tests/test_metaculus_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from datetime import datetime, timezone

import pendulum
Expand All @@ -15,6 +16,7 @@
from forecasting_tools.data_models.questions import (
BinaryQuestion,
CanceledResolution,
Category,
ConditionalQuestion,
DateQuestion,
DiscreteQuestion,
Expand Down Expand Up @@ -56,6 +58,11 @@ def test_get_binary_question_type_from_id(self) -> None:
assert question.question_type == "binary"
assert question.question_ids_of_group is None
assert question.is_in_main_feed is True
assert set([category.name for category in question.categories]) == {
"Health & Pandemics",
"Nuclear Technology & Risks",
"Technology",
}, f"Categories are not correct for post ID {question.id_of_post}. Categories: {question.categories}"
assert_basic_question_attributes_not_none(question, question.id_of_post)

def test_get_numeric_question_type_from_id(self) -> None:
Expand All @@ -71,6 +78,9 @@ def test_get_numeric_question_type_from_id(self) -> None:
assert question.question_weight == 1.0
assert question.get_question_type() == "numeric"
assert question.question_type == "numeric"
assert set([category.name for category in question.categories]) == {
"Health & Pandemics"
}, f"Categories are not correct for post ID {question.id_of_post}. Categories: {question.categories}"
assert_basic_question_attributes_not_none(question, question.id_of_post)
assert question.lower_bound == 0
assert question.upper_bound == 200
Expand Down Expand Up @@ -635,6 +645,7 @@ def test_log_scale_all_repeated_values(self) -> None:
]
self._check_cdf_processes_and_posts_correctly(percentiles, question)

@pytest.mark.skip(reason="The test question is now closed for forecasting")
def test_log_scale_another_edge_case(self) -> None:
# This test was able to replicate a floating point epsilon error at one point.
url = "https://dev.metaculus.com/questions/7546"
Expand Down Expand Up @@ -1103,19 +1114,34 @@ def assert_basic_question_attributes_not_none(
question.question_text is not None
), f"Question text is None for post ID {post_id}"
assert question.close_time is not None, f"Close time is None for post ID {post_id}"
assert (
question.close_time.tzinfo is not None
), f"Close time is not timezone aware for post ID {post_id}"
assert question.open_time is not None, f"Open time is None for post ID {post_id}"
assert (
question.open_time.tzinfo is not None
), f"Open time is not timezone aware for post ID {post_id}"
assert (
question.published_time is not None
), f"Published time is None for post ID {post_id}"
assert (
question.published_time.tzinfo is not None
), f"Published time is not timezone aware for post ID {post_id}"
assert (
question.scheduled_resolution_time is not None
), f"Scheduled resolution time is None for post ID {post_id}"
assert (
question.scheduled_resolution_time.tzinfo is not None
), f"Scheduled resolution time is not timezone aware for post ID {post_id}"
assert (
question.includes_bots_in_aggregates is not None
), f"Includes bots in aggregates is None for post ID {post_id}"
assert (
question.cp_reveal_time is not None
), f"CP reveal time is None for post ID {post_id}"
assert (
question.cp_reveal_time.tzinfo is not None
), f"CP reveal time is not timezone aware for post ID {post_id}"
assert isinstance(
question.state, QuestionState
), f"State is not a QuestionState for post ID {post_id}"
Expand All @@ -1134,6 +1160,10 @@ def assert_basic_question_attributes_not_none(
assert question.actual_resolution_time is None or isinstance(
question.actual_resolution_time, datetime
), f"Actual resolution time is not a datetime for post ID {post_id}"
if question.actual_resolution_time is not None:
assert (
question.actual_resolution_time.tzinfo is not None
), f"Actual resolution time is not timezone aware for post ID {post_id}"
assert isinstance(
question.api_json, dict
), f"API JSON is not a dict for post ID {post_id}"
Expand All @@ -1156,13 +1186,19 @@ def assert_basic_question_attributes_not_none(
assert question.date_accessed > pendulum.now().subtract(
days=1
), f"Date accessed is not in the past for post ID {post_id}"
assert (
question.date_accessed.tzinfo is not None
), f"Date accessed is not timezone aware for post ID {post_id}"
assert isinstance(
question.already_forecasted, bool
), f"Already forecasted is not a boolean for post ID {post_id}"
if question.already_forecasted:
assert (
question.timestamp_of_my_last_forecast is not None
), f"Timestamp of my last forecast is None for post ID {post_id}"
assert (
question.timestamp_of_my_last_forecast.tzinfo is not None
), f"Timestamp of my last forecast is not timezone aware for post ID {post_id}"
if isinstance(question, NumericQuestion):
assert (
question.unit_of_measure is not None
Expand All @@ -1186,6 +1222,9 @@ def assert_basic_question_attributes_not_none(
assert isinstance(question.question_ids_of_group, list)
assert all(isinstance(q_id, int) for q_id in question.question_ids_of_group)
assert question.group_question_option is not None
assert all(
isinstance(category, Category) for category in question.categories
), f"Categories is not a list of strings for post ID {post_id}"


def assert_questions_match_filter( # NOSONAR
Expand Down Expand Up @@ -1313,3 +1352,78 @@ def assert_questions_match_filter( # NOSONAR
assert (
filter_passes
), f"Question {question.id_of_post} has no community prediction at access time"


class TestAdminFunctions:

def test_all_admin_functions(self) -> None:
token = os.getenv("ADMIN_METACULUS_TOKEN")
if token is None:
raise ValueError("ADMIN_METACULUS_TOKEN is not set")
client = MetaculusClient(
base_url="https://dev.metaculus.com/api",
token=token,
)
question_to_create = client.get_question_by_url(
"https://dev.metaculus.com/questions/39162/"
)
project_id = 1156 # https://dev.metaculus.com/tournament/beta-testing/
slug = "beta-testing"
question_to_create.default_project_id = project_id
question_to_create.tournament_slugs = [slug]

### Create and approve question ###
created_question = client.create_question(question_to_create)
client.approve_question(created_question)

assert created_question is not None
assert created_question.id_of_post is not None
assert created_question.id_of_question is not None
assert created_question.default_project_id == project_id
assert created_question.id_of_post != question_to_create.id_of_post
assert created_question.id_of_question != question_to_create.id_of_question

assert created_question.question_text == question_to_create.question_text
assert created_question.open_time == question_to_create.open_time
assert created_question.close_time == question_to_create.close_time
assert (
created_question.includes_bots_in_aggregates
== question_to_create.includes_bots_in_aggregates
)
assert created_question.cp_reveal_time == question_to_create.cp_reveal_time
assert created_question.question_weight == question_to_create.question_weight
assert (
created_question.resolution_string == question_to_create.resolution_string
)
assert created_question.conditional_type == question_to_create.conditional_type
assert_basic_question_attributes_not_none(
created_question, created_question.id_of_post
)
assert str(created_question.categories) == str(question_to_create.categories)
assert set(category.name for category in created_question.categories) == {
"Artificial Intelligence"
}
assert (
created_question.resolution_criteria
== question_to_create.resolution_criteria
)
assert created_question.fine_print == question_to_create.fine_print
assert created_question.background_info == question_to_create.background_info
assert set(created_question.tournament_slugs) == {slug}
assert created_question.published_time == question_to_create.published_time

### Resolve question ###
client.resolve_question(created_question.id_of_question, "yes", pendulum.now())
resolved_question = client.get_question_by_post_id(created_question.id_of_post)
assert resolved_question.state == QuestionState.RESOLVED
assert resolved_question.actual_resolution_time is not None
assert resolved_question.resolution_string == "yes"

### Unresolve question ###
client.unresolve_question(created_question.id_of_question)
unresolved_question = client.get_question_by_post_id(
created_question.id_of_post
)
assert unresolved_question.actual_resolution_time is None
assert unresolved_question.resolution_string is None
assert unresolved_question.state != QuestionState.RESOLVED
1 change: 1 addition & 0 deletions forecasting_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
)
from forecasting_tools.data_models.numeric_report import NumericReport as NumericReport
from forecasting_tools.data_models.questions import BinaryQuestion as BinaryQuestion
from forecasting_tools.data_models.questions import Category as Category
from forecasting_tools.data_models.questions import DateQuestion as DateQuestion
from forecasting_tools.data_models.questions import DiscreteQuestion as DiscreteQuestion
from forecasting_tools.data_models.questions import (
Expand Down
23 changes: 16 additions & 7 deletions forecasting_tools/ai_models/general_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class GeneralLlm(
OutputsText,
):
"""
A wrapper around litellm's acompletion function that adds some functionality
A wrapper around litellm's acompletion function that adds functionality
like rate limiting, retry logic, metaculus proxy, and cost callback handling.

Litellm support every model, most every parameter, and acts as one interface for every provider.
Expand Down Expand Up @@ -231,23 +231,32 @@ def __init__(

ModelTracker.give_cost_tracking_warning_if_needed(self._litellm_model)

async def invoke(self, prompt: ModelInputType, system_prompt: str | None = None) -> str:
async def invoke(
self, prompt: ModelInputType, system_prompt: str | None = None
) -> str:
if system_prompt is not None and (
isinstance(prompt, str) or isinstance(prompt, VisionMessageData)
):
prompt = self.model_input_to_message(prompt, system_prompt)
elif system_prompt is not None and isinstance(prompt, list):
raise ValueError(
"System prompt cannot be used with list of messages since the list may include a system message already"
)
response: TextTokenCostResponse = (
await self._invoke_with_request_cost_time_and_token_limits_and_retry(prompt, system_prompt=system_prompt)
await self._invoke_with_request_cost_time_and_token_limits_and_retry(prompt)
)
data = response.data
return data

@RetryableModel._retry_according_to_model_allowed_tries
async def _invoke_with_request_cost_time_and_token_limits_and_retry(
self, prompt: ModelInputType, system_prompt: str | None = None
self,
prompt: ModelInputType,
) -> Any:
logger.debug(f"Invoking model with prompt: {prompt}")

prompt = self.model_input_to_message(prompt, system_prompt)

with track_generation(
input=prompt,
input=self.model_input_to_message(prompt),
model=self.model,
) as span:
direct_call_response = await self._mockable_direct_call_to_model(prompt)
Expand Down
29 changes: 27 additions & 2 deletions forecasting_tools/data_models/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ class OutOfBoundsResolution(Enum):
ConditionalSubQuestionType = Literal["parent", "child", "yes", "no"]


class Category(BaseModel, Jsonable):
id: int
name: str
slug: str | None = None
emoji: str | None = None
description: str | None = None
type: Literal["category"] = "category"


class MetaculusQuestion(BaseModel, Jsonable):
question_text: str
id_of_post: int | None = Field(
Expand Down Expand Up @@ -110,6 +119,7 @@ class MetaculusQuestion(BaseModel, Jsonable):
custom_metadata: dict = Field(
default_factory=dict
) # Additional metadata not tracked above or through the Metaculus API
categories: list[Category] = Field(default_factory=list)

@model_validator(mode="after")
def add_timezone_to_dates(self) -> MetaculusQuestion:
Expand All @@ -134,12 +144,26 @@ def from_metaculus_api_json(cls, post_api_json: dict) -> MetaculusQuestion:
tournaments: list[dict] = post_api_json["projects"]["tournament"] # type: ignore
tournament_slugs = [str(t["slug"]) for t in tournaments]
except KeyError:
tournament_slugs = []
try:
question_series: list[dict] = post_api_json["projects"]["question_series"] # type: ignore
tournament_slugs = [str(q["slug"]) for q in question_series]
except KeyError:
tournament_slugs = []

group_question_option = question_json.get("label", None)
if group_question_option is not None and group_question_option.strip() == "":
group_question_option = None

try:
categories_dict_list = post_api_json["projects"]["category"]
except KeyError:
categories_dict_list = []
categories = [
Category(**category_dict) for category_dict in categories_dict_list
]

published_time = cls._parse_api_date(post_api_json.get("published_at"))

question = MetaculusQuestion(
# NOTE: Reminder - When adding new fields, consider if group questions
# need to be parsed differently (i.e. if the field information is part of the post_json)
Expand All @@ -162,7 +186,7 @@ def from_metaculus_api_json(cls, post_api_json: dict) -> MetaculusQuestion:
scheduled_resolution_time=cls._parse_api_date(
question_json.get("scheduled_resolve_time")
),
published_time=cls._parse_api_date(post_api_json.get("published_at")),
published_time=published_time,
cp_reveal_time=cls._parse_api_date(question_json.get("cp_reveal_time")),
open_time=cls._parse_api_date(question_json.get("open_time")),
already_forecasted=is_forecasted,
Expand All @@ -180,6 +204,7 @@ def from_metaculus_api_json(cls, post_api_json: dict) -> MetaculusQuestion:
group_question_option=group_question_option,
api_json=post_api_json,
conditional_type=question_json.get("conditional_type", None),
categories=categories,
)
return question

Expand Down
22 changes: 22 additions & 0 deletions forecasting_tools/helpers/metaculus_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from datetime import datetime
from typing import List, Literal, overload

import typing_extensions
Expand Down Expand Up @@ -231,3 +232,24 @@ def get_benchmark_questions(
error_if_question_target_missed,
group_question_mode,
)

### ADMIN FUNCTIONS ###
@classmethod
def resolve_question(
cls, question_id: int, resolution: str | None, resolve_time: datetime
) -> None:
return cls.METACULUS_CLIENT.resolve_question(
question_id, resolution, resolve_time
)

@classmethod
def unresolve_question(cls, question_id: int) -> None:
return cls.METACULUS_CLIENT.unresolve_question(question_id)

@classmethod
def create_question(cls, question: MetaculusQuestion) -> MetaculusQuestion:
return cls.METACULUS_CLIENT.create_question(question)

@classmethod
def approve_question(cls, question: MetaculusQuestion) -> None:
return cls.METACULUS_CLIENT.approve_question(question)
Loading