feat(ip-info): Fix few more issues with the get_ip_info make sure we failover to different provider on 429 errors and allow future for more API providers to be added later.

This commit is contained in:
Andy
2025-08-16 00:28:05 +00:00
parent 50a5a23341
commit 72f65adcb2

View File

@@ -1,6 +1,7 @@
import ast import ast
import contextlib import contextlib
import importlib.util import importlib.util
import logging
import os import os
import re import re
import socket import socket
@@ -21,6 +22,7 @@ from langcodes import Language, closest_match
from pymp4.parser import Box from pymp4.parser import Box
from unidecode import unidecode from unidecode import unidecode
from unshackle.core.cacher import Cacher
from unshackle.core.config import config from unshackle.core.config import config
from unshackle.core.constants import LANGUAGE_MAX_DISTANCE from unshackle.core.constants import LANGUAGE_MAX_DISTANCE
@@ -248,6 +250,7 @@ def get_cached_ip_info(session: Optional[requests.Session] = None) -> Optional[d
This function uses a global cache to avoid repeated API calls when the IP This function uses a global cache to avoid repeated API calls when the IP
hasn't changed. Should only be used for local IP checks, not for proxy verification. hasn't changed. Should only be used for local IP checks, not for proxy verification.
Implements smart provider rotation to handle rate limiting (429 errors).
Args: Args:
session: Optional requests session (usually without proxy for local IP) session: Optional requests session (usually without proxy for local IP)
@@ -255,26 +258,54 @@ def get_cached_ip_info(session: Optional[requests.Session] = None) -> Optional[d
Returns: Returns:
Dict with IP info including 'country' key, or None if all providers fail Dict with IP info including 'country' key, or None if all providers fail
""" """
from unshackle.core.cacher import Cacher
log = logging.getLogger("get_cached_ip_info")
cache = Cacher("global").get("ip_info") cache = Cacher("global").get("ip_info")
if cache and not cache.expired: if cache and not cache.expired:
return cache.data return cache.data
providers = [ provider_state_cache = Cacher("global").get("ip_provider_state")
"https://ipinfo.io/json", provider_state = provider_state_cache.data if provider_state_cache and not provider_state_cache.expired else {}
"https://ipapi.co/json",
] providers = {
"ipinfo": "https://ipinfo.io/json",
"ipapi": "https://ipapi.co/json",
}
session = session or requests.Session() session = session or requests.Session()
provider_order = ["ipinfo", "ipapi"]
for provider_url in providers: current_time = time.time()
for provider_name in list(provider_order):
if provider_name in provider_state:
rate_limit_info = provider_state[provider_name]
if (current_time - rate_limit_info.get("rate_limited_at", 0)) < 300:
log.debug(f"Provider {provider_name} was rate limited recently, trying other provider first")
provider_order.remove(provider_name)
provider_order.append(provider_name)
break
for provider_name in provider_order:
provider_url = providers[provider_name]
try: try:
log.debug(f"Trying IP provider: {provider_name}")
response = session.get(provider_url, timeout=10) response = session.get(provider_url, timeout=10)
if response.status_code == 200:
data = response.json()
if response.status_code == 429:
log.warning(f"Provider {provider_name} returned 429 (rate limited), trying next provider")
if provider_name not in provider_state:
provider_state[provider_name] = {}
provider_state[provider_name]["rate_limited_at"] = current_time
provider_state[provider_name]["rate_limit_count"] = (
provider_state[provider_name].get("rate_limit_count", 0) + 1
)
provider_state_cache.set(provider_state, expiration=300)
continue
elif response.status_code == 200:
data = response.json()
normalized_data = {} normalized_data = {}
if "country" in data: if "country" in data:
@@ -288,12 +319,23 @@ def get_cached_ip_info(session: Optional[requests.Session] = None) -> Optional[d
} }
if normalized_data and "country" in normalized_data: if normalized_data and "country" in normalized_data:
log.debug(f"Successfully got IP info from provider: {provider_name}")
if provider_name in provider_state:
provider_state[provider_name].pop("rate_limited_at", None)
provider_state_cache.set(provider_state, expiration=300)
normalized_data["_provider"] = provider_name
cache.set(normalized_data, expiration=86400) cache.set(normalized_data, expiration=86400)
return normalized_data return normalized_data
else:
log.debug(f"Provider {provider_name} returned status {response.status_code}")
except Exception: except Exception as e:
log.debug(f"Provider {provider_name} failed with exception: {e}")
continue continue
log.warning("All IP geolocation providers failed")
return None return None