fix(review): address PR review findings for CORE-003
Critical fixes: - Add math.isfinite() check to reject NaN/Infinity in _safe_quote_price - Raise TypeError instead of silent 0.0 fallback in price_feed.py - Use dict instead of Mapping for external data validation Type improvements: - Add PortfolioSnapshot TypedDict for type safety - Add DisplayMode and EntryBasisMode Literal types - Add explicit dict[str, Any] annotation in to_dict() - Remove cast() in favor of type comment validation
This commit is contained in:
@@ -2,7 +2,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Iterable, Mapping
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
from typing import cast
|
|
||||||
|
|
||||||
from app.core.pricing.black_scholes import (
|
from app.core.pricing.black_scholes import (
|
||||||
DEFAULT_RISK_FREE_RATE,
|
DEFAULT_RISK_FREE_RATE,
|
||||||
@@ -153,7 +152,7 @@ def option_row_greeks(
|
|||||||
implied_volatility = float(iv_raw) if isinstance(iv_raw, (int, float)) else 0.0
|
implied_volatility = float(iv_raw) if isinstance(iv_raw, (int, float)) else 0.0
|
||||||
volatility = implied_volatility if implied_volatility > 0 else DEFAULT_VOLATILITY
|
volatility = implied_volatility if implied_volatility > 0 else DEFAULT_VOLATILITY
|
||||||
|
|
||||||
option_type_typed: OptionType = cast(OptionType, option_type)
|
# option_type is validated to be in {"call", "put"} above, so it's safe to pass
|
||||||
try:
|
try:
|
||||||
pricing = black_scholes_price_and_greeks(
|
pricing = black_scholes_price_and_greeks(
|
||||||
BlackScholesInputs(
|
BlackScholesInputs(
|
||||||
@@ -162,7 +161,7 @@ def option_row_greeks(
|
|||||||
time_to_expiry=days_to_expiry / 365.0,
|
time_to_expiry=days_to_expiry / 365.0,
|
||||||
risk_free_rate=risk_free_rate,
|
risk_free_rate=risk_free_rate,
|
||||||
volatility=volatility,
|
volatility=volatility,
|
||||||
option_type=option_type_typed,
|
option_type=option_type, # type: ignore[arg-type]
|
||||||
valuation_date=valuation,
|
valuation_date=valuation,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import math
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from decimal import Decimal, InvalidOperation
|
from decimal import Decimal, InvalidOperation
|
||||||
from typing import Any, Mapping
|
from typing import Any, Mapping, TypedDict
|
||||||
|
|
||||||
from app.domain.backtesting_math import PricePerAsset
|
from app.domain.backtesting_math import PricePerAsset
|
||||||
from app.domain.conversions import is_gld_mode
|
from app.domain.conversions import is_gld_mode
|
||||||
@@ -16,6 +17,22 @@ _DECIMAL_ONE = Decimal("1")
|
|||||||
_DECIMAL_HUNDRED = Decimal("100")
|
_DECIMAL_HUNDRED = Decimal("100")
|
||||||
|
|
||||||
|
|
||||||
|
class PortfolioSnapshot(TypedDict):
|
||||||
|
"""Typed snapshot of portfolio state for metrics calculations."""
|
||||||
|
|
||||||
|
gold_value: float
|
||||||
|
loan_amount: float
|
||||||
|
ltv_ratio: float
|
||||||
|
net_equity: float
|
||||||
|
spot_price: float
|
||||||
|
gold_units: float
|
||||||
|
margin_call_ltv: float
|
||||||
|
margin_call_price: float
|
||||||
|
cash_buffer: float
|
||||||
|
hedge_budget: float
|
||||||
|
display_mode: str
|
||||||
|
|
||||||
|
|
||||||
def _decimal_ratio(numerator: Decimal, denominator: Decimal) -> Decimal:
|
def _decimal_ratio(numerator: Decimal, denominator: Decimal) -> Decimal:
|
||||||
if denominator == 0:
|
if denominator == 0:
|
||||||
return _DECIMAL_ZERO
|
return _DECIMAL_ZERO
|
||||||
@@ -54,6 +71,12 @@ def _gold_weight(gold_ounces: float) -> Weight:
|
|||||||
|
|
||||||
|
|
||||||
def _safe_quote_price(value: object) -> float:
|
def _safe_quote_price(value: object) -> float:
|
||||||
|
"""Parse a price value, returning 0.0 for invalid/non-finite inputs.
|
||||||
|
|
||||||
|
Rejects NaN, Infinity, and non-positive values by returning 0.0.
|
||||||
|
This defensive helper is used for quote data that may come from
|
||||||
|
untrusted sources like APIs or user input.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
if isinstance(value, (int, float)):
|
if isinstance(value, (int, float)):
|
||||||
parsed = float(value)
|
parsed = float(value)
|
||||||
@@ -63,7 +86,7 @@ def _safe_quote_price(value: object) -> float:
|
|||||||
return 0.0
|
return 0.0
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return 0.0
|
return 0.0
|
||||||
if parsed <= 0:
|
if not math.isfinite(parsed) or parsed <= 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
@@ -245,7 +268,7 @@ def portfolio_snapshot_from_config(
|
|||||||
config: PortfolioConfig | None = None,
|
config: PortfolioConfig | None = None,
|
||||||
*,
|
*,
|
||||||
runtime_spot_price: float | None = None,
|
runtime_spot_price: float | None = None,
|
||||||
) -> dict[str, float | str]:
|
) -> PortfolioSnapshot:
|
||||||
"""Build portfolio snapshot with display-mode-aware calculations.
|
"""Build portfolio snapshot with display-mode-aware calculations.
|
||||||
|
|
||||||
In GLD mode:
|
In GLD mode:
|
||||||
@@ -369,7 +392,7 @@ def build_alert_context(
|
|||||||
|
|
||||||
|
|
||||||
def strategy_metrics_from_snapshot(
|
def strategy_metrics_from_snapshot(
|
||||||
strategy: dict[str, Any], scenario_pct: int, snapshot: dict[str, Any]
|
strategy: dict[str, Any], scenario_pct: int, snapshot: PortfolioSnapshot
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
spot = decimal_from_float(float(snapshot["spot_price"]))
|
spot = decimal_from_float(float(snapshot["spot_price"]))
|
||||||
gold_weight = _gold_weight(float(snapshot["gold_units"]))
|
gold_weight = _gold_weight(float(snapshot["gold_units"]))
|
||||||
|
|||||||
@@ -8,10 +8,14 @@ from dataclasses import dataclass, field
|
|||||||
from datetime import date
|
from datetime import date
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
from app.models.position import Position, create_position
|
from app.models.position import Position, create_position
|
||||||
|
|
||||||
|
# Type aliases for display mode and entry basis
|
||||||
|
DisplayMode = Literal["GLD", "XAU"]
|
||||||
|
EntryBasisMode = Literal["value_price", "weight"]
|
||||||
|
|
||||||
_DEFAULT_GOLD_VALUE = 215_000.0
|
_DEFAULT_GOLD_VALUE = 215_000.0
|
||||||
_DEFAULT_ENTRY_PRICE = 2_150.0
|
_DEFAULT_ENTRY_PRICE = 2_150.0
|
||||||
_LEGACY_DEFAULT_ENTRY_PRICE = 215.0
|
_LEGACY_DEFAULT_ENTRY_PRICE = 215.0
|
||||||
@@ -102,7 +106,7 @@ class PortfolioConfig:
|
|||||||
gold_value: float | None = None
|
gold_value: float | None = None
|
||||||
entry_price: float | None = _DEFAULT_ENTRY_PRICE
|
entry_price: float | None = _DEFAULT_ENTRY_PRICE
|
||||||
gold_ounces: float | None = None
|
gold_ounces: float | None = None
|
||||||
entry_basis_mode: str = "value_price"
|
entry_basis_mode: EntryBasisMode = "value_price"
|
||||||
loan_amount: float = 145000.0
|
loan_amount: float = 145000.0
|
||||||
margin_threshold: float = 0.75
|
margin_threshold: float = 0.75
|
||||||
monthly_budget: float = 8000.0
|
monthly_budget: float = 8000.0
|
||||||
@@ -117,7 +121,7 @@ class PortfolioConfig:
|
|||||||
underlying: str = "GLD"
|
underlying: str = "GLD"
|
||||||
|
|
||||||
# Display mode: how to show positions (GLD shares vs physical gold)
|
# Display mode: how to show positions (GLD shares vs physical gold)
|
||||||
display_mode: str = "XAU" # "GLD" for share view, "XAU" for physical gold view
|
display_mode: DisplayMode = "XAU" # "GLD" for share view, "XAU" for physical gold view
|
||||||
|
|
||||||
# Alert settings
|
# Alert settings
|
||||||
volatility_spike: float = 0.25
|
volatility_spike: float = 0.25
|
||||||
@@ -301,7 +305,7 @@ class PortfolioConfig:
|
|||||||
assert self.gold_ounces is not None
|
assert self.gold_ounces is not None
|
||||||
# Sync legacy fields from positions before serializing
|
# Sync legacy fields from positions before serializing
|
||||||
self._sync_legacy_fields_from_positions()
|
self._sync_legacy_fields_from_positions()
|
||||||
result = {
|
result: dict[str, Any] = {
|
||||||
"gold_value": self.gold_value,
|
"gold_value": self.gold_value,
|
||||||
"entry_price": self.entry_price,
|
"entry_price": self.entry_price,
|
||||||
"gold_ounces": self.gold_ounces,
|
"gold_ounces": self.gold_ounces,
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Mapping
|
|
||||||
|
|
||||||
import yfinance as yf
|
import yfinance as yf
|
||||||
|
|
||||||
@@ -52,15 +51,15 @@ class PriceFeed:
|
|||||||
self._cache = get_cache()
|
self._cache = get_cache()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _required_payload_value(payload: Mapping[str, object], key: str, *, context: str) -> object:
|
def _required_payload_value(payload: dict[str, object], key: str, *, context: str) -> object:
|
||||||
if key not in payload:
|
if key not in payload:
|
||||||
raise TypeError(f"{context} is missing required field: {key}")
|
raise TypeError(f"{context} is missing required field: {key}")
|
||||||
return payload[key]
|
return payload[key]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _normalize_cached_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData:
|
def _normalize_cached_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData:
|
||||||
if not isinstance(payload, Mapping):
|
if not isinstance(payload, dict):
|
||||||
raise TypeError("cached price payload must be an object")
|
raise TypeError("cached price payload must be a plain dict")
|
||||||
payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper()
|
payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper()
|
||||||
normalized_symbol = expected_symbol.strip().upper()
|
normalized_symbol = expected_symbol.strip().upper()
|
||||||
if payload_symbol != normalized_symbol:
|
if payload_symbol != normalized_symbol:
|
||||||
@@ -69,7 +68,9 @@ class PriceFeed:
|
|||||||
if not isinstance(timestamp, str) or not timestamp.strip():
|
if not isinstance(timestamp, str) or not timestamp.strip():
|
||||||
raise TypeError("cached timestamp must be a non-empty ISO string")
|
raise TypeError("cached timestamp must be a non-empty ISO string")
|
||||||
price_val = cls._required_payload_value(payload, "price", context="cached price payload")
|
price_val = cls._required_payload_value(payload, "price", context="cached price payload")
|
||||||
price = float(price_val) if isinstance(price_val, (int, float)) else 0.0
|
if not isinstance(price_val, (int, float)):
|
||||||
|
raise TypeError(f"cached price must be numeric, got {type(price_val).__name__}")
|
||||||
|
price = float(price_val)
|
||||||
return PriceData(
|
return PriceData(
|
||||||
symbol=payload_symbol,
|
symbol=payload_symbol,
|
||||||
price=price,
|
price=price,
|
||||||
@@ -80,8 +81,8 @@ class PriceFeed:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _normalize_provider_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData:
|
def _normalize_provider_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData:
|
||||||
if not isinstance(payload, Mapping):
|
if not isinstance(payload, dict):
|
||||||
raise TypeError("provider price payload must be an object")
|
raise TypeError("provider price payload must be a plain dict")
|
||||||
payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper()
|
payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper()
|
||||||
normalized_symbol = expected_symbol.strip().upper()
|
normalized_symbol = expected_symbol.strip().upper()
|
||||||
if payload_symbol != normalized_symbol:
|
if payload_symbol != normalized_symbol:
|
||||||
@@ -90,7 +91,9 @@ class PriceFeed:
|
|||||||
if not isinstance(timestamp, datetime):
|
if not isinstance(timestamp, datetime):
|
||||||
raise TypeError("provider timestamp must be a datetime")
|
raise TypeError("provider timestamp must be a datetime")
|
||||||
price_val = cls._required_payload_value(payload, "price", context="provider price payload")
|
price_val = cls._required_payload_value(payload, "price", context="provider price payload")
|
||||||
price = float(price_val) if isinstance(price_val, (int, float)) else 0.0
|
if not isinstance(price_val, (int, float)):
|
||||||
|
raise TypeError(f"provider price must be numeric, got {type(price_val).__name__}")
|
||||||
|
price = float(price_val)
|
||||||
return PriceData(
|
return PriceData(
|
||||||
symbol=payload_symbol,
|
symbol=payload_symbol,
|
||||||
price=price,
|
price=price,
|
||||||
|
|||||||
Reference in New Issue
Block a user