From 98e3208b5e66a9058d3de89b58c97ca638121fc5 Mon Sep 17 00:00:00 2001 From: Bu5hm4nn Date: Mon, 30 Mar 2026 00:39:02 +0200 Subject: [PATCH] fix(review): address PR review findings for CORE-003 Critical fixes: - Add math.isfinite() check to reject NaN/Infinity in _safe_quote_price - Raise TypeError instead of silent 0.0 fallback in price_feed.py - Use dict instead of Mapping for external data validation Type improvements: - Add PortfolioSnapshot TypedDict for type safety - Add DisplayMode and EntryBasisMode Literal types - Add explicit dict[str, Any] annotation in to_dict() - Remove cast() in favor of type comment validation --- app/core/calculations.py | 5 ++--- app/domain/portfolio_math.py | 31 +++++++++++++++++++++++++++---- app/models/portfolio.py | 12 ++++++++---- app/services/price_feed.py | 19 +++++++++++-------- 4 files changed, 48 insertions(+), 19 deletions(-) diff --git a/app/core/calculations.py b/app/core/calculations.py index da1286b..3e2fe4b 100644 --- a/app/core/calculations.py +++ b/app/core/calculations.py @@ -2,7 +2,6 @@ from __future__ import annotations from collections.abc import Iterable, Mapping from datetime import date, datetime -from typing import cast from app.core.pricing.black_scholes import ( DEFAULT_RISK_FREE_RATE, @@ -153,7 +152,7 @@ def option_row_greeks( implied_volatility = float(iv_raw) if isinstance(iv_raw, (int, float)) else 0.0 volatility = implied_volatility if implied_volatility > 0 else DEFAULT_VOLATILITY - option_type_typed: OptionType = cast(OptionType, option_type) + # option_type is validated to be in {"call", "put"} above, so it's safe to pass try: pricing = black_scholes_price_and_greeks( BlackScholesInputs( @@ -162,7 +161,7 @@ def option_row_greeks( time_to_expiry=days_to_expiry / 365.0, risk_free_rate=risk_free_rate, volatility=volatility, - option_type=option_type_typed, + option_type=option_type, # type: ignore[arg-type] valuation_date=valuation, ) ) diff --git a/app/domain/portfolio_math.py b/app/domain/portfolio_math.py index 6779624..6d3afe5 100644 --- a/app/domain/portfolio_math.py +++ b/app/domain/portfolio_math.py @@ -1,8 +1,9 @@ from __future__ import annotations +import math from datetime import date from decimal import Decimal, InvalidOperation -from typing import Any, Mapping +from typing import Any, Mapping, TypedDict from app.domain.backtesting_math import PricePerAsset from app.domain.conversions import is_gld_mode @@ -16,6 +17,22 @@ _DECIMAL_ONE = Decimal("1") _DECIMAL_HUNDRED = Decimal("100") +class PortfolioSnapshot(TypedDict): + """Typed snapshot of portfolio state for metrics calculations.""" + + gold_value: float + loan_amount: float + ltv_ratio: float + net_equity: float + spot_price: float + gold_units: float + margin_call_ltv: float + margin_call_price: float + cash_buffer: float + hedge_budget: float + display_mode: str + + def _decimal_ratio(numerator: Decimal, denominator: Decimal) -> Decimal: if denominator == 0: return _DECIMAL_ZERO @@ -54,6 +71,12 @@ def _gold_weight(gold_ounces: float) -> Weight: def _safe_quote_price(value: object) -> float: + """Parse a price value, returning 0.0 for invalid/non-finite inputs. + + Rejects NaN, Infinity, and non-positive values by returning 0.0. + This defensive helper is used for quote data that may come from + untrusted sources like APIs or user input. + """ try: if isinstance(value, (int, float)): parsed = float(value) @@ -63,7 +86,7 @@ def _safe_quote_price(value: object) -> float: return 0.0 except (TypeError, ValueError): return 0.0 - if parsed <= 0: + if not math.isfinite(parsed) or parsed <= 0: return 0.0 return parsed @@ -245,7 +268,7 @@ def portfolio_snapshot_from_config( config: PortfolioConfig | None = None, *, runtime_spot_price: float | None = None, -) -> dict[str, float | str]: +) -> PortfolioSnapshot: """Build portfolio snapshot with display-mode-aware calculations. In GLD mode: @@ -369,7 +392,7 @@ def build_alert_context( def strategy_metrics_from_snapshot( - strategy: dict[str, Any], scenario_pct: int, snapshot: dict[str, Any] + strategy: dict[str, Any], scenario_pct: int, snapshot: PortfolioSnapshot ) -> dict[str, Any]: spot = decimal_from_float(float(snapshot["spot_price"])) gold_weight = _gold_weight(float(snapshot["gold_units"])) diff --git a/app/models/portfolio.py b/app/models/portfolio.py index 795063f..69dc460 100644 --- a/app/models/portfolio.py +++ b/app/models/portfolio.py @@ -8,10 +8,14 @@ from dataclasses import dataclass, field from datetime import date from decimal import Decimal from pathlib import Path -from typing import Any +from typing import Any, Literal from app.models.position import Position, create_position +# Type aliases for display mode and entry basis +DisplayMode = Literal["GLD", "XAU"] +EntryBasisMode = Literal["value_price", "weight"] + _DEFAULT_GOLD_VALUE = 215_000.0 _DEFAULT_ENTRY_PRICE = 2_150.0 _LEGACY_DEFAULT_ENTRY_PRICE = 215.0 @@ -102,7 +106,7 @@ class PortfolioConfig: gold_value: float | None = None entry_price: float | None = _DEFAULT_ENTRY_PRICE gold_ounces: float | None = None - entry_basis_mode: str = "value_price" + entry_basis_mode: EntryBasisMode = "value_price" loan_amount: float = 145000.0 margin_threshold: float = 0.75 monthly_budget: float = 8000.0 @@ -117,7 +121,7 @@ class PortfolioConfig: underlying: str = "GLD" # Display mode: how to show positions (GLD shares vs physical gold) - display_mode: str = "XAU" # "GLD" for share view, "XAU" for physical gold view + display_mode: DisplayMode = "XAU" # "GLD" for share view, "XAU" for physical gold view # Alert settings volatility_spike: float = 0.25 @@ -301,7 +305,7 @@ class PortfolioConfig: assert self.gold_ounces is not None # Sync legacy fields from positions before serializing self._sync_legacy_fields_from_positions() - result = { + result: dict[str, Any] = { "gold_value": self.gold_value, "entry_price": self.entry_price, "gold_ounces": self.gold_ounces, diff --git a/app/services/price_feed.py b/app/services/price_feed.py index 9ae70b9..17b50da 100644 --- a/app/services/price_feed.py +++ b/app/services/price_feed.py @@ -7,7 +7,6 @@ import logging import math from dataclasses import dataclass from datetime import datetime -from typing import Mapping import yfinance as yf @@ -52,15 +51,15 @@ class PriceFeed: self._cache = get_cache() @staticmethod - def _required_payload_value(payload: Mapping[str, object], key: str, *, context: str) -> object: + def _required_payload_value(payload: dict[str, object], key: str, *, context: str) -> object: if key not in payload: raise TypeError(f"{context} is missing required field: {key}") return payload[key] @classmethod def _normalize_cached_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData: - if not isinstance(payload, Mapping): - raise TypeError("cached price payload must be an object") + if not isinstance(payload, dict): + raise TypeError("cached price payload must be a plain dict") payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper() normalized_symbol = expected_symbol.strip().upper() if payload_symbol != normalized_symbol: @@ -69,7 +68,9 @@ class PriceFeed: if not isinstance(timestamp, str) or not timestamp.strip(): raise TypeError("cached timestamp must be a non-empty ISO string") price_val = cls._required_payload_value(payload, "price", context="cached price payload") - price = float(price_val) if isinstance(price_val, (int, float)) else 0.0 + if not isinstance(price_val, (int, float)): + raise TypeError(f"cached price must be numeric, got {type(price_val).__name__}") + price = float(price_val) return PriceData( symbol=payload_symbol, price=price, @@ -80,8 +81,8 @@ class PriceFeed: @classmethod def _normalize_provider_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData: - if not isinstance(payload, Mapping): - raise TypeError("provider price payload must be an object") + if not isinstance(payload, dict): + raise TypeError("provider price payload must be a plain dict") payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper() normalized_symbol = expected_symbol.strip().upper() if payload_symbol != normalized_symbol: @@ -90,7 +91,9 @@ class PriceFeed: if not isinstance(timestamp, datetime): raise TypeError("provider timestamp must be a datetime") price_val = cls._required_payload_value(payload, "price", context="provider price payload") - price = float(price_val) if isinstance(price_val, (int, float)) else 0.0 + if not isinstance(price_val, (int, float)): + raise TypeError(f"provider price must be numeric, got {type(price_val).__name__}") + price = float(price_val) return PriceData( symbol=payload_symbol, price=price,