diff --git a/code_tests/integration_tests/test_ai_models/test_general_llm.py b/code_tests/integration_tests/test_ai_models/test_general_llm.py index 19d6f175..c4775b00 100644 --- a/code_tests/integration_tests/test_ai_models/test_general_llm.py +++ b/code_tests/integration_tests/test_ai_models/test_general_llm.py @@ -24,6 +24,13 @@ def _all_tests() -> list[ModelTest]: GeneralLlm(model="openai/gpt-4o"), [{"role": "user", "content": test_data.get_cheap_user_message()}], ), + ModelTest( + GeneralLlm(model="openai/gpt-4o"), + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": test_data.get_cheap_user_message()}, + ], + ), ModelTest( GeneralLlm(model="o3-mini", reasoning_effort="low"), test_data.get_cheap_user_message(), @@ -99,6 +106,21 @@ def test_general_llm_instances_run( assert cost_manager.current_usage > 0, "No cost was incurred" +def test_system_prompt_parameter() -> None: + model = GeneralLlm(model="gpt-4o") + response = asyncio.run( + model.invoke( + "Hello, world!", + system_prompt="If the user says 'Hello, world!', say 'Hello there friend! I am a turkey.' back to them.", + ) + ) + assert response is not None, "Response is None" + assert response != "", "Response is an empty string" + logger.info(f"Response: {response}") + assert "friend" in response.lower(), "Response is not friendly" + assert "turkey" in response.lower(), "Response is not a turkey" + + def test_timeout_works() -> None: model = GeneralLlm(model="gpt-4o", timeout=0.1) model_input = "Hello, world!" diff --git a/code_tests/integration_tests/test_metaculus_api.py b/code_tests/integration_tests/test_metaculus_api.py index 1dff1c10..fdb17f2c 100644 --- a/code_tests/integration_tests/test_metaculus_api.py +++ b/code_tests/integration_tests/test_metaculus_api.py @@ -1,4 +1,5 @@ import logging +import os from datetime import datetime, timezone import pendulum @@ -15,6 +16,7 @@ from forecasting_tools.data_models.questions import ( BinaryQuestion, CanceledResolution, + Category, ConditionalQuestion, DateQuestion, DiscreteQuestion, @@ -56,6 +58,11 @@ def test_get_binary_question_type_from_id(self) -> None: assert question.question_type == "binary" assert question.question_ids_of_group is None assert question.is_in_main_feed is True + assert set([category.name for category in question.categories]) == { + "Health & Pandemics", + "Nuclear Technology & Risks", + "Technology", + }, f"Categories are not correct for post ID {question.id_of_post}. Categories: {question.categories}" assert_basic_question_attributes_not_none(question, question.id_of_post) def test_get_numeric_question_type_from_id(self) -> None: @@ -71,6 +78,9 @@ def test_get_numeric_question_type_from_id(self) -> None: assert question.question_weight == 1.0 assert question.get_question_type() == "numeric" assert question.question_type == "numeric" + assert set([category.name for category in question.categories]) == { + "Health & Pandemics" + }, f"Categories are not correct for post ID {question.id_of_post}. Categories: {question.categories}" assert_basic_question_attributes_not_none(question, question.id_of_post) assert question.lower_bound == 0 assert question.upper_bound == 200 @@ -635,6 +645,7 @@ def test_log_scale_all_repeated_values(self) -> None: ] self._check_cdf_processes_and_posts_correctly(percentiles, question) + @pytest.mark.skip(reason="The test question is now closed for forecasting") def test_log_scale_another_edge_case(self) -> None: # This test was able to replicate a floating point epsilon error at one point. url = "https://dev.metaculus.com/questions/7546" @@ -1103,19 +1114,34 @@ def assert_basic_question_attributes_not_none( question.question_text is not None ), f"Question text is None for post ID {post_id}" assert question.close_time is not None, f"Close time is None for post ID {post_id}" + assert ( + question.close_time.tzinfo is not None + ), f"Close time is not timezone aware for post ID {post_id}" assert question.open_time is not None, f"Open time is None for post ID {post_id}" + assert ( + question.open_time.tzinfo is not None + ), f"Open time is not timezone aware for post ID {post_id}" assert ( question.published_time is not None ), f"Published time is None for post ID {post_id}" + assert ( + question.published_time.tzinfo is not None + ), f"Published time is not timezone aware for post ID {post_id}" assert ( question.scheduled_resolution_time is not None ), f"Scheduled resolution time is None for post ID {post_id}" + assert ( + question.scheduled_resolution_time.tzinfo is not None + ), f"Scheduled resolution time is not timezone aware for post ID {post_id}" assert ( question.includes_bots_in_aggregates is not None ), f"Includes bots in aggregates is None for post ID {post_id}" assert ( question.cp_reveal_time is not None ), f"CP reveal time is None for post ID {post_id}" + assert ( + question.cp_reveal_time.tzinfo is not None + ), f"CP reveal time is not timezone aware for post ID {post_id}" assert isinstance( question.state, QuestionState ), f"State is not a QuestionState for post ID {post_id}" @@ -1134,6 +1160,10 @@ def assert_basic_question_attributes_not_none( assert question.actual_resolution_time is None or isinstance( question.actual_resolution_time, datetime ), f"Actual resolution time is not a datetime for post ID {post_id}" + if question.actual_resolution_time is not None: + assert ( + question.actual_resolution_time.tzinfo is not None + ), f"Actual resolution time is not timezone aware for post ID {post_id}" assert isinstance( question.api_json, dict ), f"API JSON is not a dict for post ID {post_id}" @@ -1156,6 +1186,9 @@ def assert_basic_question_attributes_not_none( assert question.date_accessed > pendulum.now().subtract( days=1 ), f"Date accessed is not in the past for post ID {post_id}" + assert ( + question.date_accessed.tzinfo is not None + ), f"Date accessed is not timezone aware for post ID {post_id}" assert isinstance( question.already_forecasted, bool ), f"Already forecasted is not a boolean for post ID {post_id}" @@ -1163,6 +1196,9 @@ def assert_basic_question_attributes_not_none( assert ( question.timestamp_of_my_last_forecast is not None ), f"Timestamp of my last forecast is None for post ID {post_id}" + assert ( + question.timestamp_of_my_last_forecast.tzinfo is not None + ), f"Timestamp of my last forecast is not timezone aware for post ID {post_id}" if isinstance(question, NumericQuestion): assert ( question.unit_of_measure is not None @@ -1186,6 +1222,9 @@ def assert_basic_question_attributes_not_none( assert isinstance(question.question_ids_of_group, list) assert all(isinstance(q_id, int) for q_id in question.question_ids_of_group) assert question.group_question_option is not None + assert all( + isinstance(category, Category) for category in question.categories + ), f"Categories is not a list of strings for post ID {post_id}" def assert_questions_match_filter( # NOSONAR @@ -1313,3 +1352,78 @@ def assert_questions_match_filter( # NOSONAR assert ( filter_passes ), f"Question {question.id_of_post} has no community prediction at access time" + + +class TestAdminFunctions: + + def test_all_admin_functions(self) -> None: + token = os.getenv("ADMIN_METACULUS_TOKEN") + if token is None: + raise ValueError("ADMIN_METACULUS_TOKEN is not set") + client = MetaculusClient( + base_url="https://dev.metaculus.com/api", + token=token, + ) + question_to_create = client.get_question_by_url( + "https://dev.metaculus.com/questions/39162/" + ) + project_id = 1156 # https://dev.metaculus.com/tournament/beta-testing/ + slug = "beta-testing" + question_to_create.default_project_id = project_id + question_to_create.tournament_slugs = [slug] + + ### Create and approve question ### + created_question = client.create_question(question_to_create) + client.approve_question(created_question) + + assert created_question is not None + assert created_question.id_of_post is not None + assert created_question.id_of_question is not None + assert created_question.default_project_id == project_id + assert created_question.id_of_post != question_to_create.id_of_post + assert created_question.id_of_question != question_to_create.id_of_question + + assert created_question.question_text == question_to_create.question_text + assert created_question.open_time == question_to_create.open_time + assert created_question.close_time == question_to_create.close_time + assert ( + created_question.includes_bots_in_aggregates + == question_to_create.includes_bots_in_aggregates + ) + assert created_question.cp_reveal_time == question_to_create.cp_reveal_time + assert created_question.question_weight == question_to_create.question_weight + assert ( + created_question.resolution_string == question_to_create.resolution_string + ) + assert created_question.conditional_type == question_to_create.conditional_type + assert_basic_question_attributes_not_none( + created_question, created_question.id_of_post + ) + assert str(created_question.categories) == str(question_to_create.categories) + assert set(category.name for category in created_question.categories) == { + "Artificial Intelligence" + } + assert ( + created_question.resolution_criteria + == question_to_create.resolution_criteria + ) + assert created_question.fine_print == question_to_create.fine_print + assert created_question.background_info == question_to_create.background_info + assert set(created_question.tournament_slugs) == {slug} + assert created_question.published_time == question_to_create.published_time + + ### Resolve question ### + client.resolve_question(created_question.id_of_question, "yes", pendulum.now()) + resolved_question = client.get_question_by_post_id(created_question.id_of_post) + assert resolved_question.state == QuestionState.RESOLVED + assert resolved_question.actual_resolution_time is not None + assert resolved_question.resolution_string == "yes" + + ### Unresolve question ### + client.unresolve_question(created_question.id_of_question) + unresolved_question = client.get_question_by_post_id( + created_question.id_of_post + ) + assert unresolved_question.actual_resolution_time is None + assert unresolved_question.resolution_string is None + assert unresolved_question.state != QuestionState.RESOLVED diff --git a/forecasting_tools/__init__.py b/forecasting_tools/__init__.py index f325a0c6..c1e66953 100644 --- a/forecasting_tools/__init__.py +++ b/forecasting_tools/__init__.py @@ -112,6 +112,7 @@ ) from forecasting_tools.data_models.numeric_report import NumericReport as NumericReport from forecasting_tools.data_models.questions import BinaryQuestion as BinaryQuestion +from forecasting_tools.data_models.questions import Category as Category from forecasting_tools.data_models.questions import DateQuestion as DateQuestion from forecasting_tools.data_models.questions import DiscreteQuestion as DiscreteQuestion from forecasting_tools.data_models.questions import ( diff --git a/forecasting_tools/ai_models/general_llm.py b/forecasting_tools/ai_models/general_llm.py index 581c0dc0..97d869a4 100644 --- a/forecasting_tools/ai_models/general_llm.py +++ b/forecasting_tools/ai_models/general_llm.py @@ -52,7 +52,7 @@ class GeneralLlm( OutputsText, ): """ - A wrapper around litellm's acompletion function that adds some functionality + A wrapper around litellm's acompletion function that adds functionality like rate limiting, retry logic, metaculus proxy, and cost callback handling. Litellm support every model, most every parameter, and acts as one interface for every provider. @@ -231,23 +231,32 @@ def __init__( ModelTracker.give_cost_tracking_warning_if_needed(self._litellm_model) - async def invoke(self, prompt: ModelInputType, system_prompt: str | None = None) -> str: + async def invoke( + self, prompt: ModelInputType, system_prompt: str | None = None + ) -> str: + if system_prompt is not None and ( + isinstance(prompt, str) or isinstance(prompt, VisionMessageData) + ): + prompt = self.model_input_to_message(prompt, system_prompt) + elif system_prompt is not None and isinstance(prompt, list): + raise ValueError( + "System prompt cannot be used with list of messages since the list may include a system message already" + ) response: TextTokenCostResponse = ( - await self._invoke_with_request_cost_time_and_token_limits_and_retry(prompt, system_prompt=system_prompt) + await self._invoke_with_request_cost_time_and_token_limits_and_retry(prompt) ) data = response.data return data @RetryableModel._retry_according_to_model_allowed_tries async def _invoke_with_request_cost_time_and_token_limits_and_retry( - self, prompt: ModelInputType, system_prompt: str | None = None + self, + prompt: ModelInputType, ) -> Any: logger.debug(f"Invoking model with prompt: {prompt}") - prompt = self.model_input_to_message(prompt, system_prompt) - with track_generation( - input=prompt, + input=self.model_input_to_message(prompt), model=self.model, ) as span: direct_call_response = await self._mockable_direct_call_to_model(prompt) diff --git a/forecasting_tools/data_models/questions.py b/forecasting_tools/data_models/questions.py index daa1cf18..a797e829 100644 --- a/forecasting_tools/data_models/questions.py +++ b/forecasting_tools/data_models/questions.py @@ -59,6 +59,15 @@ class OutOfBoundsResolution(Enum): ConditionalSubQuestionType = Literal["parent", "child", "yes", "no"] +class Category(BaseModel, Jsonable): + id: int + name: str + slug: str | None = None + emoji: str | None = None + description: str | None = None + type: Literal["category"] = "category" + + class MetaculusQuestion(BaseModel, Jsonable): question_text: str id_of_post: int | None = Field( @@ -110,6 +119,7 @@ class MetaculusQuestion(BaseModel, Jsonable): custom_metadata: dict = Field( default_factory=dict ) # Additional metadata not tracked above or through the Metaculus API + categories: list[Category] = Field(default_factory=list) @model_validator(mode="after") def add_timezone_to_dates(self) -> MetaculusQuestion: @@ -134,12 +144,26 @@ def from_metaculus_api_json(cls, post_api_json: dict) -> MetaculusQuestion: tournaments: list[dict] = post_api_json["projects"]["tournament"] # type: ignore tournament_slugs = [str(t["slug"]) for t in tournaments] except KeyError: - tournament_slugs = [] + try: + question_series: list[dict] = post_api_json["projects"]["question_series"] # type: ignore + tournament_slugs = [str(q["slug"]) for q in question_series] + except KeyError: + tournament_slugs = [] group_question_option = question_json.get("label", None) if group_question_option is not None and group_question_option.strip() == "": group_question_option = None + try: + categories_dict_list = post_api_json["projects"]["category"] + except KeyError: + categories_dict_list = [] + categories = [ + Category(**category_dict) for category_dict in categories_dict_list + ] + + published_time = cls._parse_api_date(post_api_json.get("published_at")) + question = MetaculusQuestion( # NOTE: Reminder - When adding new fields, consider if group questions # need to be parsed differently (i.e. if the field information is part of the post_json) @@ -162,7 +186,7 @@ def from_metaculus_api_json(cls, post_api_json: dict) -> MetaculusQuestion: scheduled_resolution_time=cls._parse_api_date( question_json.get("scheduled_resolve_time") ), - published_time=cls._parse_api_date(post_api_json.get("published_at")), + published_time=published_time, cp_reveal_time=cls._parse_api_date(question_json.get("cp_reveal_time")), open_time=cls._parse_api_date(question_json.get("open_time")), already_forecasted=is_forecasted, @@ -180,6 +204,7 @@ def from_metaculus_api_json(cls, post_api_json: dict) -> MetaculusQuestion: group_question_option=group_question_option, api_json=post_api_json, conditional_type=question_json.get("conditional_type", None), + categories=categories, ) return question diff --git a/forecasting_tools/helpers/metaculus_api.py b/forecasting_tools/helpers/metaculus_api.py index aaf0d963..2cc56b0c 100644 --- a/forecasting_tools/helpers/metaculus_api.py +++ b/forecasting_tools/helpers/metaculus_api.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import datetime from typing import List, Literal, overload import typing_extensions @@ -231,3 +232,24 @@ def get_benchmark_questions( error_if_question_target_missed, group_question_mode, ) + + ### ADMIN FUNCTIONS ### + @classmethod + def resolve_question( + cls, question_id: int, resolution: str | None, resolve_time: datetime + ) -> None: + return cls.METACULUS_CLIENT.resolve_question( + question_id, resolution, resolve_time + ) + + @classmethod + def unresolve_question(cls, question_id: int) -> None: + return cls.METACULUS_CLIENT.unresolve_question(question_id) + + @classmethod + def create_question(cls, question: MetaculusQuestion) -> MetaculusQuestion: + return cls.METACULUS_CLIENT.create_question(question) + + @classmethod + def approve_question(cls, question: MetaculusQuestion) -> None: + return cls.METACULUS_CLIENT.approve_question(question) diff --git a/forecasting_tools/helpers/metaculus_client.py b/forecasting_tools/helpers/metaculus_client.py index 5ae6c71b..20f93f8b 100644 --- a/forecasting_tools/helpers/metaculus_client.py +++ b/forecasting_tools/helpers/metaculus_client.py @@ -163,7 +163,12 @@ def __init__( self.timeout = timeout self.sleep_time_between_requests_min = sleep_seconds_between_requests self.sleep_jitter_seconds = sleep_jitter_seconds - self.token = token + + self.token = token or os.getenv("METACULUS_TOKEN") + if self.token is None: + logger.warning( + "METACULUS_TOKEN environment variable and/or token field not set" + ) @retry_with_exponential_backoff() def get_user_bots(self) -> list[UserResponse]: @@ -344,9 +349,10 @@ def get_question_by_post_id( self._sleep_between_requests() url = f"{self.base_url}/posts/{post_id}/" + headers = self._get_auth_headers() response = requests.get( url, - **self._get_auth_headers(), # type: ignore + **headers, # type: ignore timeout=self.timeout, ) raise_for_status_with_additional_info(response) @@ -573,15 +579,17 @@ def get_needs_update_questions( return result def _get_auth_headers(self) -> dict[str, dict[str, str]]: - METACULUS_TOKEN = self.token or os.getenv("METACULUS_TOKEN") - if METACULUS_TOKEN is None: - raise ValueError("METACULUS_TOKEN environment variable or field not set") - return { + if self.token is None: + raise ValueError( + "METACULUS_TOKEN environment variable and/or token field not set" + ) + headers = { "headers": { - "Authorization": f"Token {METACULUS_TOKEN}", + "Authorization": f"Token {self.token}", "Accept-Language": "en", } } + return headers @retry_with_exponential_backoff() def _post_question_prediction( @@ -1152,3 +1160,184 @@ def _sleep_between_requests(self) -> None: f"Sleeping for {random_sleep_time:.1f} seconds before next request" ) time.sleep(random_sleep_time) + + #################### ADMIN FUNCTIONS #################### + + @retry_with_exponential_backoff() + def resolve_question( + self, question_id: int, resolution: str | None, resolve_time: datetime + ) -> None: + self._sleep_between_requests() + if resolution is None: + logger.warning( + f"Resolution is None, Question ID {question_id} not resolved" + ) + return + response = requests.post( + f"{self.base_url}/questions/{question_id}/resolve/", + json={ + "resolution": resolution, + "actual_resolve_time": resolve_time.strftime("%Y-%m-%dT%H:%M:%S"), + }, # TODO: Find all instances of strftime that doesn't use timezone, and add timezone + **self._get_auth_headers(), # type: ignore + timeout=self.timeout, + ) + raise_for_status_with_additional_info(response) + logger.info(f"Resolved question ID {question_id} with resolution {resolution}") + + @retry_with_exponential_backoff() + def unresolve_question(self, question_id: int) -> None: + self._sleep_between_requests() + response = requests.post( + f"{self.base_url}/questions/{question_id}/unresolve/", + **self._get_auth_headers(), # type: ignore + timeout=self.timeout, + ) + raise_for_status_with_additional_info(response) + logger.info(f"Unresolved question ID {question_id}") + + @retry_with_exponential_backoff() + def create_question(self, question: MetaculusQuestion) -> MetaculusQuestion: + question_data = self._get_post_create_data(question) + self._sleep_between_requests() + response = requests.post( + f"{self.base_url}/posts/create/", + json=question_data, + headers={ + "accept": "application/json", + "Authorization": f"Token {self.token}", + }, + timeout=self.timeout, + ) + + raise_for_status_with_additional_info(response) + created_question_list = self._post_json_to_questions_while_handling_groups( + json.loads(response.content), "exclude" + ) + if len(created_question_list) != 1: + raise ValueError( + f"Expected 1 question to be created, got {len(created_question_list)}. Questions created: {[question.page_url for question in created_question_list]}" + ) + single_created_question = created_question_list[0] + logger.info(f"Created new questions: {single_created_question.page_url}") + + assert single_created_question.id_of_post is not None + full_created_question = self.get_question_by_post_id( + single_created_question.id_of_post + ) + + return full_created_question + + @retry_with_exponential_backoff() + def approve_question(self, question: MetaculusQuestion) -> None: + self._sleep_between_requests() + response = requests.post( + f"{self.base_url}/posts/{question.id_of_post}/approve/", + data={ + "published_at": ( + question.published_time.strftime("%Y-%m-%dT%H:%M:%S") + if question.published_time + else None + ), + "open_time": ( + question.open_time.strftime("%Y-%m-%dT%H:%M:%S") + if question.open_time + else None + ), + "cp_reveal_time": ( + question.cp_reveal_time.strftime("%Y-%m-%dT%H:%M:%S") + if question.cp_reveal_time + else None + ), + "scheduled_close_time": ( + question.close_time.strftime("%Y-%m-%dT%H:%M:%S") + if question.close_time + else None + ), + "scheduled_resolve_time": ( + question.scheduled_resolution_time.strftime("%Y-%m-%dT%H:%M:%S") + if question.scheduled_resolution_time + else None + ), + }, + timeout=self.timeout, + headers={ + "accept": "application/json", + "Authorization": f"Token {self.token}", + }, + ) + + raise_for_status_with_additional_info(response) + logger.info(f"Approved question {question.page_url}") + + @staticmethod + def _get_post_create_data(question: MetaculusQuestion) -> dict: + if question.published_time is None: + publish_time = question.open_time + else: + publish_time = question.published_time + + return { + "title": question.question_text, + "short_title": question.custom_metadata.get("short_name", None) + or question.question_text, + "default_project": question.default_project_id, + "categories": [category.id for category in question.categories], + "published_at": ( + publish_time.strftime("%Y-%m-%dT%H:%M:%S") if publish_time else None + ), + "is_automatically_translated": False, + "question": { + "title": question.question_text, + "description": question.background_info or "", + "type": getattr(question, "question_type", "binary"), + "possibilities": None, # deprecated + "resolution": None, + "include_bots_in_aggregates": ( + question.includes_bots_in_aggregates + if isinstance(question.includes_bots_in_aggregates, bool) + else True # defaults to True since mini bench is a bot thing + ), + "default_aggregation_method": "unweighted", + "default_score_type": "spot_peer", + "question_weight": ( + question.question_weight if question.question_weight else 1.0 + ), + "scaling": { + "range_max": getattr(question, "upper_bound", None), + "range_min": getattr(question, "lower_bound", None), + "zero_point": getattr(question, "zero_point", None), + }, + "open_upper_bound": getattr(question, "open_upper_bound", None), + "open_lower_bound": getattr(question, "open_lower_bound", None), + "inbound_outcome_count": getattr(question, "cdf_size", 201) - 1, + "options": getattr(question, "options", None), + "group_variable": getattr(question, "option_is_instance_of", None) + or "", + "label": "", # only group questions + "resolution_criteria": question.resolution_criteria or "", + "open_time": ( + question.open_time.strftime("%Y-%m-%dT%H:%M:%S") + if question.open_time + else None + ), + "cp_reveal_time": ( + question.cp_reveal_time.strftime("%Y-%m-%dT%H:%M:%S") + if question.cp_reveal_time + else None + ), + "scheduled_close_time": ( + question.close_time.strftime("%Y-%m-%dT%H:%M:%S") + if question.close_time + else None + ), + "scheduled_resolve_time": ( + question.scheduled_resolution_time.strftime("%Y-%m-%dT%H:%M:%S") + if question.scheduled_resolution_time + else None + ), + "unit": getattr(question, "unit_of_measure", None) or "", + "fine_print": question.fine_print or "", + "group_rank": None, # only group questions + }, + } diff --git a/forecasting_tools/util/launch_utils.py b/forecasting_tools/util/launch_utils.py index 7f588237..05722bfe 100644 --- a/forecasting_tools/util/launch_utils.py +++ b/forecasting_tools/util/launch_utils.py @@ -331,9 +331,18 @@ def schedule_questions( @staticmethod def compute_upcoming_day( - day_of_week: Literal["monday", "saturday", "friday"], + day_of_week: Literal[ + "monday", "saturday", "friday", "tuesday", "wednesday", "thursday" + ], ) -> datetime: - day_number = {"monday": 0, "saturday": 5, "friday": 4} + day_number = { + "monday": 0, + "tuesday": 1, + "wednesday": 2, + "thursday": 3, + "saturday": 5, + "friday": 4, + } today = datetime.now().date() today_weekday = today.weekday() target_weekday = day_number[day_of_week] diff --git a/forecasting_tools/util/misc.py b/forecasting_tools/util/misc.py index 7aaa911a..88433f60 100644 --- a/forecasting_tools/util/misc.py +++ b/forecasting_tools/util/misc.py @@ -26,13 +26,25 @@ def raise_for_status_with_additional_info( try: response.raise_for_status() except requests.exceptions.HTTPError as e: - response_text = response.text + response_text = str(response.text) response_reason = response.reason try: response_json = response.json() except Exception: response_json = None - error_message = f"HTTPError. Url: {response.url}. Response reason: {response_reason}. Response text: {response_text}. Response JSON: {response_json}" + if "!DOCTYPE html".lower() in response_text.lower(): + response_text = "Response text is a HTML page" + + try: + status_code = response.status_code # type: ignore + except Exception: + status_code = None + + error_message = ( + f"HTTPError. Url: {response.url}. Status code: {status_code}. " + f"Response reason: {response_reason}. Response text: {response_text}. " + f"Response JSON: {response_json}." + ) logger.error(error_message) raise requests.exceptions.HTTPError(error_message) from e diff --git a/pyproject.toml b/pyproject.toml index af3276a1..7b11d698 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "forecasting-tools" -version = "0.2.81" +version = "0.2.83" description = "AI forecasting and research tools to help humans reason about and forecast the future" authors = ["Benjamin Wilson "] license = "MIT"