"""Tests for position model and multi-position portfolio support.""" from __future__ import annotations import json from datetime import date from decimal import Decimal from uuid import uuid4 import pytest from app.models.portfolio import PortfolioConfig, PortfolioRepository from app.models.position import Position, create_position from app.models.workspace import WorkspaceRepository class TestPositionModel: """Test Position model creation and validation.""" def test_create_position_with_defaults(self) -> None: """Test creating a position with default values.""" pos = create_position() assert pos.underlying == "GLD" assert pos.quantity == Decimal("100") assert pos.unit == "oz" assert pos.entry_price == Decimal("2150") assert pos.entry_date == date.today() assert pos.entry_basis_mode == "weight" assert pos.notes == "" def test_create_position_with_custom_values(self) -> None: """Test creating a position with custom values.""" custom_date = date(2025, 6, 15) pos = create_position( underlying="XAU", quantity=Decimal("50"), unit="oz", entry_price=Decimal("2000"), entry_date=custom_date, notes="Test position", ) assert pos.underlying == "XAU" assert pos.quantity == Decimal("50") assert pos.entry_price == Decimal("2000") assert pos.entry_date == custom_date assert pos.notes == "Test position" def test_position_entry_value(self) -> None: """Test position entry value calculation.""" pos = create_position( quantity=Decimal("100"), entry_price=Decimal("2150"), ) assert pos.entry_value == Decimal("215000") def test_position_serialization(self) -> None: """Test position to_dict serialization.""" pos = create_position( underlying="GLD", quantity=Decimal("100"), entry_price=Decimal("2150"), ) data = pos.to_dict() assert data["underlying"] == "GLD" assert data["quantity"] == "100" assert data["entry_price"] == "2150" assert data["unit"] == "oz" assert "id" in data assert "created_at" in data def test_position_deserialization(self) -> None: """Test position from_dict deserialization.""" pos = create_position( underlying="GLD", quantity=Decimal("100"), entry_price=Decimal("2150"), ) data = pos.to_dict() restored = Position.from_dict(data) assert restored.underlying == pos.underlying assert restored.quantity == pos.quantity assert restored.entry_price == pos.entry_price assert restored.id == pos.id def test_position_validates_positive_quantity(self) -> None: """Test that position rejects non-positive quantity.""" with pytest.raises(ValueError, match="quantity must be positive"): create_position(quantity=Decimal("0")) with pytest.raises(ValueError, match="quantity must be positive"): create_position(quantity=Decimal("-10")) def test_position_validates_positive_entry_price(self) -> None: """Test that position rejects non-positive entry price.""" with pytest.raises(ValueError, match="entry_price must be positive"): create_position(entry_price=Decimal("0")) with pytest.raises(ValueError, match="entry_price must be positive"): create_position(entry_price=Decimal("-100")) def test_position_validates_entry_basis_mode(self) -> None: """Test that position validates entry_basis_mode.""" with pytest.raises(ValueError, match="entry_basis_mode must be"): Position( id=uuid4(), underlying="GLD", quantity=Decimal("100"), unit="oz", entry_price=Decimal("2150"), entry_date=date.today(), entry_basis_mode="invalid", ) class TestPortfolioConfigWithPositions: """Test PortfolioConfig integration with positions.""" def test_portfolio_config_with_empty_positions(self) -> None: """Test PortfolioConfig with no positions (legacy mode).""" config = PortfolioConfig( gold_ounces=100.0, entry_price=2150.0, gold_value=215000.0, ) # Need to call migration explicitly config.migrate_to_positions_if_needed() # Should have migrated to one position assert len(config.positions) == 1 assert config.positions[0].quantity == Decimal("100") assert config.positions[0].entry_price == Decimal("2150") def test_portfolio_config_with_multiple_positions(self) -> None: """Test PortfolioConfig with multiple positions.""" pos1 = create_position( underlying="GLD", quantity=Decimal("50"), entry_price=Decimal("2100"), ) pos2 = create_position( underlying="GLD", quantity=Decimal("50"), entry_price=Decimal("2200"), ) config = PortfolioConfig( gold_ounces=100.0, entry_price=2150.0, gold_value=215000.0, positions=[pos1, pos2], ) assert len(config.positions) == 2 assert config.positions[0].id == pos1.id assert config.positions[1].id == pos2.id def test_portfolio_config_serializes_positions(self) -> None: """Test that PortfolioConfig.to_dict includes positions.""" pos = create_position( underlying="GLD", quantity=Decimal("100"), entry_price=Decimal("2150"), ) config = PortfolioConfig( gold_ounces=100.0, entry_price=2150.0, gold_value=215000.0, positions=[pos], ) data = config.to_dict() assert "positions" in data assert len(data["positions"]) == 1 assert data["positions"][0]["underlying"] == "GLD" assert data["positions"][0]["quantity"] == "100" def test_portfolio_config_deserializes_positions(self) -> None: """Test that PortfolioConfig.from_dict restores positions.""" pos = create_position( underlying="GLD", quantity=Decimal("100"), entry_price=Decimal("2150"), ) data = { "gold_value": 215000.0, "entry_price": 2150.0, "gold_ounces": 100.0, "positions": [pos.to_dict()], } config = PortfolioConfig.from_dict(data) assert len(config.positions) == 1 assert config.positions[0].underlying == "GLD" assert config.positions[0].quantity == Decimal("100") def test_portfolio_config_syncs_legacy_fields_from_positions(self) -> None: """Test that legacy fields are computed from positions.""" pos1 = create_position( underlying="GLD", quantity=Decimal("50"), entry_price=Decimal("2100"), ) pos2 = create_position( underlying="GLD", quantity=Decimal("50"), entry_price=Decimal("2200"), ) config = PortfolioConfig( gold_ounces=100.0, entry_price=2150.0, gold_value=215000.0, positions=[pos1, pos2], ) # Trigger sync config._sync_legacy_fields_from_positions() # Total should be 100 oz assert config.gold_ounces == 100.0 # Weighted average entry price: (50*2100 + 50*2200) / 100 = 2150 assert config.entry_price == 2150.0 # Total value: 50*2100 + 50*2200 = 215000 assert config.gold_value == 215000.0 class TestPortfolioRepositoryWithPositions: """Test PortfolioRepository persistence with positions.""" def test_repository_saves_positions(self, tmp_path) -> None: """Test that repository persists positions to disk.""" repo = PortfolioRepository(config_path=tmp_path / "portfolio_config.json") pos = create_position( underlying="GLD", quantity=Decimal("100"), entry_price=Decimal("2150"), ) config = PortfolioConfig( gold_ounces=100.0, entry_price=2150.0, gold_value=215000.0, positions=[pos], ) repo.save(config) # Read raw JSON to verify structure payload = json.loads((tmp_path / "portfolio_config.json").read_text()) assert "positions" in payload["portfolio"] assert len(payload["portfolio"]["positions"]) == 1 assert payload["portfolio"]["positions"][0]["underlying"] == "GLD" def test_repository_loads_positions(self, tmp_path) -> None: """Test that repository loads positions from disk.""" config_path = tmp_path / "portfolio_config.json" pos = create_position( underlying="GLD", quantity=Decimal("100"), entry_price=Decimal("2150"), ) # Write raw JSON with positions config_path.write_text( json.dumps( { "schema_version": 2, "portfolio": { "gold_value": {"value": "215000.0", "currency": "USD"}, "entry_price": { "value": "2150.0", "currency": "USD", "per_weight_unit": "ozt", }, "gold_ounces": {"value": "100.0", "unit": "ozt"}, "entry_basis_mode": "weight", "loan_amount": {"value": "145000.0", "currency": "USD"}, "margin_threshold": {"value": "0.75", "unit": "ratio"}, "monthly_budget": {"value": "8000.0", "currency": "USD"}, "ltv_warning": {"value": "0.70", "unit": "ratio"}, "primary_source": "yfinance", "fallback_source": "yfinance", "refresh_interval": {"value": 5, "unit": "seconds"}, "underlying": "GLD", "volatility_spike": {"value": "0.25", "unit": "ratio"}, "spot_drawdown": {"value": "7.5", "unit": "percent"}, "email_alerts": False, "positions": [pos.to_dict()], }, } ) ) config = PortfolioRepository(config_path=config_path).load() assert len(config.positions) == 1 assert config.positions[0].underlying == "GLD" assert config.positions[0].quantity == Decimal("100") def test_repository_round_trips_positions(self, tmp_path) -> None: """Test that positions survive save/load round-trip.""" repo = PortfolioRepository(config_path=tmp_path / "portfolio_config.json") pos = create_position( underlying="XAU", quantity=Decimal("50"), entry_price=Decimal("2000"), notes="Physical gold", ) config = PortfolioConfig( gold_ounces=50.0, entry_price=2000.0, gold_value=100000.0, loan_amount=50000.0, # Must be < gold_value positions=[pos], ) repo.save(config) loaded = repo.load() assert len(loaded.positions) == 1 assert loaded.positions[0].underlying == "XAU" assert loaded.positions[0].quantity == Decimal("50") assert loaded.positions[0].notes == "Physical gold" class TestWorkspaceRepositoryPositionCRUD: """Test WorkspaceRepository position CRUD operations.""" def test_add_position(self, tmp_path) -> None: """Test adding a position to workspace.""" repo = WorkspaceRepository(base_path=tmp_path / "workspaces") workspace_id = repo.create_workspace_id() # Workspace starts with one auto-migrated position from default config initial_positions = repo.list_positions(workspace_id) assert len(initial_positions) == 1 pos = create_position( underlying="GLD", quantity=Decimal("50"), entry_price=Decimal("2150"), ) repo.add_position(workspace_id, pos) positions = repo.list_positions(workspace_id) assert len(positions) == 2 assert positions[1].id == pos.id assert positions[1].underlying == "GLD" def test_remove_position(self, tmp_path) -> None: """Test removing a position from workspace.""" repo = WorkspaceRepository(base_path=tmp_path / "workspaces") workspace_id = repo.create_workspace_id() # Workspace starts with one auto-migrated position initial_count = len(repo.list_positions(workspace_id)) assert initial_count == 1 pos1 = create_position(underlying="GLD", quantity=Decimal("50")) pos2 = create_position(underlying="XAU", quantity=Decimal("50")) repo.add_position(workspace_id, pos1) repo.add_position(workspace_id, pos2) positions = repo.list_positions(workspace_id) assert len(positions) == initial_count + 2 repo.remove_position(workspace_id, pos1.id) positions = repo.list_positions(workspace_id) assert len(positions) == initial_count + 1 # pos2 should still be there assert any(p.id == pos2.id for p in positions) # pos1 should be gone assert not any(p.id == pos1.id for p in positions) def test_update_position(self, tmp_path) -> None: """Test updating a position.""" repo = WorkspaceRepository(base_path=tmp_path / "workspaces") workspace_id = repo.create_workspace_id() # Add a new position (not the auto-migrated one) pos = create_position( underlying="GLD", quantity=Decimal("100"), entry_price=Decimal("2150"), notes="Original note", ) repo.add_position(workspace_id, pos) repo.update_position( workspace_id, pos.id, {"notes": "Updated note", "quantity": Decimal("150")}, ) updated = repo.get_position(workspace_id, pos.id) assert updated is not None assert updated.notes == "Updated note" assert updated.quantity == Decimal("150") def test_get_position(self, tmp_path) -> None: """Test getting a specific position by ID.""" repo = WorkspaceRepository(base_path=tmp_path / "workspaces") workspace_id = repo.create_workspace_id() pos = create_position(underlying="GLD") repo.add_position(workspace_id, pos) retrieved = repo.get_position(workspace_id, pos.id) assert retrieved is not None assert retrieved.id == pos.id # Non-existent position returns None not_found = repo.get_position(workspace_id, uuid4()) assert not_found is None def test_list_positions(self, tmp_path) -> None: """Test listing all positions.""" repo = WorkspaceRepository(base_path=tmp_path / "workspaces") workspace_id = repo.create_workspace_id() # Workspace starts with one auto-migrated position initial_positions = repo.list_positions(workspace_id) initial_count = len(initial_positions) assert initial_count == 1 pos1 = create_position(underlying="GLD", quantity=Decimal("50")) pos2 = create_position(underlying="XAU", quantity=Decimal("50")) repo.add_position(workspace_id, pos1) repo.add_position(workspace_id, pos2) positions = repo.list_positions(workspace_id) assert len(positions) == initial_count + 2 # Should contain the initial position plus the two new ones assert any(p.id == pos1.id for p in positions) assert any(p.id == pos2.id for p in positions) class TestLegacyMigration: """Test backward migration from legacy single-entry to multi-position.""" def test_legacy_config_migrates_to_single_position(self) -> None: """Test that legacy config without positions creates one position.""" config = PortfolioConfig( gold_ounces=100.0, entry_price=2150.0, gold_value=215000.0, underlying="GLD", ) # Need to call migration explicitly config.migrate_to_positions_if_needed() # Should have migrated to one position assert len(config.positions) == 1 pos = config.positions[0] assert pos.underlying == "GLD" assert pos.quantity == Decimal("100") assert pos.entry_price == Decimal("2150") assert pos.unit == "oz" def test_repository_loads_legacy_and_migrates(self, tmp_path) -> None: """Test that loading legacy config migrates to positions.""" config_path = tmp_path / "portfolio_config.json" # Write legacy config without positions config_path.write_text( json.dumps( { "schema_version": 2, "portfolio": { "gold_value": {"value": "215000.0", "currency": "USD"}, "entry_price": { "value": "2150.0", "currency": "USD", "per_weight_unit": "ozt", }, "gold_ounces": {"value": "100.0", "unit": "ozt"}, "entry_basis_mode": "weight", "loan_amount": {"value": "145000.0", "currency": "USD"}, "margin_threshold": {"value": "0.75", "unit": "ratio"}, "monthly_budget": {"value": "8000.0", "currency": "USD"}, "ltv_warning": {"value": "0.70", "unit": "ratio"}, "primary_source": "yfinance", "fallback_source": "yfinance", "refresh_interval": {"value": 5, "unit": "seconds"}, "underlying": "GLD", "volatility_spike": {"value": "0.25", "unit": "ratio"}, "spot_drawdown": {"value": "7.5", "unit": "percent"}, "email_alerts": False, # No positions field - legacy format }, } ) ) config = PortfolioRepository(config_path=config_path).load() # Should have migrated to one position assert len(config.positions) == 1 assert config.positions[0].quantity == Decimal("100") assert config.positions[0].entry_price == Decimal("2150") def test_multiple_positions_aggregate_correctly(self) -> None: """Test that multiple positions aggregate to correct totals.""" pos1 = create_position( underlying="GLD", quantity=Decimal("60"), entry_price=Decimal("2000"), ) pos2 = create_position( underlying="GLD", quantity=Decimal("40"), entry_price=Decimal("2300"), ) config = PortfolioConfig( gold_ounces=100.0, entry_price=2150.0, gold_value=215000.0, positions=[pos1, pos2], ) config._sync_legacy_fields_from_positions() # Total quantity: 60 + 40 = 100 assert config.gold_ounces == 100.0 # Weighted avg price: (60*2000 + 40*2300) / 100 = 2120 assert config.entry_price == 2120.0 # Total value: 60*2000 + 40*2300 = 212000 assert config.gold_value == 212000.0