feat(PORTFOLIO-001): add position-level portfolio entries

This commit is contained in:
Bu5hm4nn
2026-03-28 21:29:30 +01:00
parent 447f4bbd0d
commit 1a39956757
6 changed files with 1041 additions and 7 deletions

View File

@@ -3,6 +3,7 @@
from .event_preset import EventPreset, EventScenarioOverrides
from .option import Greeks, OptionContract, OptionMoneyness
from .portfolio import LombardPortfolio
from .position import Position, create_position
from .strategy import HedgingStrategy, ScenarioResult, StrategyType
from .strategy_template import EntryPolicy, RollPolicy, StrategyTemplate, TemplateLeg
@@ -14,10 +15,12 @@ __all__ = [
"LombardPortfolio",
"OptionContract",
"OptionMoneyness",
"Position",
"ScenarioResult",
"StrategyType",
"StrategyTemplate",
"TemplateLeg",
"RollPolicy",
"EntryPolicy",
"create_position",
]

View File

@@ -4,11 +4,14 @@ from __future__ import annotations
import json
import os
from dataclasses import dataclass
from dataclasses import dataclass, field
from datetime import date
from decimal import Decimal
from pathlib import Path
from typing import Any
from app.models.position import Position, create_position
_DEFAULT_GOLD_VALUE = 215_000.0
_DEFAULT_ENTRY_PRICE = 2_150.0
_LEGACY_DEFAULT_ENTRY_PRICE = 215.0
@@ -93,6 +96,7 @@ class PortfolioConfig:
margin_threshold: LTV threshold for margin call (default 0.75)
monthly_budget: Approved monthly hedge budget
ltv_warning: LTV warning level for alerts (default 0.70)
positions: List of position entries (multi-position support)
"""
gold_value: float | None = None
@@ -117,11 +121,40 @@ class PortfolioConfig:
spot_drawdown: float = 7.5
email_alerts: bool = False
# Multi-position support
positions: list[Position] = field(default_factory=list)
def __post_init__(self) -> None:
"""Normalize entry basis fields and validate configuration."""
self._normalize_entry_basis()
self.validate()
def migrate_to_positions_if_needed(self) -> None:
"""Migrate legacy single-entry portfolios to multi-position format.
Call this after loading from persistence to migrate legacy configs.
If positions list is empty but gold_ounces exists, create one Position
representing the legacy single entry.
"""
if self.positions:
# Already has positions, no migration needed
return
if self.gold_ounces is None or self.entry_price is None:
return
# Create a single position from legacy fields
position = create_position(
underlying=self.underlying,
quantity=Decimal(str(self.gold_ounces)),
unit="oz",
entry_price=Decimal(str(self.entry_price)),
entry_date=date.today(),
entry_basis_mode=self.entry_basis_mode,
)
# PortfolioConfig is not frozen, so we can set directly
self.positions = [position]
def _normalize_entry_basis(self) -> None:
"""Resolve user input into canonical weight + entry price representation."""
if self.entry_basis_mode not in {"value_price", "weight"}:
@@ -157,6 +190,55 @@ class PortfolioConfig:
raise ValueError("Gold value and weight contradict each other")
self.gold_value = derived_gold_value
def _migrate_legacy_to_positions(self) -> None:
"""Migrate legacy single-entry portfolios to multi-position format.
If positions list is empty but gold_ounces exists, create one Position
representing the legacy single entry.
"""
if self.positions:
# Already has positions, no migration needed
return
if self.gold_ounces is None or self.entry_price is None:
return
# Create a single position from legacy fields
position = create_position(
underlying=self.underlying,
quantity=Decimal(str(self.gold_ounces)),
unit="oz",
entry_price=Decimal(str(self.entry_price)),
entry_date=date.today(),
entry_basis_mode=self.entry_basis_mode,
)
# PortfolioConfig is not frozen, so we can set directly
self.positions = [position]
def _sync_legacy_fields_from_positions(self) -> None:
"""Sync legacy gold_ounces, entry_price, gold_value from positions.
For backward compatibility, compute aggregate values from positions list.
"""
if not self.positions:
return
# For now, assume homogeneous positions (same underlying and unit)
# Sum quantities and compute weighted average entry price
total_quantity = Decimal("0")
total_value = Decimal("0")
for pos in self.positions:
if pos.unit == "oz":
total_quantity += pos.quantity
total_value += pos.entry_value
if total_quantity > 0:
avg_entry_price = total_value / total_quantity
self.gold_ounces = float(total_quantity)
self.entry_price = float(avg_entry_price)
self.gold_value = float(total_value)
def validate(self) -> None:
"""Validate configuration values."""
assert self.gold_value is not None
@@ -214,7 +296,9 @@ class PortfolioConfig:
assert self.gold_value is not None
assert self.entry_price is not None
assert self.gold_ounces is not None
return {
# Sync legacy fields from positions before serializing
self._sync_legacy_fields_from_positions()
result = {
"gold_value": self.gold_value,
"entry_price": self.entry_price,
"gold_ounces": self.gold_ounces,
@@ -231,11 +315,31 @@ class PortfolioConfig:
"spot_drawdown": self.spot_drawdown,
"email_alerts": self.email_alerts,
}
# Include positions if any exist
if self.positions:
result["positions"] = [pos.to_dict() for pos in self.positions]
return result
@classmethod
def from_dict(cls, data: dict[str, Any]) -> PortfolioConfig:
"""Create configuration from dictionary."""
return cls(**{k: v for k, v in data.items() if k in cls.__dataclass_fields__})
# Extract positions if present (may already be Position objects from deserialization)
positions_data = data.pop("positions", None)
config_data = {k: v for k, v in data.items() if k in cls.__dataclass_fields__}
# Create config without positions first (will be set in __post_init__)
config = cls(**config_data)
# Set positions after initialization
if positions_data:
if positions_data and isinstance(positions_data[0], Position):
# Already deserialized by _deserialize_value
positions = positions_data
else:
positions = [Position.from_dict(p) for p in positions_data]
config.positions = positions
return config
def _coerce_persisted_decimal(value: Any) -> Decimal:
@@ -293,6 +397,7 @@ class PortfolioRepository:
"volatility_spike",
"spot_drawdown",
"email_alerts",
"positions", # multi-position support
}
def __init__(self, config_path: Path | None = None) -> None:
@@ -329,11 +434,42 @@ class PortfolioRepository:
@classmethod
def _to_persistence_payload(cls, config: PortfolioConfig) -> dict[str, Any]:
# Serialize positions separately before calling to_dict
positions_data = [pos.to_dict() for pos in config.positions] if config.positions else []
config_dict = config.to_dict()
# Remove positions from config_dict since we handle it separately
config_dict.pop("positions", None)
return {
"schema_version": cls.SCHEMA_VERSION,
"portfolio": {key: cls._serialize_value(key, value) for key, value in config.to_dict().items()},
"portfolio": {
**{key: cls._serialize_value(key, value) for key, value in config_dict.items()},
**({"positions": positions_data} if positions_data else {}),
},
}
@classmethod
def _serialize_value(cls, key: str, value: Any) -> Any:
if key in cls._MONEY_FIELDS:
return {"value": cls._decimal_to_string(value), "currency": cls.PERSISTENCE_CURRENCY}
if key in cls._WEIGHT_FIELDS:
return {"value": cls._decimal_to_string(value), "unit": cls.PERSISTENCE_WEIGHT_UNIT}
if key in cls._PRICE_PER_WEIGHT_FIELDS:
return {
"value": cls._decimal_to_string(value),
"currency": cls.PERSISTENCE_CURRENCY,
"per_weight_unit": cls.PERSISTENCE_WEIGHT_UNIT,
}
if key in cls._RATIO_FIELDS:
return {"value": cls._decimal_to_string(value), "unit": "ratio"}
if key in cls._PERCENT_FIELDS:
return {"value": cls._decimal_to_string(value), "unit": "percent"}
if key in cls._INTEGER_FIELDS:
return cls._serialize_integer(value, unit="seconds")
if key == "positions" and isinstance(value, list):
# Already serialized as dicts from _to_persistence_payload
return value
return value
@classmethod
def _config_from_payload(cls, data: dict[str, Any]) -> PortfolioConfig:
if not isinstance(data, dict):
@@ -347,11 +483,15 @@ class PortfolioRepository:
cls._validate_portfolio_fields(portfolio)
deserialized = cls._deserialize_portfolio_payload(portfolio)
upgraded = cls._upgrade_legacy_default_workspace(deserialized)
return PortfolioConfig.from_dict(upgraded)
config = PortfolioConfig.from_dict(upgraded)
# Migrate legacy configs without positions to single position
config.migrate_to_positions_if_needed()
return config
# Fields that must be present in persisted payloads
# (underlying is optional with default "GLD")
_REQUIRED_FIELDS = _PERSISTED_FIELDS - {"underlying"}
# (positions is optional - legacy configs won't have it)
_REQUIRED_FIELDS = (_PERSISTED_FIELDS - {"underlying"}) - {"positions"}
@classmethod
def _validate_portfolio_fields(cls, payload: dict[str, Any]) -> None:
@@ -421,6 +561,9 @@ class PortfolioRepository:
return {"value": cls._decimal_to_string(value), "unit": "percent"}
if key in cls._INTEGER_FIELDS:
return cls._serialize_integer(value, unit="seconds")
if key == "positions" and isinstance(value, list):
# Already serialized as dicts from _to_persistence_payload
return value
return value
@classmethod
@@ -437,6 +580,8 @@ class PortfolioRepository:
return float(cls._deserialize_percent(value))
if key in cls._INTEGER_FIELDS:
return cls._deserialize_integer(value, expected_unit="seconds")
if key == "positions" and isinstance(value, list):
return [Position.from_dict(p) for p in value]
return value
@classmethod

118
app/models/position.py Normal file
View File

@@ -0,0 +1,118 @@
"""Position model for multi-position portfolio entries."""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import UTC, date, datetime
from decimal import Decimal
from typing import Any
from uuid import UUID, uuid4
@dataclass(frozen=True)
class Position:
"""A single position entry in a portfolio.
Attributes:
id: Unique identifier for this position
underlying: Underlying instrument symbol (e.g., "GLD", "GC=F", "XAU")
quantity: Number of units held (shares, contracts, grams, or oz)
unit: Unit of quantity (e.g., "shares", "contracts", "g", "oz")
entry_price: Price per unit at purchase (in USD)
entry_date: Date of position entry (for historical conversion lookups)
entry_basis_mode: Entry basis mode ("weight" or "value_price")
notes: Optional notes about this position
created_at: Timestamp when position was created
"""
id: UUID
underlying: str
quantity: Decimal
unit: str
entry_price: Decimal
entry_date: date
entry_basis_mode: str = "weight"
notes: str = ""
created_at: datetime = field(default_factory=lambda: datetime.now(UTC))
def __post_init__(self) -> None:
"""Validate position fields."""
if not self.underlying:
raise ValueError("underlying must be non-empty")
# Use object.__getattribute__ because Decimal comparison with frozen dataclass
quantity = object.__getattribute__(self, "quantity")
entry_price = object.__getattribute__(self, "entry_price")
if quantity <= 0:
raise ValueError("quantity must be positive")
if not self.unit:
raise ValueError("unit must be non-empty")
if entry_price <= 0:
raise ValueError("entry_price must be positive")
if self.entry_basis_mode not in {"weight", "value_price"}:
raise ValueError("entry_basis_mode must be 'weight' or 'value_price'")
@property
def entry_value(self) -> Decimal:
"""Calculate total entry value (quantity × entry_price)."""
return self.quantity * self.entry_price
def to_dict(self) -> dict[str, Any]:
"""Convert position to dictionary for serialization."""
return {
"id": str(self.id),
"underlying": self.underlying,
"quantity": str(self.quantity),
"unit": self.unit,
"entry_price": str(self.entry_price),
"entry_date": self.entry_date.isoformat(),
"entry_basis_mode": self.entry_basis_mode,
"notes": self.notes,
"created_at": self.created_at.isoformat(),
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Position:
"""Create position from dictionary."""
return cls(
id=UUID(data["id"]) if isinstance(data["id"], str) else data["id"],
underlying=data["underlying"],
quantity=Decimal(data["quantity"]),
unit=data["unit"],
entry_price=Decimal(data["entry_price"]),
entry_date=date.fromisoformat(data["entry_date"]),
entry_basis_mode=data.get("entry_basis_mode", "weight"),
notes=data.get("notes", ""),
created_at=datetime.fromisoformat(data["created_at"]) if "created_at" in data else datetime.now(UTC),
)
def create_position(
underlying: str = "GLD",
quantity: Decimal | None = None,
unit: str = "oz",
entry_price: Decimal | None = None,
entry_date: date | None = None,
entry_basis_mode: str = "weight",
notes: str = "",
) -> Position:
"""Create a new position with sensible defaults.
Args:
underlying: Underlying instrument (default: "GLD")
quantity: Position quantity (default: Decimal("100"))
unit: Unit of quantity (default: "oz")
entry_price: Entry price per unit (default: Decimal("2150"))
entry_date: Entry date (default: today)
entry_basis_mode: Entry basis mode (default: "weight")
notes: Optional notes
"""
return Position(
id=uuid4(),
underlying=underlying,
quantity=quantity if quantity is not None else Decimal("100"),
unit=unit,
entry_price=entry_price if entry_price is not None else Decimal("2150"),
entry_date=entry_date or date.today(),
entry_basis_mode=entry_basis_mode,
notes=notes,
)

View File

@@ -2,9 +2,10 @@ from __future__ import annotations
import re
from pathlib import Path
from uuid import uuid4
from uuid import UUID, uuid4
from app.models.portfolio import PortfolioConfig, PortfolioRepository, build_default_portfolio_config
from app.models.position import Position
WORKSPACE_COOKIE = "workspace_id"
_WORKSPACE_ID_RE = re.compile(
@@ -63,6 +64,69 @@ class WorkspaceRepository:
raise ValueError("workspace_id must be a UUID4 string")
PortfolioRepository(self._portfolio_path(workspace_id)).save(config)
def add_position(self, workspace_id: str, position: Position) -> None:
"""Add a position to the workspace portfolio."""
if not self.is_valid_workspace_id(workspace_id):
raise ValueError("workspace_id must be a UUID4 string")
config = self.load_portfolio_config(workspace_id)
# Use object.__setattr__ because positions is in a frozen dataclass
object.__setattr__(config, "positions", list(config.positions) + [position])
self.save_portfolio_config(workspace_id, config)
def remove_position(self, workspace_id: str, position_id: UUID) -> None:
"""Remove a position from the workspace portfolio."""
if not self.is_valid_workspace_id(workspace_id):
raise ValueError("workspace_id must be a UUID4 string")
config = self.load_portfolio_config(workspace_id)
updated_positions = [p for p in config.positions if p.id != position_id]
object.__setattr__(config, "positions", updated_positions)
self.save_portfolio_config(workspace_id, config)
def update_position(
self,
workspace_id: str,
position_id: UUID,
updates: dict[str, object],
) -> None:
"""Update a position's fields."""
if not self.is_valid_workspace_id(workspace_id):
raise ValueError("workspace_id must be a UUID4 string")
config = self.load_portfolio_config(workspace_id)
updated_positions = []
for pos in config.positions:
if pos.id == position_id:
# Create updated position (Position is frozen, so create new instance)
update_kwargs: dict[str, object] = {}
for key, value in updates.items():
if key in {"id", "created_at"}:
continue # Skip immutable fields
update_kwargs[key] = value
# Use dataclass replace-like pattern
pos_dict = pos.to_dict()
pos_dict.update(update_kwargs)
updated_positions.append(Position.from_dict(pos_dict))
else:
updated_positions.append(pos)
object.__setattr__(config, "positions", updated_positions)
self.save_portfolio_config(workspace_id, config)
def get_position(self, workspace_id: str, position_id: UUID) -> Position | None:
"""Get a specific position by ID."""
if not self.is_valid_workspace_id(workspace_id):
raise ValueError("workspace_id must be a UUID4 string")
config = self.load_portfolio_config(workspace_id)
for pos in config.positions:
if pos.id == position_id:
return pos
return None
def list_positions(self, workspace_id: str) -> list[Position]:
"""List all positions in the workspace portfolio."""
if not self.is_valid_workspace_id(workspace_id):
raise ValueError("workspace_id must be a UUID4 string")
config = self.load_portfolio_config(workspace_id)
return list(config.positions)
def _portfolio_path(self, workspace_id: str) -> Path:
return self.base_path / workspace_id / "portfolio_config.json"

View File

@@ -1,11 +1,15 @@
from __future__ import annotations
import logging
from datetime import date
from decimal import Decimal
from uuid import uuid4
from fastapi.responses import RedirectResponse
from nicegui import ui
from app.models.portfolio import PortfolioConfig
from app.models.position import Position
from app.models.workspace import get_workspace_repository
from app.pages.common import dashboard_page, split_page_panes
from app.services.alerts import AlertService, build_portfolio_alert_context
@@ -270,6 +274,154 @@ def settings_page(workspace_id: str) -> None:
step=1,
).classes("w-full")
# Position Management Card
with ui.card().classes(
"w-full rounded-2xl border border-slate-200 bg-white shadow-sm dark:border-slate-800 dark:bg-slate-900"
):
ui.label("Portfolio Positions").classes("text-lg font-semibold text-slate-900 dark:text-slate-100")
ui.label(
"Manage individual position entries. Each position tracks its own entry date and price."
).classes("text-sm text-slate-500 dark:text-slate-400")
# Position list container
position_list_container = ui.column().classes("w-full gap-2 mt-3")
# Add position form (hidden by default)
with (
ui.dialog() as add_position_dialog,
ui.card().classes(
"w-full max-w-md rounded-2xl border border-slate-200 bg-white p-6 shadow-lg dark:border-slate-800 dark:bg-slate-900"
),
):
ui.label("Add New Position").classes(
"text-lg font-semibold text-slate-900 dark:text-slate-100 mb-4"
)
pos_underlying = ui.select(
{
"GLD": "SPDR Gold Shares ETF",
"XAU": "Physical Gold (oz)",
"GC=F": "Gold Futures",
},
value="GLD",
label="Underlying",
).classes("w-full")
pos_quantity = ui.number(
"Quantity",
value=100.0,
min=0.0001,
step=0.01,
).classes("w-full")
pos_unit = ui.select(
{"oz": "Troy Ounces", "shares": "Shares", "g": "Grams", "contracts": "Contracts"},
value="oz",
label="Unit",
).classes("w-full")
pos_entry_price = ui.number(
"Entry Price ($/unit)",
value=2150.0,
min=0.01,
step=0.01,
).classes("w-full")
with ui.row().classes("w-full items-center gap-2"):
ui.label("Entry Date").classes("text-sm font-medium")
pos_entry_date = (
ui.date(
value=date.today().isoformat(),
)
.classes("w-full")
.props("stack-label")
)
pos_notes = ui.textarea(
label="Notes (optional)",
placeholder="Add notes about this position...",
).classes("w-full")
with ui.row().classes("w-full gap-3 mt-4"):
ui.button("Cancel", on_click=lambda: add_position_dialog.close()).props("outline")
ui.button("Add Position", on_click=lambda: add_position_from_form()).props("color=primary")
def add_position_from_form() -> None:
"""Add a new position from the form."""
try:
new_position = Position(
id=uuid4(),
underlying=str(pos_underlying.value),
quantity=Decimal(str(pos_quantity.value)),
unit=str(pos_unit.value),
entry_price=Decimal(str(pos_entry_price.value)),
entry_date=date.fromisoformat(str(pos_entry_date.value)),
entry_basis_mode="weight",
notes=str(pos_notes.value or ""),
)
workspace_repo.add_position(workspace_id, new_position)
add_position_dialog.close()
render_positions()
ui.notify("Position added successfully", color="positive")
except Exception as e:
logger.exception("Failed to add position")
ui.notify(f"Failed to add position: {e}", color="negative")
def render_positions() -> None:
"""Render the list of positions."""
position_list_container.clear()
positions = workspace_repo.list_positions(workspace_id)
if not positions:
with position_list_container:
ui.label("No positions yet. Click 'Add Position' to create one.").classes(
"text-sm text-slate-500 dark:text-slate-400 italic"
)
return
for pos in positions:
with ui.card().classes(
"w-full rounded-lg border border-slate-200 bg-slate-50 p-3 dark:border-slate-700 dark:bg-slate-800"
):
with ui.row().classes("w-full items-start justify-between gap-3"):
with ui.column().classes("gap-1"):
ui.label(f"{pos.underlying} · {float(pos.quantity):,.4f} {pos.unit}").classes(
"text-sm font-medium text-slate-900 dark:text-slate-100"
)
ui.label(
f"Entry: ${float(pos.entry_price):,.2f}/{pos.unit} · Date: {pos.entry_date}"
).classes("text-xs text-slate-500 dark:text-slate-400")
if pos.notes:
ui.label(pos.notes).classes("text-xs text-slate-500 dark:text-slate-400 italic")
ui.label(f"Value: ${float(pos.entry_value):,.2f}").classes(
"text-xs font-semibold text-emerald-600 dark:text-emerald-400"
)
with ui.row().classes("gap-1"):
ui.button(
icon="delete",
on_click=lambda p=pos: remove_position(p.id),
).props(
"flat dense color=negative size=sm"
).classes("self-start")
def remove_position(position_id) -> None:
"""Remove a position."""
try:
workspace_repo.remove_position(workspace_id, position_id)
render_positions()
ui.notify("Position removed", color="positive")
except Exception as e:
logger.exception("Failed to remove position")
ui.notify(f"Failed to remove position: {e}", color="negative")
with ui.row().classes("w-full mt-3"):
ui.button("Add Position", icon="add", on_click=lambda: add_position_dialog.open()).props(
"color=primary"
)
# Initial render
render_positions()
with ui.card().classes(
"w-full rounded-2xl border border-slate-200 bg-white shadow-sm dark:border-slate-800 dark:bg-slate-900"
):

552
tests/test_position.py Normal file
View File

@@ -0,0 +1,552 @@
"""Tests for position model and multi-position portfolio support."""
from __future__ import annotations
import json
from datetime import date
from decimal import Decimal
from uuid import uuid4
import pytest
from app.models.portfolio import PortfolioConfig, PortfolioRepository
from app.models.position import Position, create_position
from app.models.workspace import WorkspaceRepository
class TestPositionModel:
"""Test Position model creation and validation."""
def test_create_position_with_defaults(self) -> None:
"""Test creating a position with default values."""
pos = create_position()
assert pos.underlying == "GLD"
assert pos.quantity == Decimal("100")
assert pos.unit == "oz"
assert pos.entry_price == Decimal("2150")
assert pos.entry_date == date.today()
assert pos.entry_basis_mode == "weight"
assert pos.notes == ""
def test_create_position_with_custom_values(self) -> None:
"""Test creating a position with custom values."""
custom_date = date(2025, 6, 15)
pos = create_position(
underlying="XAU",
quantity=Decimal("50"),
unit="oz",
entry_price=Decimal("2000"),
entry_date=custom_date,
notes="Test position",
)
assert pos.underlying == "XAU"
assert pos.quantity == Decimal("50")
assert pos.entry_price == Decimal("2000")
assert pos.entry_date == custom_date
assert pos.notes == "Test position"
def test_position_entry_value(self) -> None:
"""Test position entry value calculation."""
pos = create_position(
quantity=Decimal("100"),
entry_price=Decimal("2150"),
)
assert pos.entry_value == Decimal("215000")
def test_position_serialization(self) -> None:
"""Test position to_dict serialization."""
pos = create_position(
underlying="GLD",
quantity=Decimal("100"),
entry_price=Decimal("2150"),
)
data = pos.to_dict()
assert data["underlying"] == "GLD"
assert data["quantity"] == "100"
assert data["entry_price"] == "2150"
assert data["unit"] == "oz"
assert "id" in data
assert "created_at" in data
def test_position_deserialization(self) -> None:
"""Test position from_dict deserialization."""
pos = create_position(
underlying="GLD",
quantity=Decimal("100"),
entry_price=Decimal("2150"),
)
data = pos.to_dict()
restored = Position.from_dict(data)
assert restored.underlying == pos.underlying
assert restored.quantity == pos.quantity
assert restored.entry_price == pos.entry_price
assert restored.id == pos.id
def test_position_validates_positive_quantity(self) -> None:
"""Test that position rejects non-positive quantity."""
with pytest.raises(ValueError, match="quantity must be positive"):
create_position(quantity=Decimal("0"))
with pytest.raises(ValueError, match="quantity must be positive"):
create_position(quantity=Decimal("-10"))
def test_position_validates_positive_entry_price(self) -> None:
"""Test that position rejects non-positive entry price."""
with pytest.raises(ValueError, match="entry_price must be positive"):
create_position(entry_price=Decimal("0"))
with pytest.raises(ValueError, match="entry_price must be positive"):
create_position(entry_price=Decimal("-100"))
def test_position_validates_entry_basis_mode(self) -> None:
"""Test that position validates entry_basis_mode."""
with pytest.raises(ValueError, match="entry_basis_mode must be"):
Position(
id=uuid4(),
underlying="GLD",
quantity=Decimal("100"),
unit="oz",
entry_price=Decimal("2150"),
entry_date=date.today(),
entry_basis_mode="invalid",
)
class TestPortfolioConfigWithPositions:
"""Test PortfolioConfig integration with positions."""
def test_portfolio_config_with_empty_positions(self) -> None:
"""Test PortfolioConfig with no positions (legacy mode)."""
config = PortfolioConfig(
gold_ounces=100.0,
entry_price=2150.0,
gold_value=215000.0,
)
# Need to call migration explicitly
config.migrate_to_positions_if_needed()
# Should have migrated to one position
assert len(config.positions) == 1
assert config.positions[0].quantity == Decimal("100")
assert config.positions[0].entry_price == Decimal("2150")
def test_portfolio_config_with_multiple_positions(self) -> None:
"""Test PortfolioConfig with multiple positions."""
pos1 = create_position(
underlying="GLD",
quantity=Decimal("50"),
entry_price=Decimal("2100"),
)
pos2 = create_position(
underlying="GLD",
quantity=Decimal("50"),
entry_price=Decimal("2200"),
)
config = PortfolioConfig(
gold_ounces=100.0,
entry_price=2150.0,
gold_value=215000.0,
positions=[pos1, pos2],
)
assert len(config.positions) == 2
assert config.positions[0].id == pos1.id
assert config.positions[1].id == pos2.id
def test_portfolio_config_serializes_positions(self) -> None:
"""Test that PortfolioConfig.to_dict includes positions."""
pos = create_position(
underlying="GLD",
quantity=Decimal("100"),
entry_price=Decimal("2150"),
)
config = PortfolioConfig(
gold_ounces=100.0,
entry_price=2150.0,
gold_value=215000.0,
positions=[pos],
)
data = config.to_dict()
assert "positions" in data
assert len(data["positions"]) == 1
assert data["positions"][0]["underlying"] == "GLD"
assert data["positions"][0]["quantity"] == "100"
def test_portfolio_config_deserializes_positions(self) -> None:
"""Test that PortfolioConfig.from_dict restores positions."""
pos = create_position(
underlying="GLD",
quantity=Decimal("100"),
entry_price=Decimal("2150"),
)
data = {
"gold_value": 215000.0,
"entry_price": 2150.0,
"gold_ounces": 100.0,
"positions": [pos.to_dict()],
}
config = PortfolioConfig.from_dict(data)
assert len(config.positions) == 1
assert config.positions[0].underlying == "GLD"
assert config.positions[0].quantity == Decimal("100")
def test_portfolio_config_syncs_legacy_fields_from_positions(self) -> None:
"""Test that legacy fields are computed from positions."""
pos1 = create_position(
underlying="GLD",
quantity=Decimal("50"),
entry_price=Decimal("2100"),
)
pos2 = create_position(
underlying="GLD",
quantity=Decimal("50"),
entry_price=Decimal("2200"),
)
config = PortfolioConfig(
gold_ounces=100.0,
entry_price=2150.0,
gold_value=215000.0,
positions=[pos1, pos2],
)
# Trigger sync
config._sync_legacy_fields_from_positions()
# Total should be 100 oz
assert config.gold_ounces == 100.0
# Weighted average entry price: (50*2100 + 50*2200) / 100 = 2150
assert config.entry_price == 2150.0
# Total value: 50*2100 + 50*2200 = 215000
assert config.gold_value == 215000.0
class TestPortfolioRepositoryWithPositions:
"""Test PortfolioRepository persistence with positions."""
def test_repository_saves_positions(self, tmp_path) -> None:
"""Test that repository persists positions to disk."""
repo = PortfolioRepository(config_path=tmp_path / "portfolio_config.json")
pos = create_position(
underlying="GLD",
quantity=Decimal("100"),
entry_price=Decimal("2150"),
)
config = PortfolioConfig(
gold_ounces=100.0,
entry_price=2150.0,
gold_value=215000.0,
positions=[pos],
)
repo.save(config)
# Read raw JSON to verify structure
payload = json.loads((tmp_path / "portfolio_config.json").read_text())
assert "positions" in payload["portfolio"]
assert len(payload["portfolio"]["positions"]) == 1
assert payload["portfolio"]["positions"][0]["underlying"] == "GLD"
def test_repository_loads_positions(self, tmp_path) -> None:
"""Test that repository loads positions from disk."""
config_path = tmp_path / "portfolio_config.json"
pos = create_position(
underlying="GLD",
quantity=Decimal("100"),
entry_price=Decimal("2150"),
)
# Write raw JSON with positions
config_path.write_text(
json.dumps(
{
"schema_version": 2,
"portfolio": {
"gold_value": {"value": "215000.0", "currency": "USD"},
"entry_price": {
"value": "2150.0",
"currency": "USD",
"per_weight_unit": "ozt",
},
"gold_ounces": {"value": "100.0", "unit": "ozt"},
"entry_basis_mode": "weight",
"loan_amount": {"value": "145000.0", "currency": "USD"},
"margin_threshold": {"value": "0.75", "unit": "ratio"},
"monthly_budget": {"value": "8000.0", "currency": "USD"},
"ltv_warning": {"value": "0.70", "unit": "ratio"},
"primary_source": "yfinance",
"fallback_source": "yfinance",
"refresh_interval": {"value": 5, "unit": "seconds"},
"underlying": "GLD",
"volatility_spike": {"value": "0.25", "unit": "ratio"},
"spot_drawdown": {"value": "7.5", "unit": "percent"},
"email_alerts": False,
"positions": [pos.to_dict()],
},
}
)
)
config = PortfolioRepository(config_path=config_path).load()
assert len(config.positions) == 1
assert config.positions[0].underlying == "GLD"
assert config.positions[0].quantity == Decimal("100")
def test_repository_round_trips_positions(self, tmp_path) -> None:
"""Test that positions survive save/load round-trip."""
repo = PortfolioRepository(config_path=tmp_path / "portfolio_config.json")
pos = create_position(
underlying="XAU",
quantity=Decimal("50"),
entry_price=Decimal("2000"),
notes="Physical gold",
)
config = PortfolioConfig(
gold_ounces=50.0,
entry_price=2000.0,
gold_value=100000.0,
loan_amount=50000.0, # Must be < gold_value
positions=[pos],
)
repo.save(config)
loaded = repo.load()
assert len(loaded.positions) == 1
assert loaded.positions[0].underlying == "XAU"
assert loaded.positions[0].quantity == Decimal("50")
assert loaded.positions[0].notes == "Physical gold"
class TestWorkspaceRepositoryPositionCRUD:
"""Test WorkspaceRepository position CRUD operations."""
def test_add_position(self, tmp_path) -> None:
"""Test adding a position to workspace."""
repo = WorkspaceRepository(base_path=tmp_path / "workspaces")
workspace_id = repo.create_workspace_id()
# Workspace starts with one auto-migrated position from default config
initial_positions = repo.list_positions(workspace_id)
assert len(initial_positions) == 1
pos = create_position(
underlying="GLD",
quantity=Decimal("50"),
entry_price=Decimal("2150"),
)
repo.add_position(workspace_id, pos)
positions = repo.list_positions(workspace_id)
assert len(positions) == 2
assert positions[1].id == pos.id
assert positions[1].underlying == "GLD"
def test_remove_position(self, tmp_path) -> None:
"""Test removing a position from workspace."""
repo = WorkspaceRepository(base_path=tmp_path / "workspaces")
workspace_id = repo.create_workspace_id()
# Workspace starts with one auto-migrated position
initial_count = len(repo.list_positions(workspace_id))
assert initial_count == 1
pos1 = create_position(underlying="GLD", quantity=Decimal("50"))
pos2 = create_position(underlying="XAU", quantity=Decimal("50"))
repo.add_position(workspace_id, pos1)
repo.add_position(workspace_id, pos2)
positions = repo.list_positions(workspace_id)
assert len(positions) == initial_count + 2
repo.remove_position(workspace_id, pos1.id)
positions = repo.list_positions(workspace_id)
assert len(positions) == initial_count + 1
# pos2 should still be there
assert any(p.id == pos2.id for p in positions)
# pos1 should be gone
assert not any(p.id == pos1.id for p in positions)
def test_update_position(self, tmp_path) -> None:
"""Test updating a position."""
repo = WorkspaceRepository(base_path=tmp_path / "workspaces")
workspace_id = repo.create_workspace_id()
# Add a new position (not the auto-migrated one)
pos = create_position(
underlying="GLD",
quantity=Decimal("100"),
entry_price=Decimal("2150"),
notes="Original note",
)
repo.add_position(workspace_id, pos)
repo.update_position(
workspace_id,
pos.id,
{"notes": "Updated note", "quantity": Decimal("150")},
)
updated = repo.get_position(workspace_id, pos.id)
assert updated is not None
assert updated.notes == "Updated note"
assert updated.quantity == Decimal("150")
def test_get_position(self, tmp_path) -> None:
"""Test getting a specific position by ID."""
repo = WorkspaceRepository(base_path=tmp_path / "workspaces")
workspace_id = repo.create_workspace_id()
pos = create_position(underlying="GLD")
repo.add_position(workspace_id, pos)
retrieved = repo.get_position(workspace_id, pos.id)
assert retrieved is not None
assert retrieved.id == pos.id
# Non-existent position returns None
not_found = repo.get_position(workspace_id, uuid4())
assert not_found is None
def test_list_positions(self, tmp_path) -> None:
"""Test listing all positions."""
repo = WorkspaceRepository(base_path=tmp_path / "workspaces")
workspace_id = repo.create_workspace_id()
# Workspace starts with one auto-migrated position
initial_positions = repo.list_positions(workspace_id)
initial_count = len(initial_positions)
assert initial_count == 1
pos1 = create_position(underlying="GLD", quantity=Decimal("50"))
pos2 = create_position(underlying="XAU", quantity=Decimal("50"))
repo.add_position(workspace_id, pos1)
repo.add_position(workspace_id, pos2)
positions = repo.list_positions(workspace_id)
assert len(positions) == initial_count + 2
# Should contain the initial position plus the two new ones
assert any(p.id == pos1.id for p in positions)
assert any(p.id == pos2.id for p in positions)
class TestLegacyMigration:
"""Test backward migration from legacy single-entry to multi-position."""
def test_legacy_config_migrates_to_single_position(self) -> None:
"""Test that legacy config without positions creates one position."""
config = PortfolioConfig(
gold_ounces=100.0,
entry_price=2150.0,
gold_value=215000.0,
underlying="GLD",
)
# Need to call migration explicitly
config.migrate_to_positions_if_needed()
# Should have migrated to one position
assert len(config.positions) == 1
pos = config.positions[0]
assert pos.underlying == "GLD"
assert pos.quantity == Decimal("100")
assert pos.entry_price == Decimal("2150")
assert pos.unit == "oz"
def test_repository_loads_legacy_and_migrates(self, tmp_path) -> None:
"""Test that loading legacy config migrates to positions."""
config_path = tmp_path / "portfolio_config.json"
# Write legacy config without positions
config_path.write_text(
json.dumps(
{
"schema_version": 2,
"portfolio": {
"gold_value": {"value": "215000.0", "currency": "USD"},
"entry_price": {
"value": "2150.0",
"currency": "USD",
"per_weight_unit": "ozt",
},
"gold_ounces": {"value": "100.0", "unit": "ozt"},
"entry_basis_mode": "weight",
"loan_amount": {"value": "145000.0", "currency": "USD"},
"margin_threshold": {"value": "0.75", "unit": "ratio"},
"monthly_budget": {"value": "8000.0", "currency": "USD"},
"ltv_warning": {"value": "0.70", "unit": "ratio"},
"primary_source": "yfinance",
"fallback_source": "yfinance",
"refresh_interval": {"value": 5, "unit": "seconds"},
"underlying": "GLD",
"volatility_spike": {"value": "0.25", "unit": "ratio"},
"spot_drawdown": {"value": "7.5", "unit": "percent"},
"email_alerts": False,
# No positions field - legacy format
},
}
)
)
config = PortfolioRepository(config_path=config_path).load()
# Should have migrated to one position
assert len(config.positions) == 1
assert config.positions[0].quantity == Decimal("100")
assert config.positions[0].entry_price == Decimal("2150")
def test_multiple_positions_aggregate_correctly(self) -> None:
"""Test that multiple positions aggregate to correct totals."""
pos1 = create_position(
underlying="GLD",
quantity=Decimal("60"),
entry_price=Decimal("2000"),
)
pos2 = create_position(
underlying="GLD",
quantity=Decimal("40"),
entry_price=Decimal("2300"),
)
config = PortfolioConfig(
gold_ounces=100.0,
entry_price=2150.0,
gold_value=215000.0,
positions=[pos1, pos2],
)
config._sync_legacy_fields_from_positions()
# Total quantity: 60 + 40 = 100
assert config.gold_ounces == 100.0
# Weighted avg price: (60*2000 + 40*2300) / 100 = 2120
assert config.entry_price == 2120.0
# Total value: 60*2000 + 40*2300 = 212000
assert config.gold_value == 212000.0