228 lines
7.6 KiB
Python
228 lines
7.6 KiB
Python
"""FastAPI application entry point with NiceGUI integration."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import logging
|
|
import os
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from fastapi import FastAPI, Form, Request, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.responses import RedirectResponse, Response
|
|
from nicegui import ui # type: ignore[attr-defined]
|
|
|
|
import app.pages # noqa: F401
|
|
from app.api.routes import router as api_router
|
|
from app.domain.portfolio_math import resolve_collateral_spot_from_quote
|
|
from app.models.portfolio import build_default_portfolio_config
|
|
from app.models.workspace import WORKSPACE_COOKIE, get_workspace_repository
|
|
from app.services import turnstile as turnstile_service
|
|
from app.services.cache import CacheService
|
|
from app.services.data_service import DataService
|
|
from app.services.runtime import get_data_service, set_data_service
|
|
|
|
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class Settings:
|
|
app_name: str = "Vault Dashboard"
|
|
environment: str = "development"
|
|
cors_origins: list[str] | None = None
|
|
redis_url: str | None = None
|
|
cache_ttl: int = 300
|
|
default_symbol: str = "GLD"
|
|
websocket_interval_seconds: int = 5
|
|
nicegui_mount_path: str = "/"
|
|
nicegui_storage_secret: str = "vault-dash-dev-secret"
|
|
turnstile_site_key: str = ""
|
|
turnstile_secret_key: str = ""
|
|
|
|
@classmethod
|
|
def load(cls) -> Settings:
|
|
cls._load_dotenv()
|
|
origins = os.getenv("CORS_ORIGINS", "*")
|
|
turnstile = turnstile_service.load_turnstile_settings()
|
|
return cls(
|
|
app_name=os.getenv("APP_NAME", cls.app_name),
|
|
environment=os.getenv("APP_ENV", os.getenv("ENVIRONMENT", cls.environment)),
|
|
cors_origins=[origin.strip() for origin in origins.split(",") if origin.strip()],
|
|
redis_url=os.getenv("REDIS_URL"),
|
|
cache_ttl=int(os.getenv("CACHE_TTL", cls.cache_ttl)),
|
|
default_symbol=os.getenv("DEFAULT_SYMBOL", cls.default_symbol),
|
|
websocket_interval_seconds=int(os.getenv("WEBSOCKET_INTERVAL_SECONDS", cls.websocket_interval_seconds)),
|
|
nicegui_mount_path=os.getenv("NICEGUI_MOUNT_PATH", cls.nicegui_mount_path),
|
|
nicegui_storage_secret=os.getenv("NICEGUI_STORAGE_SECRET", cls.nicegui_storage_secret),
|
|
turnstile_site_key=turnstile.site_key,
|
|
turnstile_secret_key=turnstile.secret_key,
|
|
)
|
|
|
|
@staticmethod
|
|
def _load_dotenv() -> None:
|
|
try:
|
|
from dotenv import load_dotenv
|
|
except ImportError:
|
|
return
|
|
load_dotenv()
|
|
|
|
|
|
settings = Settings.load()
|
|
|
|
|
|
class ConnectionManager:
|
|
def __init__(self) -> None:
|
|
self._connections: set[WebSocket] = set()
|
|
|
|
async def connect(self, websocket: WebSocket) -> None:
|
|
await websocket.accept()
|
|
self._connections.add(websocket)
|
|
|
|
def disconnect(self, websocket: WebSocket) -> None:
|
|
self._connections.discard(websocket)
|
|
|
|
async def broadcast_json(self, payload: dict[str, Any]) -> None:
|
|
stale: list[WebSocket] = []
|
|
for websocket in self._connections:
|
|
try:
|
|
await websocket.send_json(payload)
|
|
except Exception:
|
|
stale.append(websocket)
|
|
for websocket in stale:
|
|
self.disconnect(websocket)
|
|
|
|
@property
|
|
def count(self) -> int:
|
|
return len(self._connections)
|
|
|
|
|
|
async def publish_updates(app: FastAPI) -> None:
|
|
try:
|
|
while True:
|
|
payload = {
|
|
"type": "portfolio_update",
|
|
"connections": app.state.ws_manager.count,
|
|
"portfolio": await app.state.data_service.get_portfolio(app.state.settings.default_symbol),
|
|
}
|
|
await app.state.ws_manager.broadcast_json(payload)
|
|
await asyncio.sleep(app.state.settings.websocket_interval_seconds)
|
|
except asyncio.CancelledError:
|
|
logger.info("WebSocket publisher stopped")
|
|
raise
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
app.state.settings = settings
|
|
app.state.cache = CacheService(settings.redis_url, default_ttl=settings.cache_ttl)
|
|
await app.state.cache.connect()
|
|
app.state.data_service = DataService(app.state.cache, default_underlying=settings.default_symbol)
|
|
set_data_service(app.state.data_service)
|
|
app.state.ws_manager = ConnectionManager()
|
|
app.state.publisher_task = asyncio.create_task(publish_updates(app))
|
|
logger.info("Application startup complete")
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
app.state.publisher_task.cancel()
|
|
with contextlib.suppress(asyncio.CancelledError):
|
|
await app.state.publisher_task
|
|
await app.state.cache.close()
|
|
logger.info("Application shutdown complete")
|
|
|
|
|
|
app = FastAPI(title=settings.app_name, lifespan=lifespan)
|
|
app.include_router(api_router)
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=settings.cors_origins or ["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.get("/health", tags=["health"])
|
|
async def health(request: Request) -> dict[str, Any]:
|
|
return {
|
|
"status": "ok",
|
|
"environment": request.app.state.settings.environment,
|
|
"redis_enabled": request.app.state.cache.enabled,
|
|
}
|
|
|
|
|
|
@app.get("/workspaces/bootstrap", tags=["workspace"])
|
|
async def bootstrap_workspace_redirect() -> RedirectResponse:
|
|
return RedirectResponse(url="/", status_code=303)
|
|
|
|
|
|
@app.post("/workspaces/bootstrap", tags=["workspace"])
|
|
async def bootstrap_workspace(
|
|
request: Request,
|
|
turnstile_response: str = Form(alias="cf-turnstile-response", default=""),
|
|
) -> Response:
|
|
if not turnstile_service.verify_turnstile_token(
|
|
turnstile_response, request.client.host if request.client else None
|
|
):
|
|
return RedirectResponse(url="/?captcha_error=1", status_code=303)
|
|
|
|
repo = get_workspace_repository()
|
|
config = build_default_portfolio_config()
|
|
try:
|
|
data_service = get_data_service()
|
|
quote = await data_service.get_quote(data_service.default_symbol)
|
|
resolved_spot = resolve_collateral_spot_from_quote(quote, fallback_symbol=data_service.default_symbol)
|
|
if resolved_spot is not None:
|
|
config = build_default_portfolio_config(entry_price=resolved_spot[0])
|
|
except Exception as exc:
|
|
logger.warning("Falling back to static default workspace seed: %s", exc)
|
|
|
|
workspace_id = repo.create_workspace_id(config=config)
|
|
response = RedirectResponse(url=f"/{workspace_id}", status_code=303)
|
|
response.set_cookie(
|
|
key=WORKSPACE_COOKIE,
|
|
value=workspace_id,
|
|
httponly=True,
|
|
samesite="lax",
|
|
max_age=60 * 60 * 24 * 365,
|
|
path="/",
|
|
)
|
|
return response
|
|
|
|
|
|
@app.websocket("/ws/updates")
|
|
async def websocket_updates(websocket: WebSocket) -> None:
|
|
manager: ConnectionManager = websocket.app.state.ws_manager
|
|
await manager.connect(websocket)
|
|
try:
|
|
await websocket.send_json({"type": "connected", "message": "Real-time updates enabled"})
|
|
while True:
|
|
await websocket.receive_text()
|
|
except WebSocketDisconnect:
|
|
pass
|
|
finally:
|
|
manager.disconnect(websocket)
|
|
|
|
|
|
ui.run_with(
|
|
app,
|
|
mount_path=settings.nicegui_mount_path,
|
|
storage_secret=settings.nicegui_storage_secret,
|
|
)
|
|
|
|
|
|
if __name__ in {"__main__", "__mp_main__"}:
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
"app.main:app",
|
|
host="0.0.0.0",
|
|
port=8000,
|
|
reload=settings.environment == "development",
|
|
)
|