diff --git a/unshackle/core/utilities.py b/unshackle/core/utilities.py index 210de02..784c037 100644 --- a/unshackle/core/utilities.py +++ b/unshackle/core/utilities.py @@ -1,6 +1,7 @@ import ast import contextlib import importlib.util +import logging import os import re import socket @@ -21,6 +22,7 @@ from langcodes import Language, closest_match from pymp4.parser import Box from unidecode import unidecode +from unshackle.core.cacher import Cacher from unshackle.core.config import config from unshackle.core.constants import LANGUAGE_MAX_DISTANCE @@ -248,6 +250,7 @@ def get_cached_ip_info(session: Optional[requests.Session] = None) -> Optional[d This function uses a global cache to avoid repeated API calls when the IP hasn't changed. Should only be used for local IP checks, not for proxy verification. + Implements smart provider rotation to handle rate limiting (429 errors). Args: session: Optional requests session (usually without proxy for local IP) @@ -255,26 +258,54 @@ def get_cached_ip_info(session: Optional[requests.Session] = None) -> Optional[d Returns: Dict with IP info including 'country' key, or None if all providers fail """ - from unshackle.core.cacher import Cacher + log = logging.getLogger("get_cached_ip_info") cache = Cacher("global").get("ip_info") if cache and not cache.expired: return cache.data - providers = [ - "https://ipinfo.io/json", - "https://ipapi.co/json", - ] + provider_state_cache = Cacher("global").get("ip_provider_state") + provider_state = provider_state_cache.data if provider_state_cache and not provider_state_cache.expired else {} + + providers = { + "ipinfo": "https://ipinfo.io/json", + "ipapi": "https://ipapi.co/json", + } session = session or requests.Session() + provider_order = ["ipinfo", "ipapi"] - for provider_url in providers: + current_time = time.time() + for provider_name in list(provider_order): + if provider_name in provider_state: + rate_limit_info = provider_state[provider_name] + if (current_time - rate_limit_info.get("rate_limited_at", 0)) < 300: + log.debug(f"Provider {provider_name} was rate limited recently, trying other provider first") + provider_order.remove(provider_name) + provider_order.append(provider_name) + break + + for provider_name in provider_order: + provider_url = providers[provider_name] try: + log.debug(f"Trying IP provider: {provider_name}") response = session.get(provider_url, timeout=10) - if response.status_code == 200: - data = response.json() + if response.status_code == 429: + log.warning(f"Provider {provider_name} returned 429 (rate limited), trying next provider") + if provider_name not in provider_state: + provider_state[provider_name] = {} + provider_state[provider_name]["rate_limited_at"] = current_time + provider_state[provider_name]["rate_limit_count"] = ( + provider_state[provider_name].get("rate_limit_count", 0) + 1 + ) + + provider_state_cache.set(provider_state, expiration=300) + continue + + elif response.status_code == 200: + data = response.json() normalized_data = {} if "country" in data: @@ -288,12 +319,23 @@ def get_cached_ip_info(session: Optional[requests.Session] = None) -> Optional[d } if normalized_data and "country" in normalized_data: + log.debug(f"Successfully got IP info from provider: {provider_name}") + + if provider_name in provider_state: + provider_state[provider_name].pop("rate_limited_at", None) + provider_state_cache.set(provider_state, expiration=300) + + normalized_data["_provider"] = provider_name cache.set(normalized_data, expiration=86400) return normalized_data + else: + log.debug(f"Provider {provider_name} returned status {response.status_code}") - except Exception: + except Exception as e: + log.debug(f"Provider {provider_name} failed with exception: {e}") continue + log.warning("All IP geolocation providers failed") return None