diff --git a/app/services/event_comparison_ui.py b/app/services/event_comparison_ui.py index c286eb6..16979e3 100644 --- a/app/services/event_comparison_ui.py +++ b/app/services/event_comparison_ui.py @@ -117,13 +117,18 @@ class EventComparisonPageService: ) -> BacktestScenario: if not template_slugs: raise ValueError("Select at least one strategy template.") + normalized_inputs = normalize_historical_scenario_inputs( + underlying_units=underlying_units, + loan_amount=loan_amount, + margin_call_ltv=margin_call_ltv, + ) try: scenario = self.comparison_service.preview_scenario_from_inputs( preset_slug=preset_slug, template_slugs=template_slugs, - underlying_units=underlying_units, - loan_amount=loan_amount, - margin_call_ltv=margin_call_ltv, + underlying_units=normalized_inputs.underlying_units, + loan_amount=normalized_inputs.loan_amount, + margin_call_ltv=normalized_inputs.margin_call_ltv, ) except ValueError as exc: if str(exc) == "loan_amount must be less than initial collateral value": @@ -134,9 +139,17 @@ class EventComparisonPageService: preset.window_end, ) if preview: - _validate_initial_collateral(underlying_units, preview[0].close, loan_amount) + _validate_initial_collateral( + normalized_inputs.underlying_units, + preview[0].close, + normalized_inputs.loan_amount, + ) raise - _validate_initial_collateral(underlying_units, scenario.initial_portfolio.entry_spot, loan_amount) + _validate_initial_collateral( + normalized_inputs.underlying_units, + scenario.initial_portfolio.entry_spot, + normalized_inputs.loan_amount, + ) return scenario def run_read_only_comparison( diff --git a/app/services/price_feed.py b/app/services/price_feed.py index 9a08939..132d037 100644 --- a/app/services/price_feed.py +++ b/app/services/price_feed.py @@ -52,38 +52,44 @@ class PriceFeed: self._cache = get_cache() @staticmethod - def _normalize_cached_price_payload(payload: object, *, expected_symbol: str) -> PriceData: + def _required_payload_value(payload: Mapping[str, object], key: str, *, context: str) -> object: + if key not in payload: + raise TypeError(f"{context} is missing required field: {key}") + return payload[key] + + @classmethod + def _normalize_cached_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData: if not isinstance(payload, Mapping): raise TypeError("cached price payload must be an object") payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper() normalized_symbol = expected_symbol.strip().upper() if payload_symbol != normalized_symbol: raise ValueError(f"cached symbol mismatch: {payload_symbol} != {normalized_symbol}") - timestamp = payload.get("timestamp") + timestamp = cls._required_payload_value(payload, "timestamp", context="cached price payload") if not isinstance(timestamp, str) or not timestamp.strip(): raise TypeError("cached timestamp must be a non-empty ISO string") return PriceData( symbol=payload_symbol, - price=float(payload["price"]), + price=float(cls._required_payload_value(payload, "price", context="cached price payload")), currency=str(payload.get("currency", "USD")), timestamp=datetime.fromisoformat(timestamp), source=str(payload.get("source", "yfinance")), ) - @staticmethod - def _normalize_provider_price_payload(payload: object, *, expected_symbol: str) -> PriceData: + @classmethod + def _normalize_provider_price_payload(cls, payload: object, *, expected_symbol: str) -> PriceData: if not isinstance(payload, Mapping): raise TypeError("provider price payload must be an object") payload_symbol = str(payload.get("symbol", expected_symbol)).strip().upper() normalized_symbol = expected_symbol.strip().upper() if payload_symbol != normalized_symbol: raise ValueError(f"provider symbol mismatch: {payload_symbol} != {normalized_symbol}") - timestamp = payload.get("timestamp") + timestamp = cls._required_payload_value(payload, "timestamp", context="provider price payload") if not isinstance(timestamp, datetime): raise TypeError("provider timestamp must be a datetime") return PriceData( symbol=payload_symbol, - price=float(payload["price"]), + price=float(cls._required_payload_value(payload, "price", context="provider price payload")), currency=str(payload.get("currency", "USD")), timestamp=timestamp, source=str(payload.get("source", "yfinance")), diff --git a/tests/test_event_comparison_ui.py b/tests/test_event_comparison_ui.py index d45309b..4c98736 100644 --- a/tests/test_event_comparison_ui.py +++ b/tests/test_event_comparison_ui.py @@ -11,6 +11,13 @@ from app.services.event_comparison_ui import EventComparisonFixtureHistoricalPri def test_event_comparison_page_service_accepts_string_and_decimal_boundary_values() -> None: service = EventComparisonPageService() + preview = service.preview_scenario( + preset_slug="gld-jan-2024-selloff", + template_slugs=("protective-put-atm-12m",), + underlying_units="1000.0", + loan_amount=Decimal("68000.0"), + margin_call_ltv="0.75", + ) report = service.run_read_only_comparison( preset_slug="gld-jan-2024-selloff", template_slugs=("protective-put-atm-12m", "protective-put-95pct-12m"), @@ -19,6 +26,9 @@ def test_event_comparison_page_service_accepts_string_and_decimal_boundary_value margin_call_ltv="0.75", ) + assert preview.initial_portfolio.underlying_units == 1000.0 + assert preview.initial_portfolio.loan_amount == 68000.0 + assert preview.initial_portfolio.margin_call_ltv == 0.75 assert report.scenario.initial_portfolio.underlying_units == 1000.0 assert report.scenario.initial_portfolio.loan_amount == 68000.0 assert report.scenario.initial_portfolio.margin_call_ltv == 0.75 diff --git a/tests/test_price_feed.py b/tests/test_price_feed.py index cb24f3f..d2e3aee 100644 --- a/tests/test_price_feed.py +++ b/tests/test_price_feed.py @@ -75,6 +75,35 @@ async def test_price_feed_discards_malformed_cached_payload_and_refetches(monkey assert feed._cache.writes[0][0] == "price:GLD" +@pytest.mark.asyncio +async def test_price_feed_discards_cached_payload_missing_required_price_and_refetches( + monkeypatch: pytest.MonkeyPatch, +) -> None: + feed = PriceFeed() + feed._cache = _CacheStub({"price:GLD": {"symbol": "GLD", "timestamp": "2026-03-26T12:00:00+00:00"}}) + + async def fake_fetch(symbol: str): + return { + "symbol": symbol, + "price": 205.0, + "currency": "USD", + "timestamp": datetime(2026, 3, 26, 12, 3, tzinfo=timezone.utc), + "source": "yfinance", + } + + monkeypatch.setattr(feed, "_fetch_yfinance", fake_fetch) + + data = await feed.get_price("GLD") + + assert data == PriceData( + symbol="GLD", + price=205.0, + currency="USD", + timestamp=datetime(2026, 3, 26, 12, 3, tzinfo=timezone.utc), + source="yfinance", + ) + + @pytest.mark.asyncio async def test_price_feed_rejects_invalid_provider_payload() -> None: feed = PriceFeed()