fix(review): address PR review findings for CORE-003

Critical fixes:
- Add math.isfinite() check to reject NaN/Infinity in _safe_quote_price
- Raise TypeError instead of silent 0.0 fallback in price_feed.py
- Use dict instead of Mapping for external data validation

Type improvements:
- Add PortfolioSnapshot TypedDict for type safety
- Add DisplayMode and EntryBasisMode Literal types
- Add explicit dict[str, Any] annotation in to_dict()
- Remove cast() in favor of type comment validation
This commit is contained in:
Bu5hm4nn
2026-03-30 00:39:02 +02:00
parent 1dce5bfd23
commit 98e3208b5e
4 changed files with 48 additions and 19 deletions

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
from collections.abc import Iterable, Mapping
from datetime import date, datetime
from typing import cast
from app.core.pricing.black_scholes import (
DEFAULT_RISK_FREE_RATE,
@@ -153,7 +152,7 @@ def option_row_greeks(
implied_volatility = float(iv_raw) if isinstance(iv_raw, (int, float)) else 0.0
volatility = implied_volatility if implied_volatility > 0 else DEFAULT_VOLATILITY
option_type_typed: OptionType = cast(OptionType, option_type)
# option_type is validated to be in {"call", "put"} above, so it's safe to pass
try:
pricing = black_scholes_price_and_greeks(
BlackScholesInputs(
@@ -162,7 +161,7 @@ def option_row_greeks(
time_to_expiry=days_to_expiry / 365.0,
risk_free_rate=risk_free_rate,
volatility=volatility,
option_type=option_type_typed,
option_type=option_type, # type: ignore[arg-type]
valuation_date=valuation,
)
)

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
import math
from datetime import date
from decimal import Decimal, InvalidOperation
from typing import Any, Mapping
from typing import Any, Mapping, TypedDict
from app.domain.backtesting_math import PricePerAsset
from app.domain.conversions import is_gld_mode
@@ -16,6 +17,22 @@ _DECIMAL_ONE = Decimal("1")
_DECIMAL_HUNDRED = Decimal("100")
class PortfolioSnapshot(TypedDict):
"""Typed snapshot of portfolio state for metrics calculations."""
gold_value: float
loan_amount: float
ltv_ratio: float
net_equity: float
spot_price: float
gold_units: float
margin_call_ltv: float
margin_call_price: float
cash_buffer: float
hedge_budget: float
display_mode: str
def _decimal_ratio(numerator: Decimal, denominator: Decimal) -> Decimal:
if denominator == 0:
return _DECIMAL_ZERO
@@ -54,6 +71,12 @@ def _gold_weight(gold_ounces: float) -> Weight:
def _safe_quote_price(value: object) -> float:
"""Parse a price value, returning 0.0 for invalid/non-finite inputs.
Rejects NaN, Infinity, and non-positive values by returning 0.0.
This defensive helper is used for quote data that may come from
untrusted sources like APIs or user input.
"""
try:
if isinstance(value, (int, float)):
parsed = float(value)
@@ -63,7 +86,7 @@ def _safe_quote_price(value: object) -> float:
return 0.0
except (TypeError, ValueError):
return 0.0
if parsed <= 0:
if not math.isfinite(parsed) or parsed <= 0:
return 0.0
return parsed
@@ -245,7 +268,7 @@ def portfolio_snapshot_from_config(
config: PortfolioConfig | None = None,
*,
runtime_spot_price: float | None = None,
) -> dict[str, float | str]:
) -> PortfolioSnapshot:
"""Build portfolio snapshot with display-mode-aware calculations.
In GLD mode:
@@ -369,7 +392,7 @@ def build_alert_context(
def strategy_metrics_from_snapshot(
strategy: dict[str, Any], scenario_pct: int, snapshot: dict[str, Any]
strategy: dict[str, Any], scenario_pct: int, snapshot: PortfolioSnapshot
) -> dict[str, Any]:
spot = decimal_from_float(float(snapshot["spot_price"]))
gold_weight = _gold_weight(float(snapshot["gold_units"]))

View File

@@ -8,10 +8,14 @@ from dataclasses import dataclass, field
from datetime import date
from decimal import Decimal
from pathlib import Path
from typing import Any
from typing import Any, Literal
from app.models.position import Position, create_position
# Type aliases for display mode and entry basis
DisplayMode = Literal["GLD", "XAU"]
EntryBasisMode = Literal["value_price", "weight"]
_DEFAULT_GOLD_VALUE = 215_000.0
_DEFAULT_ENTRY_PRICE = 2_150.0
_LEGACY_DEFAULT_ENTRY_PRICE = 215.0
@@ -102,7 +106,7 @@ class PortfolioConfig:
gold_value: float | None = None
entry_price: float | None = _DEFAULT_ENTRY_PRICE
gold_ounces: float | None = None
entry_basis_mode: str = "value_price"
entry_basis_mode: EntryBasisMode = "value_price"
loan_amount: float = 145000.0
margin_threshold: float = 0.75
monthly_budget: float = 8000.0
@@ -117,7 +121,7 @@ class PortfolioConfig:
underlying: str = "GLD"
# Display mode: how to show positions (GLD shares vs physical gold)
display_mode: str = "XAU" # "GLD" for share view, "XAU" for physical gold view
display_mode: DisplayMode = "XAU" # "GLD" for share view, "XAU" for physical gold view
# Alert settings
volatility_spike: float = 0.25
@@ -301,7 +305,7 @@ class PortfolioConfig:
assert self.gold_ounces is not None
# Sync legacy fields from positions before serializing
self._sync_legacy_fields_from_positions()
result = {
result: dict[str, Any] = {
"gold_value": self.gold_value,
"entry_price": self.entry_price,
"gold_ounces": self.gold_ounces,

View File

@@ -7,7 +7,6 @@ import logging
import math
from dataclasses import dataclass
from datetime import datetime
from typing import Mapping
import yfinance as yf
@@ -52,15 +51,15 @@ class PriceFeed:
self._cache = get_cache()
@staticmethod
def _required_payload_value(payload: Mapping[str, object], key: str, *, context: str) -> object:
def _required_payload_value(payload: dict[str, object], key: str, *, context: str) -> object:
if key not in payload:
raise TypeError(f"{context} is missing required field: {key}")
return payload[key]
@classmethod
def _normalize_cached_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData:
if not isinstance(payload, Mapping):
raise TypeError("cached price payload must be an object")
if not isinstance(payload, dict):
raise TypeError("cached price payload must be a plain dict")
payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper()
normalized_symbol = expected_symbol.strip().upper()
if payload_symbol != normalized_symbol:
@@ -69,7 +68,9 @@ class PriceFeed:
if not isinstance(timestamp, str) or not timestamp.strip():
raise TypeError("cached timestamp must be a non-empty ISO string")
price_val = cls._required_payload_value(payload, "price", context="cached price payload")
price = float(price_val) if isinstance(price_val, (int, float)) else 0.0
if not isinstance(price_val, (int, float)):
raise TypeError(f"cached price must be numeric, got {type(price_val).__name__}")
price = float(price_val)
return PriceData(
symbol=payload_symbol,
price=price,
@@ -80,8 +81,8 @@ class PriceFeed:
@classmethod
def _normalize_provider_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData:
if not isinstance(payload, Mapping):
raise TypeError("provider price payload must be an object")
if not isinstance(payload, dict):
raise TypeError("provider price payload must be a plain dict")
payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper()
normalized_symbol = expected_symbol.strip().upper()
if payload_symbol != normalized_symbol:
@@ -90,7 +91,9 @@ class PriceFeed:
if not isinstance(timestamp, datetime):
raise TypeError("provider timestamp must be a datetime")
price_val = cls._required_payload_value(payload, "price", context="provider price payload")
price = float(price_val) if isinstance(price_val, (int, float)) else 0.0
if not isinstance(price_val, (int, float)):
raise TypeError(f"provider price must be numeric, got {type(price_val).__name__}")
price = float(price_val)
return PriceData(
symbol=payload_symbol,
price=price,