- Fix return type annotation for get_default_premium_for_product - Add type narrowing for Weight|Money union using _as_money helper - Add isinstance checks before float() calls for object types - Add type guard for Decimal.exponent comparison - Use _unit_typed and _currency_typed properties for type narrowing - Cast option_type to OptionType Literal after validation - Fix provider type hierarchy in backtesting services - Add types-requests to dev dependencies - Remove '|| true' from CI type-check job All 36 mypy errors resolved across 15 files.
165 lines
6.8 KiB
Python
165 lines
6.8 KiB
Python
"""Live price feed service for fetching real-time GLD and other asset prices."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import math
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from typing import Mapping
|
|
|
|
import yfinance as yf
|
|
|
|
from app.services.cache import get_cache
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PriceData:
|
|
"""Price data for a symbol."""
|
|
|
|
symbol: str
|
|
price: float
|
|
currency: str
|
|
timestamp: datetime
|
|
source: str = "yfinance"
|
|
|
|
def __post_init__(self) -> None:
|
|
normalized_symbol = self.symbol.strip().upper()
|
|
if not normalized_symbol:
|
|
raise ValueError("symbol is required")
|
|
if not math.isfinite(self.price) or self.price <= 0:
|
|
raise ValueError("price must be a finite positive number")
|
|
normalized_currency = self.currency.strip().upper()
|
|
if not normalized_currency:
|
|
raise ValueError("currency is required")
|
|
if not isinstance(self.timestamp, datetime):
|
|
raise TypeError("timestamp must be a datetime")
|
|
object.__setattr__(self, "symbol", normalized_symbol)
|
|
object.__setattr__(self, "currency", normalized_currency)
|
|
object.__setattr__(self, "source", self.source.strip() or "yfinance")
|
|
|
|
|
|
class PriceFeed:
|
|
"""Live price feed service using yfinance with Redis caching."""
|
|
|
|
CACHE_TTL_SECONDS = 60
|
|
DEFAULT_SYMBOLS = ["GLD", "TLT", "BTC-USD"]
|
|
|
|
def __init__(self):
|
|
self._cache = get_cache()
|
|
|
|
@staticmethod
|
|
def _required_payload_value(payload: Mapping[str, object], key: str, *, context: str) -> object:
|
|
if key not in payload:
|
|
raise TypeError(f"{context} is missing required field: {key}")
|
|
return payload[key]
|
|
|
|
@classmethod
|
|
def _normalize_cached_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData:
|
|
if not isinstance(payload, Mapping):
|
|
raise TypeError("cached price payload must be an object")
|
|
payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper()
|
|
normalized_symbol = expected_symbol.strip().upper()
|
|
if payload_symbol != normalized_symbol:
|
|
raise ValueError(f"cached symbol mismatch: {payload_symbol} != {normalized_symbol}")
|
|
timestamp = cls._required_payload_value(payload, "timestamp", context="cached price payload")
|
|
if not isinstance(timestamp, str) or not timestamp.strip():
|
|
raise TypeError("cached timestamp must be a non-empty ISO string")
|
|
price_val = cls._required_payload_value(payload, "price", context="cached price payload")
|
|
price = float(price_val) if isinstance(price_val, (int, float)) else 0.0
|
|
return PriceData(
|
|
symbol=payload_symbol,
|
|
price=price,
|
|
currency=str(payload.get("currency", "USD")),
|
|
timestamp=datetime.fromisoformat(timestamp),
|
|
source=str(payload.get("source", "yfinance")),
|
|
)
|
|
|
|
@classmethod
|
|
def _normalize_provider_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData:
|
|
if not isinstance(payload, Mapping):
|
|
raise TypeError("provider price payload must be an object")
|
|
payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper()
|
|
normalized_symbol = expected_symbol.strip().upper()
|
|
if payload_symbol != normalized_symbol:
|
|
raise ValueError(f"provider symbol mismatch: {payload_symbol} != {normalized_symbol}")
|
|
timestamp = cls._required_payload_value(payload, "timestamp", context="provider price payload")
|
|
if not isinstance(timestamp, datetime):
|
|
raise TypeError("provider timestamp must be a datetime")
|
|
price_val = cls._required_payload_value(payload, "price", context="provider price payload")
|
|
price = float(price_val) if isinstance(price_val, (int, float)) else 0.0
|
|
return PriceData(
|
|
symbol=payload_symbol,
|
|
price=price,
|
|
currency=str(payload.get("currency", "USD")),
|
|
timestamp=timestamp,
|
|
source=str(payload.get("source", "yfinance")),
|
|
)
|
|
|
|
@staticmethod
|
|
def _price_data_to_cache_payload(data: PriceData) -> dict[str, object]:
|
|
return {
|
|
"symbol": data.symbol,
|
|
"price": data.price,
|
|
"currency": data.currency,
|
|
"timestamp": data.timestamp.isoformat(),
|
|
"source": data.source,
|
|
}
|
|
|
|
async def get_price(self, symbol: str) -> PriceData | None:
|
|
"""Get current price for a symbol, with caching."""
|
|
normalized_symbol = symbol.strip().upper()
|
|
cache_key = f"price:{normalized_symbol}"
|
|
|
|
if self._cache.enabled:
|
|
cached = await self._cache.get_json(cache_key)
|
|
if cached is not None:
|
|
try:
|
|
return self._normalize_cached_price_payload(cached, expected_symbol=normalized_symbol)
|
|
except (TypeError, ValueError) as exc:
|
|
logger.warning("Discarding cached price payload for %s: %s", normalized_symbol, exc)
|
|
|
|
try:
|
|
payload = await self._fetch_yfinance(normalized_symbol)
|
|
if payload is None:
|
|
return None
|
|
data = self._normalize_provider_price_payload(payload, expected_symbol=normalized_symbol)
|
|
if self._cache.enabled:
|
|
await self._cache.set_json(
|
|
cache_key, self._price_data_to_cache_payload(data), ttl=self.CACHE_TTL_SECONDS
|
|
)
|
|
return data
|
|
except Exception as exc:
|
|
logger.error("Failed to fetch price for %s: %s", normalized_symbol, exc)
|
|
return None
|
|
|
|
async def _fetch_yfinance(self, symbol: str) -> dict[str, object] | None:
|
|
"""Fetch price from yfinance (run in thread pool to avoid blocking)."""
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(None, self._sync_fetch_yfinance, symbol)
|
|
|
|
def _sync_fetch_yfinance(self, symbol: str) -> dict[str, object] | None:
|
|
"""Synchronous yfinance fetch."""
|
|
ticker = yf.Ticker(symbol)
|
|
hist = ticker.history(period="1d", interval="1m")
|
|
|
|
if hist.empty:
|
|
return None
|
|
last_price = hist["Close"].iloc[-1]
|
|
return {
|
|
"symbol": symbol,
|
|
"price": float(last_price),
|
|
"currency": ticker.info.get("currency", "USD"),
|
|
"timestamp": datetime.utcnow(),
|
|
"source": "yfinance",
|
|
}
|
|
|
|
async def get_prices(self, symbols: list[str]) -> dict[str, PriceData | None]:
|
|
"""Get prices for multiple symbols concurrently."""
|
|
tasks = [self.get_price(symbol) for symbol in symbols]
|
|
results = await asyncio.gather(*tasks)
|
|
return {symbol: result for symbol, result in zip(symbols, results, strict=True)}
|