Files
vault-dash/app/main.py

196 lines
6.1 KiB
Python

"""FastAPI application entry point with NiceGUI integration."""
from __future__ import annotations
import asyncio
import contextlib
import logging
import os
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any
from fastapi import FastAPI, Request, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from nicegui import ui # type: ignore[attr-defined]
import app.pages # noqa: F401
from app.api.routes import router as api_router
from app.models.workspace import WORKSPACE_COOKIE, get_workspace_repository
from app.services.cache import CacheService
from app.services.data_service import DataService
from app.services.runtime import set_data_service
logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
logger = logging.getLogger(__name__)
@dataclass
class Settings:
app_name: str = "Vault Dashboard"
environment: str = "development"
cors_origins: list[str] | None = None
redis_url: str | None = None
cache_ttl: int = 300
default_symbol: str = "GLD"
websocket_interval_seconds: int = 5
nicegui_mount_path: str = "/"
nicegui_storage_secret: str = "vault-dash-dev-secret"
@classmethod
def load(cls) -> Settings:
cls._load_dotenv()
origins = os.getenv("CORS_ORIGINS", "*")
return cls(
app_name=os.getenv("APP_NAME", cls.app_name),
environment=os.getenv("APP_ENV", os.getenv("ENVIRONMENT", cls.environment)),
cors_origins=[origin.strip() for origin in origins.split(",") if origin.strip()],
redis_url=os.getenv("REDIS_URL"),
cache_ttl=int(os.getenv("CACHE_TTL", cls.cache_ttl)),
default_symbol=os.getenv("DEFAULT_SYMBOL", cls.default_symbol),
websocket_interval_seconds=int(os.getenv("WEBSOCKET_INTERVAL_SECONDS", cls.websocket_interval_seconds)),
nicegui_mount_path=os.getenv("NICEGUI_MOUNT_PATH", cls.nicegui_mount_path),
nicegui_storage_secret=os.getenv("NICEGUI_STORAGE_SECRET", cls.nicegui_storage_secret),
)
@staticmethod
def _load_dotenv() -> None:
try:
from dotenv import load_dotenv
except ImportError:
return
load_dotenv()
settings = Settings.load()
class ConnectionManager:
def __init__(self) -> None:
self._connections: set[WebSocket] = set()
async def connect(self, websocket: WebSocket) -> None:
await websocket.accept()
self._connections.add(websocket)
def disconnect(self, websocket: WebSocket) -> None:
self._connections.discard(websocket)
async def broadcast_json(self, payload: dict[str, Any]) -> None:
stale: list[WebSocket] = []
for websocket in self._connections:
try:
await websocket.send_json(payload)
except Exception:
stale.append(websocket)
for websocket in stale:
self.disconnect(websocket)
@property
def count(self) -> int:
return len(self._connections)
async def publish_updates(app: FastAPI) -> None:
try:
while True:
payload = {
"type": "portfolio_update",
"connections": app.state.ws_manager.count,
"portfolio": await app.state.data_service.get_portfolio(app.state.settings.default_symbol),
}
await app.state.ws_manager.broadcast_json(payload)
await asyncio.sleep(app.state.settings.websocket_interval_seconds)
except asyncio.CancelledError:
logger.info("WebSocket publisher stopped")
raise
@asynccontextmanager
async def lifespan(app: FastAPI):
app.state.settings = settings
app.state.cache = CacheService(settings.redis_url, default_ttl=settings.cache_ttl)
await app.state.cache.connect()
app.state.data_service = DataService(app.state.cache, default_symbol=settings.default_symbol)
set_data_service(app.state.data_service)
app.state.ws_manager = ConnectionManager()
app.state.publisher_task = asyncio.create_task(publish_updates(app))
logger.info("Application startup complete")
try:
yield
finally:
app.state.publisher_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await app.state.publisher_task
await app.state.cache.close()
logger.info("Application shutdown complete")
app = FastAPI(title=settings.app_name, lifespan=lifespan)
app.include_router(api_router)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins or ["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health", tags=["health"])
async def health(request: Request) -> dict[str, Any]:
return {
"status": "ok",
"environment": request.app.state.settings.environment,
"redis_enabled": request.app.state.cache.enabled,
}
@app.get("/workspaces/bootstrap", tags=["workspace"])
async def bootstrap_workspace() -> RedirectResponse:
workspace_id = get_workspace_repository().create_workspace_id()
response = RedirectResponse(url=f"/{workspace_id}", status_code=303)
response.set_cookie(
key=WORKSPACE_COOKIE,
value=workspace_id,
httponly=True,
samesite="lax",
max_age=60 * 60 * 24 * 365,
path="/",
)
return response
@app.websocket("/ws/updates")
async def websocket_updates(websocket: WebSocket) -> None:
manager: ConnectionManager = websocket.app.state.ws_manager
await manager.connect(websocket)
try:
await websocket.send_json({"type": "connected", "message": "Real-time updates enabled"})
while True:
await websocket.receive_text()
except WebSocketDisconnect:
pass
finally:
manager.disconnect(websocket)
ui.run_with(
app,
mount_path=settings.nicegui_mount_path,
storage_secret=settings.nicegui_storage_secret,
)
if __name__ in {"__main__", "__mp_main__"}:
import uvicorn
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=8000,
reload=settings.environment == "development",
)