"""Live price feed service for fetching real-time GLD and other asset prices.""" from __future__ import annotations import asyncio import logging import math from dataclasses import dataclass from datetime import datetime from typing import Mapping import yfinance as yf from app.services.cache import get_cache logger = logging.getLogger(__name__) @dataclass(frozen=True) class PriceData: """Price data for a symbol.""" symbol: str price: float currency: str timestamp: datetime source: str = "yfinance" def __post_init__(self) -> None: normalized_symbol = self.symbol.strip().upper() if not normalized_symbol: raise ValueError("symbol is required") if not math.isfinite(self.price) or self.price <= 0: raise ValueError("price must be a finite positive number") normalized_currency = self.currency.strip().upper() if not normalized_currency: raise ValueError("currency is required") if not isinstance(self.timestamp, datetime): raise TypeError("timestamp must be a datetime") object.__setattr__(self, "symbol", normalized_symbol) object.__setattr__(self, "currency", normalized_currency) object.__setattr__(self, "source", self.source.strip() or "yfinance") class PriceFeed: """Live price feed service using yfinance with Redis caching.""" CACHE_TTL_SECONDS = 60 DEFAULT_SYMBOLS = ["GLD", "TLT", "BTC-USD"] def __init__(self): self._cache = get_cache() @staticmethod def _required_payload_value(payload: Mapping[str, object], key: str, *, context: str) -> object: if key not in payload: raise TypeError(f"{context} is missing required field: {key}") return payload[key] @classmethod def _normalize_cached_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData: if not isinstance(payload, Mapping): raise TypeError("cached price payload must be an object") payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper() normalized_symbol = expected_symbol.strip().upper() if payload_symbol != normalized_symbol: raise ValueError(f"cached symbol mismatch: {payload_symbol} != {normalized_symbol}") timestamp = cls._required_payload_value(payload, "timestamp", context="cached price payload") if not isinstance(timestamp, str) or not timestamp.strip(): raise TypeError("cached timestamp must be a non-empty ISO string") return PriceData( symbol=payload_symbol, price=float(cls._required_payload_value(payload, "price", context="cached price payload")), currency=str(payload.get("currency", "USD")), timestamp=datetime.fromisoformat(timestamp), source=str(payload.get("source", "yfinance")), ) @classmethod def _normalize_provider_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData: if not isinstance(payload, Mapping): raise TypeError("provider price payload must be an object") payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper() normalized_symbol = expected_symbol.strip().upper() if payload_symbol != normalized_symbol: raise ValueError(f"provider symbol mismatch: {payload_symbol} != {normalized_symbol}") timestamp = cls._required_payload_value(payload, "timestamp", context="provider price payload") if not isinstance(timestamp, datetime): raise TypeError("provider timestamp must be a datetime") return PriceData( symbol=payload_symbol, price=float(cls._required_payload_value(payload, "price", context="provider price payload")), currency=str(payload.get("currency", "USD")), timestamp=timestamp, source=str(payload.get("source", "yfinance")), ) @staticmethod def _price_data_to_cache_payload(data: PriceData) -> dict[str, object]: return { "symbol": data.symbol, "price": data.price, "currency": data.currency, "timestamp": data.timestamp.isoformat(), "source": data.source, } async def get_price(self, symbol: str) -> PriceData | None: """Get current price for a symbol, with caching.""" normalized_symbol = symbol.strip().upper() cache_key = f"price:{normalized_symbol}" if self._cache.enabled: cached = await self._cache.get_json(cache_key) if cached is not None: try: return self._normalize_cached_price_payload(cached, expected_symbol=normalized_symbol) except (TypeError, ValueError) as exc: logger.warning("Discarding cached price payload for %s: %s", normalized_symbol, exc) try: payload = await self._fetch_yfinance(normalized_symbol) if payload is None: return None data = self._normalize_provider_price_payload(payload, expected_symbol=normalized_symbol) if self._cache.enabled: await self._cache.set_json( cache_key, self._price_data_to_cache_payload(data), ttl=self.CACHE_TTL_SECONDS ) return data except Exception as exc: logger.error("Failed to fetch price for %s: %s", normalized_symbol, exc) return None async def _fetch_yfinance(self, symbol: str) -> dict[str, object] | None: """Fetch price from yfinance (run in thread pool to avoid blocking).""" loop = asyncio.get_event_loop() return await loop.run_in_executor(None, self._sync_fetch_yfinance, symbol) def _sync_fetch_yfinance(self, symbol: str) -> dict[str, object] | None: """Synchronous yfinance fetch.""" ticker = yf.Ticker(symbol) hist = ticker.history(period="1d", interval="1m") if hist.empty: return None last_price = hist["Close"].iloc[-1] return { "symbol": symbol, "price": float(last_price), "currency": ticker.info.get("currency", "USD"), "timestamp": datetime.utcnow(), "source": "yfinance", } async def get_prices(self, symbols: list[str]) -> dict[str, PriceData | None]: """Get prices for multiple symbols concurrently.""" tasks = [self.get_price(symbol) for symbol in symbols] results = await asyncio.gather(*tasks) return {symbol: result for symbol, result in zip(symbols, results, strict=True)}