- Fix return type annotation for get_default_premium_for_product - Add type narrowing for Weight|Money union using _as_money helper - Add isinstance checks before float() calls for object types - Add type guard for Decimal.exponent comparison - Use _unit_typed and _currency_typed properties for type narrowing - Cast option_type to OptionType Literal after validation - Fix provider type hierarchy in backtesting services - Add types-requests to dev dependencies - Remove '|| true' from CI type-check job All 36 mypy errors resolved across 15 files.
528 lines
18 KiB
Python
528 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from datetime import date, timedelta
|
|
from math import isfinite
|
|
from typing import Protocol, cast
|
|
|
|
from app.models.backtest import ProviderRef
|
|
|
|
try:
|
|
import yfinance as yf
|
|
except ImportError: # pragma: no cover - optional in tests
|
|
yf = None
|
|
|
|
from app.core.pricing.black_scholes import BlackScholesInputs, OptionType, black_scholes_price_and_greeks
|
|
from app.models.strategy_template import TemplateLeg
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DailyClosePoint:
|
|
date: date
|
|
close: float
|
|
|
|
def __post_init__(self) -> None:
|
|
if self.close <= 0:
|
|
raise ValueError("close must be positive")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class SyntheticOptionQuote:
|
|
position_id: str
|
|
leg_id: str
|
|
spot: float
|
|
strike: float
|
|
expiry: date
|
|
quantity: float
|
|
mark: float
|
|
|
|
def __post_init__(self) -> None:
|
|
for field_name in ("position_id", "leg_id"):
|
|
value = getattr(self, field_name)
|
|
if not isinstance(value, str) or not value:
|
|
raise ValueError(f"{field_name} is required")
|
|
for field_name in ("spot", "strike", "quantity", "mark"):
|
|
value = getattr(self, field_name)
|
|
if not isinstance(value, (int, float)) or isinstance(value, bool) or not isfinite(float(value)):
|
|
raise TypeError(f"{field_name} must be a finite number")
|
|
if self.spot <= 0:
|
|
raise ValueError("spot must be positive")
|
|
if self.strike <= 0:
|
|
raise ValueError("strike must be positive")
|
|
if self.quantity <= 0:
|
|
raise ValueError("quantity must be positive")
|
|
if self.mark < 0:
|
|
raise ValueError("mark must be non-negative")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class DailyOptionSnapshot:
|
|
contract_key: str
|
|
symbol: str
|
|
snapshot_date: date
|
|
expiry: date
|
|
option_type: str
|
|
strike: float
|
|
mid: float
|
|
|
|
def __post_init__(self) -> None:
|
|
if not self.contract_key:
|
|
raise ValueError("contract_key is required")
|
|
if not self.symbol:
|
|
raise ValueError("symbol is required")
|
|
if self.option_type not in {"put", "call"}:
|
|
raise ValueError("unsupported option_type")
|
|
if self.strike <= 0:
|
|
raise ValueError("strike must be positive")
|
|
if self.mid < 0:
|
|
raise ValueError("mid must be non-negative")
|
|
|
|
|
|
@dataclass
|
|
class HistoricalOptionPosition:
|
|
position_id: str
|
|
leg_id: str
|
|
contract_key: str
|
|
option_type: str
|
|
strike: float
|
|
expiry: date
|
|
quantity: float
|
|
entry_price: float
|
|
current_mark: float
|
|
last_mark_date: date
|
|
source_snapshot_date: date
|
|
|
|
def __post_init__(self) -> None:
|
|
for field_name in ("position_id", "leg_id", "contract_key"):
|
|
value = getattr(self, field_name)
|
|
if not isinstance(value, str) or not value:
|
|
raise ValueError(f"{field_name} is required")
|
|
if self.option_type not in {"put", "call"}:
|
|
raise ValueError("unsupported option_type")
|
|
for field_name in ("strike", "quantity", "entry_price", "current_mark"):
|
|
value = getattr(self, field_name)
|
|
if not isinstance(value, (int, float)) or isinstance(value, bool) or not isfinite(float(value)):
|
|
raise TypeError(f"{field_name} must be a finite number")
|
|
if self.strike <= 0:
|
|
raise ValueError("strike must be positive")
|
|
if self.quantity <= 0:
|
|
raise ValueError("quantity must be positive")
|
|
if self.entry_price < 0:
|
|
raise ValueError("entry_price must be non-negative")
|
|
if self.current_mark < 0:
|
|
raise ValueError("current_mark must be non-negative")
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class HistoricalOptionMark:
|
|
contract_key: str
|
|
mark: float
|
|
source: str
|
|
is_active: bool
|
|
realized_cashflow: float = 0.0
|
|
warning: str | None = None
|
|
|
|
def __post_init__(self) -> None:
|
|
if not self.contract_key:
|
|
raise ValueError("contract_key is required")
|
|
for field_name in ("mark", "realized_cashflow"):
|
|
value = getattr(self, field_name)
|
|
if not isinstance(value, (int, float)) or isinstance(value, bool) or not isfinite(float(value)):
|
|
raise TypeError(f"{field_name} must be a finite number")
|
|
if self.mark < 0:
|
|
raise ValueError("mark must be non-negative")
|
|
if self.realized_cashflow < 0:
|
|
raise ValueError("realized_cashflow must be non-negative")
|
|
|
|
|
|
class HistoricalPriceSource(Protocol):
|
|
def load_daily_closes(self, symbol: str, start_date: date, end_date: date) -> list[DailyClosePoint]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class OptionSnapshotSource(Protocol):
|
|
def load_option_chain(self, symbol: str, snapshot_date: date) -> list[DailyOptionSnapshot]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class BacktestHistoricalProvider(Protocol):
|
|
provider_id: str
|
|
pricing_mode: str
|
|
|
|
def load_history(self, symbol: str, start_date: date, end_date: date) -> list[DailyClosePoint]:
|
|
raise NotImplementedError
|
|
|
|
def validate_provider_ref(self, provider_ref: ProviderRef) -> None:
|
|
raise NotImplementedError
|
|
|
|
def open_position(
|
|
self,
|
|
*,
|
|
symbol: str,
|
|
leg: TemplateLeg,
|
|
position_id: str,
|
|
quantity: float,
|
|
as_of_date: date,
|
|
spot: float,
|
|
trading_days: list[DailyClosePoint],
|
|
) -> HistoricalOptionPosition:
|
|
raise NotImplementedError
|
|
|
|
def mark_position(
|
|
self,
|
|
position: HistoricalOptionPosition,
|
|
*,
|
|
symbol: str,
|
|
as_of_date: date,
|
|
spot: float,
|
|
) -> HistoricalOptionMark:
|
|
raise NotImplementedError
|
|
|
|
|
|
class YFinanceHistoricalPriceSource:
|
|
@staticmethod
|
|
def _normalize_daily_close_row(*, row_date: object, close: object) -> DailyClosePoint | None:
|
|
if close is None:
|
|
return None
|
|
if not hasattr(row_date, "date"):
|
|
raise TypeError(f"historical row date must support .date(), got {type(row_date)!r}")
|
|
if isinstance(close, (int, float)):
|
|
normalized_close = float(close)
|
|
else:
|
|
raise TypeError(f"close must be numeric, got {type(close)!r}")
|
|
if not isfinite(normalized_close):
|
|
raise ValueError("historical close must be finite")
|
|
return DailyClosePoint(date=row_date.date(), close=normalized_close)
|
|
|
|
def load_daily_closes(self, symbol: str, start_date: date, end_date: date) -> list[DailyClosePoint]:
|
|
if yf is None:
|
|
raise RuntimeError("yfinance is required to load historical backtest prices")
|
|
ticker = yf.Ticker(symbol)
|
|
inclusive_end_date = end_date + timedelta(days=1)
|
|
history = ticker.history(start=start_date.isoformat(), end=inclusive_end_date.isoformat(), interval="1d")
|
|
rows: list[DailyClosePoint] = []
|
|
for index, row in history.iterrows():
|
|
point = self._normalize_daily_close_row(row_date=index, close=row.get("Close"))
|
|
if point is not None:
|
|
rows.append(point)
|
|
return rows
|
|
|
|
|
|
class SyntheticHistoricalProvider:
|
|
provider_id = "synthetic_v1"
|
|
pricing_mode = "synthetic_bs_mid"
|
|
|
|
def __init__(
|
|
self,
|
|
source: HistoricalPriceSource | None = None,
|
|
implied_volatility: float = 0.16,
|
|
risk_free_rate: float = 0.045,
|
|
) -> None:
|
|
if implied_volatility <= 0:
|
|
raise ValueError("implied_volatility must be positive")
|
|
self.source = source or YFinanceHistoricalPriceSource()
|
|
self.implied_volatility = implied_volatility
|
|
self.risk_free_rate = risk_free_rate
|
|
|
|
def load_history(self, symbol: str, start_date: date, end_date: date) -> list[DailyClosePoint]:
|
|
rows = self.source.load_daily_closes(symbol, start_date, end_date)
|
|
filtered = [row for row in rows if start_date <= row.date <= end_date]
|
|
return sorted(filtered, key=lambda row: row.date)
|
|
|
|
def validate_provider_ref(self, provider_ref: ProviderRef) -> None:
|
|
if provider_ref.provider_id != self.provider_id or provider_ref.pricing_mode != self.pricing_mode:
|
|
raise ValueError(
|
|
"Unsupported provider/pricing combination for synthetic MVP engine: "
|
|
f"{provider_ref.provider_id}/{provider_ref.pricing_mode}"
|
|
)
|
|
|
|
def resolve_expiry(self, trading_days: list[DailyClosePoint], as_of_date: date, target_expiry_days: int) -> date:
|
|
target_date = date.fromordinal(as_of_date.toordinal() + target_expiry_days)
|
|
for day in trading_days:
|
|
if day.date >= target_date:
|
|
return day.date
|
|
return target_date
|
|
|
|
def open_position(
|
|
self,
|
|
*,
|
|
symbol: str,
|
|
leg: TemplateLeg,
|
|
position_id: str,
|
|
quantity: float,
|
|
as_of_date: date,
|
|
spot: float,
|
|
trading_days: list[DailyClosePoint],
|
|
) -> HistoricalOptionPosition:
|
|
expiry = self.resolve_expiry(trading_days, as_of_date, leg.target_expiry_days)
|
|
strike = spot * leg.strike_rule.value
|
|
quote = self.price_option(
|
|
position_id=position_id,
|
|
leg=leg,
|
|
spot=spot,
|
|
strike=strike,
|
|
expiry=expiry,
|
|
quantity=quantity,
|
|
valuation_date=as_of_date,
|
|
)
|
|
return HistoricalOptionPosition(
|
|
position_id=position_id,
|
|
leg_id=leg.leg_id,
|
|
contract_key=f"{symbol}-{expiry.isoformat()}-{leg.option_type}-{strike:.4f}",
|
|
option_type=leg.option_type,
|
|
strike=strike,
|
|
expiry=expiry,
|
|
quantity=quantity,
|
|
entry_price=quote.mark,
|
|
current_mark=quote.mark,
|
|
last_mark_date=as_of_date,
|
|
source_snapshot_date=as_of_date,
|
|
)
|
|
|
|
def mark_position(
|
|
self,
|
|
position: HistoricalOptionPosition,
|
|
*,
|
|
symbol: str,
|
|
as_of_date: date,
|
|
spot: float,
|
|
) -> HistoricalOptionMark:
|
|
if as_of_date >= position.expiry:
|
|
intrinsic = self.intrinsic_value(option_type=position.option_type, spot=spot, strike=position.strike)
|
|
return HistoricalOptionMark(
|
|
contract_key=position.contract_key,
|
|
mark=0.0,
|
|
source="intrinsic_expiry",
|
|
is_active=False,
|
|
realized_cashflow=intrinsic * position.quantity,
|
|
)
|
|
|
|
quote = self.price_option_by_type(
|
|
position_id=position.position_id,
|
|
leg_id=position.leg_id,
|
|
option_type=position.option_type,
|
|
spot=spot,
|
|
strike=position.strike,
|
|
expiry=position.expiry,
|
|
quantity=position.quantity,
|
|
valuation_date=as_of_date,
|
|
)
|
|
position.current_mark = quote.mark
|
|
position.last_mark_date = as_of_date
|
|
return HistoricalOptionMark(
|
|
contract_key=position.contract_key,
|
|
mark=quote.mark,
|
|
source="synthetic_bs_mid",
|
|
is_active=True,
|
|
)
|
|
|
|
def price_option(
|
|
self,
|
|
*,
|
|
position_id: str,
|
|
leg: TemplateLeg,
|
|
spot: float,
|
|
strike: float,
|
|
expiry: date,
|
|
quantity: float,
|
|
valuation_date: date,
|
|
) -> SyntheticOptionQuote:
|
|
return self.price_option_by_type(
|
|
position_id=position_id,
|
|
leg_id=leg.leg_id,
|
|
option_type=leg.option_type,
|
|
spot=spot,
|
|
strike=strike,
|
|
expiry=expiry,
|
|
quantity=quantity,
|
|
valuation_date=valuation_date,
|
|
)
|
|
|
|
def price_option_by_type(
|
|
self,
|
|
*,
|
|
position_id: str,
|
|
leg_id: str,
|
|
option_type: str,
|
|
spot: float,
|
|
strike: float,
|
|
expiry: date,
|
|
quantity: float,
|
|
valuation_date: date,
|
|
) -> SyntheticOptionQuote:
|
|
remaining_days = max(1, expiry.toordinal() - valuation_date.toordinal())
|
|
mark = black_scholes_price_and_greeks(
|
|
BlackScholesInputs(
|
|
spot=spot,
|
|
strike=strike,
|
|
time_to_expiry=remaining_days / 365.0,
|
|
risk_free_rate=self.risk_free_rate,
|
|
volatility=self.implied_volatility,
|
|
option_type=cast(OptionType, option_type),
|
|
valuation_date=valuation_date,
|
|
)
|
|
).price
|
|
return SyntheticOptionQuote(
|
|
position_id=position_id,
|
|
leg_id=leg_id,
|
|
spot=spot,
|
|
strike=strike,
|
|
expiry=expiry,
|
|
quantity=quantity,
|
|
mark=mark,
|
|
)
|
|
|
|
@staticmethod
|
|
def intrinsic_value(*, option_type: str, spot: float, strike: float) -> float:
|
|
if option_type == "put":
|
|
return max(strike - spot, 0.0)
|
|
if option_type == "call":
|
|
return max(spot - strike, 0.0)
|
|
raise ValueError(f"Unsupported option type: {option_type}")
|
|
|
|
|
|
class EmptyOptionSnapshotSource:
|
|
def load_option_chain(self, symbol: str, snapshot_date: date) -> list[DailyOptionSnapshot]:
|
|
return []
|
|
|
|
|
|
class DailyOptionsSnapshotProvider:
|
|
provider_id = "daily_snapshots_v1"
|
|
pricing_mode = "snapshot_mid"
|
|
|
|
def __init__(
|
|
self,
|
|
price_source: HistoricalPriceSource | None = None,
|
|
snapshot_source: OptionSnapshotSource | None = None,
|
|
) -> None:
|
|
self.price_source = price_source or YFinanceHistoricalPriceSource()
|
|
self.snapshot_source = snapshot_source or EmptyOptionSnapshotSource()
|
|
|
|
def load_history(self, symbol: str, start_date: date, end_date: date) -> list[DailyClosePoint]:
|
|
rows = self.price_source.load_daily_closes(symbol, start_date, end_date)
|
|
filtered = [row for row in rows if start_date <= row.date <= end_date]
|
|
return sorted(filtered, key=lambda row: row.date)
|
|
|
|
def validate_provider_ref(self, provider_ref: ProviderRef) -> None:
|
|
if provider_ref.provider_id != self.provider_id or provider_ref.pricing_mode != self.pricing_mode:
|
|
raise ValueError(
|
|
"Unsupported provider/pricing combination for historical snapshot engine: "
|
|
f"{provider_ref.provider_id}/{provider_ref.pricing_mode}"
|
|
)
|
|
|
|
def open_position(
|
|
self,
|
|
*,
|
|
symbol: str,
|
|
leg: TemplateLeg,
|
|
position_id: str,
|
|
quantity: float,
|
|
as_of_date: date,
|
|
spot: float,
|
|
trading_days: list[DailyClosePoint],
|
|
) -> HistoricalOptionPosition:
|
|
del trading_days # selection must use only the entry-day snapshot, not future state
|
|
selected_snapshot = self._select_entry_snapshot(symbol=symbol, leg=leg, as_of_date=as_of_date, spot=spot)
|
|
return HistoricalOptionPosition(
|
|
position_id=position_id,
|
|
leg_id=leg.leg_id,
|
|
contract_key=selected_snapshot.contract_key,
|
|
option_type=selected_snapshot.option_type,
|
|
strike=selected_snapshot.strike,
|
|
expiry=selected_snapshot.expiry,
|
|
quantity=quantity,
|
|
entry_price=selected_snapshot.mid,
|
|
current_mark=selected_snapshot.mid,
|
|
last_mark_date=as_of_date,
|
|
source_snapshot_date=as_of_date,
|
|
)
|
|
|
|
def mark_position(
|
|
self,
|
|
position: HistoricalOptionPosition,
|
|
*,
|
|
symbol: str,
|
|
as_of_date: date,
|
|
spot: float,
|
|
) -> HistoricalOptionMark:
|
|
if as_of_date >= position.expiry:
|
|
intrinsic = SyntheticHistoricalProvider.intrinsic_value(
|
|
option_type=position.option_type,
|
|
spot=spot,
|
|
strike=position.strike,
|
|
)
|
|
return HistoricalOptionMark(
|
|
contract_key=position.contract_key,
|
|
mark=0.0,
|
|
source="intrinsic_expiry",
|
|
is_active=False,
|
|
realized_cashflow=intrinsic * position.quantity,
|
|
)
|
|
|
|
exact_snapshot = next(
|
|
(
|
|
snapshot
|
|
for snapshot in self.snapshot_source.load_option_chain(symbol, as_of_date)
|
|
if snapshot.contract_key == position.contract_key
|
|
),
|
|
None,
|
|
)
|
|
if exact_snapshot is not None:
|
|
position.current_mark = exact_snapshot.mid
|
|
position.last_mark_date = as_of_date
|
|
return HistoricalOptionMark(
|
|
contract_key=position.contract_key,
|
|
mark=exact_snapshot.mid,
|
|
source="snapshot_mid",
|
|
is_active=True,
|
|
)
|
|
|
|
if position.current_mark < 0:
|
|
raise ValueError(f"Missing historical mark for {position.contract_key} on {as_of_date.isoformat()}")
|
|
return HistoricalOptionMark(
|
|
contract_key=position.contract_key,
|
|
mark=position.current_mark,
|
|
source="carry_forward",
|
|
is_active=True,
|
|
warning=(
|
|
f"Missing historical mark for {position.contract_key} on {as_of_date.isoformat()}; "
|
|
f"carrying forward prior mark from {position.last_mark_date.isoformat()}."
|
|
),
|
|
)
|
|
|
|
def _select_entry_snapshot(
|
|
self,
|
|
*,
|
|
symbol: str,
|
|
leg: TemplateLeg,
|
|
as_of_date: date,
|
|
spot: float,
|
|
) -> DailyOptionSnapshot:
|
|
target_expiry = date.fromordinal(as_of_date.toordinal() + leg.target_expiry_days)
|
|
target_strike = spot * leg.strike_rule.value
|
|
chain = [
|
|
snapshot
|
|
for snapshot in self.snapshot_source.load_option_chain(symbol, as_of_date)
|
|
if snapshot.symbol.strip().upper() == symbol.strip().upper() and snapshot.option_type == leg.option_type
|
|
]
|
|
eligible_expiries = [snapshot for snapshot in chain if snapshot.expiry >= target_expiry]
|
|
if not eligible_expiries:
|
|
raise ValueError(
|
|
f"No eligible historical option snapshots found for {symbol} on {as_of_date.isoformat()} "
|
|
f"at or beyond target expiry {target_expiry.isoformat()}"
|
|
)
|
|
selected_expiry = min(
|
|
eligible_expiries,
|
|
key=lambda snapshot: ((snapshot.expiry - target_expiry).days, snapshot.expiry),
|
|
).expiry
|
|
expiry_matches = [snapshot for snapshot in eligible_expiries if snapshot.expiry == selected_expiry]
|
|
return min(
|
|
expiry_matches, key=lambda snapshot: self._strike_sort_key(snapshot.strike, target_strike, leg.option_type)
|
|
)
|
|
|
|
@staticmethod
|
|
def _strike_sort_key(strike: float, target_strike: float, option_type: str) -> tuple[float, float]:
|
|
if option_type == "put":
|
|
return (abs(strike - target_strike), -strike)
|
|
return (abs(strike - target_strike), strike)
|