fix(review): address PR review findings for CORE-003

Critical fixes:
- Add math.isfinite() check to reject NaN/Infinity in _safe_quote_price
- Raise TypeError instead of silent 0.0 fallback in price_feed.py
- Use dict instead of Mapping for external data validation

Type improvements:
- Add PortfolioSnapshot TypedDict for type safety
- Add DisplayMode and EntryBasisMode Literal types
- Add explicit dict[str, Any] annotation in to_dict()
- Remove cast() in favor of type comment validation
This commit is contained in:
Bu5hm4nn
2026-03-30 00:39:02 +02:00
parent 1dce5bfd23
commit 98e3208b5e
4 changed files with 48 additions and 19 deletions

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from datetime import date, datetime from datetime import date, datetime
from typing import cast
from app.core.pricing.black_scholes import ( from app.core.pricing.black_scholes import (
DEFAULT_RISK_FREE_RATE, DEFAULT_RISK_FREE_RATE,
@@ -153,7 +152,7 @@ def option_row_greeks(
implied_volatility = float(iv_raw) if isinstance(iv_raw, (int, float)) else 0.0 implied_volatility = float(iv_raw) if isinstance(iv_raw, (int, float)) else 0.0
volatility = implied_volatility if implied_volatility > 0 else DEFAULT_VOLATILITY volatility = implied_volatility if implied_volatility > 0 else DEFAULT_VOLATILITY
option_type_typed: OptionType = cast(OptionType, option_type) # option_type is validated to be in {"call", "put"} above, so it's safe to pass
try: try:
pricing = black_scholes_price_and_greeks( pricing = black_scholes_price_and_greeks(
BlackScholesInputs( BlackScholesInputs(
@@ -162,7 +161,7 @@ def option_row_greeks(
time_to_expiry=days_to_expiry / 365.0, time_to_expiry=days_to_expiry / 365.0,
risk_free_rate=risk_free_rate, risk_free_rate=risk_free_rate,
volatility=volatility, volatility=volatility,
option_type=option_type_typed, option_type=option_type, # type: ignore[arg-type]
valuation_date=valuation, valuation_date=valuation,
) )
) )

View File

@@ -1,8 +1,9 @@
from __future__ import annotations from __future__ import annotations
import math
from datetime import date from datetime import date
from decimal import Decimal, InvalidOperation from decimal import Decimal, InvalidOperation
from typing import Any, Mapping from typing import Any, Mapping, TypedDict
from app.domain.backtesting_math import PricePerAsset from app.domain.backtesting_math import PricePerAsset
from app.domain.conversions import is_gld_mode from app.domain.conversions import is_gld_mode
@@ -16,6 +17,22 @@ _DECIMAL_ONE = Decimal("1")
_DECIMAL_HUNDRED = Decimal("100") _DECIMAL_HUNDRED = Decimal("100")
class PortfolioSnapshot(TypedDict):
"""Typed snapshot of portfolio state for metrics calculations."""
gold_value: float
loan_amount: float
ltv_ratio: float
net_equity: float
spot_price: float
gold_units: float
margin_call_ltv: float
margin_call_price: float
cash_buffer: float
hedge_budget: float
display_mode: str
def _decimal_ratio(numerator: Decimal, denominator: Decimal) -> Decimal: def _decimal_ratio(numerator: Decimal, denominator: Decimal) -> Decimal:
if denominator == 0: if denominator == 0:
return _DECIMAL_ZERO return _DECIMAL_ZERO
@@ -54,6 +71,12 @@ def _gold_weight(gold_ounces: float) -> Weight:
def _safe_quote_price(value: object) -> float: def _safe_quote_price(value: object) -> float:
"""Parse a price value, returning 0.0 for invalid/non-finite inputs.
Rejects NaN, Infinity, and non-positive values by returning 0.0.
This defensive helper is used for quote data that may come from
untrusted sources like APIs or user input.
"""
try: try:
if isinstance(value, (int, float)): if isinstance(value, (int, float)):
parsed = float(value) parsed = float(value)
@@ -63,7 +86,7 @@ def _safe_quote_price(value: object) -> float:
return 0.0 return 0.0
except (TypeError, ValueError): except (TypeError, ValueError):
return 0.0 return 0.0
if parsed <= 0: if not math.isfinite(parsed) or parsed <= 0:
return 0.0 return 0.0
return parsed return parsed
@@ -245,7 +268,7 @@ def portfolio_snapshot_from_config(
config: PortfolioConfig | None = None, config: PortfolioConfig | None = None,
*, *,
runtime_spot_price: float | None = None, runtime_spot_price: float | None = None,
) -> dict[str, float | str]: ) -> PortfolioSnapshot:
"""Build portfolio snapshot with display-mode-aware calculations. """Build portfolio snapshot with display-mode-aware calculations.
In GLD mode: In GLD mode:
@@ -369,7 +392,7 @@ def build_alert_context(
def strategy_metrics_from_snapshot( def strategy_metrics_from_snapshot(
strategy: dict[str, Any], scenario_pct: int, snapshot: dict[str, Any] strategy: dict[str, Any], scenario_pct: int, snapshot: PortfolioSnapshot
) -> dict[str, Any]: ) -> dict[str, Any]:
spot = decimal_from_float(float(snapshot["spot_price"])) spot = decimal_from_float(float(snapshot["spot_price"]))
gold_weight = _gold_weight(float(snapshot["gold_units"])) gold_weight = _gold_weight(float(snapshot["gold_units"]))

View File

@@ -8,10 +8,14 @@ from dataclasses import dataclass, field
from datetime import date from datetime import date
from decimal import Decimal from decimal import Decimal
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Literal
from app.models.position import Position, create_position from app.models.position import Position, create_position
# Type aliases for display mode and entry basis
DisplayMode = Literal["GLD", "XAU"]
EntryBasisMode = Literal["value_price", "weight"]
_DEFAULT_GOLD_VALUE = 215_000.0 _DEFAULT_GOLD_VALUE = 215_000.0
_DEFAULT_ENTRY_PRICE = 2_150.0 _DEFAULT_ENTRY_PRICE = 2_150.0
_LEGACY_DEFAULT_ENTRY_PRICE = 215.0 _LEGACY_DEFAULT_ENTRY_PRICE = 215.0
@@ -102,7 +106,7 @@ class PortfolioConfig:
gold_value: float | None = None gold_value: float | None = None
entry_price: float | None = _DEFAULT_ENTRY_PRICE entry_price: float | None = _DEFAULT_ENTRY_PRICE
gold_ounces: float | None = None gold_ounces: float | None = None
entry_basis_mode: str = "value_price" entry_basis_mode: EntryBasisMode = "value_price"
loan_amount: float = 145000.0 loan_amount: float = 145000.0
margin_threshold: float = 0.75 margin_threshold: float = 0.75
monthly_budget: float = 8000.0 monthly_budget: float = 8000.0
@@ -117,7 +121,7 @@ class PortfolioConfig:
underlying: str = "GLD" underlying: str = "GLD"
# Display mode: how to show positions (GLD shares vs physical gold) # Display mode: how to show positions (GLD shares vs physical gold)
display_mode: str = "XAU" # "GLD" for share view, "XAU" for physical gold view display_mode: DisplayMode = "XAU" # "GLD" for share view, "XAU" for physical gold view
# Alert settings # Alert settings
volatility_spike: float = 0.25 volatility_spike: float = 0.25
@@ -301,7 +305,7 @@ class PortfolioConfig:
assert self.gold_ounces is not None assert self.gold_ounces is not None
# Sync legacy fields from positions before serializing # Sync legacy fields from positions before serializing
self._sync_legacy_fields_from_positions() self._sync_legacy_fields_from_positions()
result = { result: dict[str, Any] = {
"gold_value": self.gold_value, "gold_value": self.gold_value,
"entry_price": self.entry_price, "entry_price": self.entry_price,
"gold_ounces": self.gold_ounces, "gold_ounces": self.gold_ounces,

View File

@@ -7,7 +7,6 @@ import logging
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Mapping
import yfinance as yf import yfinance as yf
@@ -52,15 +51,15 @@ class PriceFeed:
self._cache = get_cache() self._cache = get_cache()
@staticmethod @staticmethod
def _required_payload_value(payload: Mapping[str, object], key: str, *, context: str) -> object: def _required_payload_value(payload: dict[str, object], key: str, *, context: str) -> object:
if key not in payload: if key not in payload:
raise TypeError(f"{context} is missing required field: {key}") raise TypeError(f"{context} is missing required field: {key}")
return payload[key] return payload[key]
@classmethod @classmethod
def _normalize_cached_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData: def _normalize_cached_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData:
if not isinstance(payload, Mapping): if not isinstance(payload, dict):
raise TypeError("cached price payload must be an object") raise TypeError("cached price payload must be a plain dict")
payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper() payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper()
normalized_symbol = expected_symbol.strip().upper() normalized_symbol = expected_symbol.strip().upper()
if payload_symbol != normalized_symbol: if payload_symbol != normalized_symbol:
@@ -69,7 +68,9 @@ class PriceFeed:
if not isinstance(timestamp, str) or not timestamp.strip(): if not isinstance(timestamp, str) or not timestamp.strip():
raise TypeError("cached timestamp must be a non-empty ISO string") raise TypeError("cached timestamp must be a non-empty ISO string")
price_val = cls._required_payload_value(payload, "price", context="cached price payload") price_val = cls._required_payload_value(payload, "price", context="cached price payload")
price = float(price_val) if isinstance(price_val, (int, float)) else 0.0 if not isinstance(price_val, (int, float)):
raise TypeError(f"cached price must be numeric, got {type(price_val).__name__}")
price = float(price_val)
return PriceData( return PriceData(
symbol=payload_symbol, symbol=payload_symbol,
price=price, price=price,
@@ -80,8 +81,8 @@ class PriceFeed:
@classmethod @classmethod
def _normalize_provider_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData: def _normalize_provider_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData:
if not isinstance(payload, Mapping): if not isinstance(payload, dict):
raise TypeError("provider price payload must be an object") raise TypeError("provider price payload must be a plain dict")
payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper() payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper()
normalized_symbol = expected_symbol.strip().upper() normalized_symbol = expected_symbol.strip().upper()
if payload_symbol != normalized_symbol: if payload_symbol != normalized_symbol:
@@ -90,7 +91,9 @@ class PriceFeed:
if not isinstance(timestamp, datetime): if not isinstance(timestamp, datetime):
raise TypeError("provider timestamp must be a datetime") raise TypeError("provider timestamp must be a datetime")
price_val = cls._required_payload_value(payload, "price", context="provider price payload") price_val = cls._required_payload_value(payload, "price", context="provider price payload")
price = float(price_val) if isinstance(price_val, (int, float)) else 0.0 if not isinstance(price_val, (int, float)):
raise TypeError(f"provider price must be numeric, got {type(price_val).__name__}")
price = float(price_val)
return PriceData( return PriceData(
symbol=payload_symbol, symbol=payload_symbol,
price=price, price=price,