diff --git a/package.json b/package.json index c9e2058..4895df2 100755 --- a/package.json +++ b/package.json @@ -11,7 +11,8 @@ "db:push": "drizzle-kit push", "db:studio": "drizzle-kit studio", "db:seed": "bun db:push && bun scripts/seed-dev.ts", - "dev": "next dev --turbo", + "dev": "bun run ws-server.ts & next dev --turbo", + "dev:ws": "bun run ws-server.ts", "docker:up": "if [ \"$(uname)\" = \"Darwin\" ]; then colima start; fi && docker compose up -d", "docker:down": "docker compose down && if [ \"$(uname)\" = \"Darwin\" ]; then colima stop; fi", "format:check": "prettier --check \"**/*.{ts,tsx,js,jsx,mdx}\" --cache", @@ -134,4 +135,4 @@ "sharp", "unrs-resolver" ] -} \ No newline at end of file +} diff --git a/src/app/api/websocket/route.ts b/src/app/api/websocket/route.ts new file mode 100644 index 0000000..718e369 --- /dev/null +++ b/src/app/api/websocket/route.ts @@ -0,0 +1,135 @@ +import { NextRequest } from "next/server"; +import { headers } from "next/headers"; +import { wsManager } from "~/server/services/websocket-manager"; +import { auth } from "~/lib/auth"; + +const clientConnections = new Map< + string, + { socket: WebSocket; clientId: string } +>(); + +function generateClientId(): string { + return `ws_${Date.now()}_${Math.random().toString(36).substring(2, 9)}`; +} + +export const runtime = "edge"; +export const dynamic = "force-dynamic"; + +export async function GET(request: NextRequest) { + const url = new URL(request.url); + const trialId = url.searchParams.get("trialId"); + const token = url.searchParams.get("token"); + + if (!trialId) { + return new Response("Missing trialId parameter", { status: 400 }); + } + + let userId: string | null = null; + + try { + const session = await auth.api.getSession({ + headers: await headers(), + }); + if (session?.user?.id) { + userId = session.user.id; + } + } catch { + if (!token) { + return new Response("Authentication required", { status: 401 }); + } + + try { + const tokenData = JSON.parse(atob(token)); + userId = tokenData.userId; + } catch { + return new Response("Invalid token", { status: 401 }); + } + } + + const pair = new WebSocketPair(); + const clientId = generateClientId(); + const serverWebSocket = Object.values(pair)[0] as WebSocket; + + clientConnections.set(clientId, { socket: serverWebSocket, clientId }); + + await wsManager.subscribe(clientId, serverWebSocket, trialId, userId); + + serverWebSocket.accept(); + + serverWebSocket.addEventListener("message", async (event) => { + try { + const message = JSON.parse(event.data as string); + + switch (message.type) { + case "heartbeat": + wsManager.sendToClient(clientId, { + type: "heartbeat_response", + data: { timestamp: Date.now() }, + }); + break; + + case "request_trial_status": { + const status = await wsManager.getTrialStatus(trialId); + wsManager.sendToClient(clientId, { + type: "trial_status", + data: { + trial: status?.trial ?? null, + current_step_index: status?.currentStepIndex ?? 0, + timestamp: Date.now(), + }, + }); + break; + } + + case "request_trial_events": { + const events = await wsManager.getTrialEvents( + trialId, + message.data?.limit ?? 100, + ); + wsManager.sendToClient(clientId, { + type: "trial_events_snapshot", + data: { events, timestamp: Date.now() }, + }); + break; + } + + case "ping": + wsManager.sendToClient(clientId, { + type: "pong", + data: { timestamp: Date.now() }, + }); + break; + + default: + console.log( + `[WS] Unknown message type from client ${clientId}:`, + message.type, + ); + } + } catch (error) { + console.error(`[WS] Error processing message from ${clientId}:`, error); + } + }); + + serverWebSocket.addEventListener("close", () => { + wsManager.unsubscribe(clientId); + clientConnections.delete(clientId); + }); + + serverWebSocket.addEventListener("error", (error) => { + console.error(`[WS] Error for client ${clientId}:`, error); + wsManager.unsubscribe(clientId); + clientConnections.delete(clientId); + }); + + return new Response(null, { + status: 101, + webSocket: serverWebSocket, + } as ResponseInit); +} + +declare global { + interface WebSocket { + accept(): void; + } +} diff --git a/src/components/trials/TrialForm.tsx b/src/components/trials/TrialForm.tsx index c65d664..3865072 100755 --- a/src/components/trials/TrialForm.tsx +++ b/src/components/trials/TrialForm.tsx @@ -163,6 +163,11 @@ export function TrialForm({ mode, trialId, studyId }: TrialFormProps) { const form = useForm({ resolver: zodResolver(trialSchema), defaultValues: { + experimentId: "" as any, + participantId: "" as any, + scheduledAt: new Date(), + wizardId: undefined, + notes: "", sessionNumber: 1, }, }); @@ -347,7 +352,7 @@ export function TrialForm({ mode, trialId, studyId }: TrialFormProps) { form.setValue("participantId", value)} disabled={participantsLoading || mode === "edit"} > diff --git a/src/components/trials/wizard/WizardInterface.tsx b/src/components/trials/wizard/WizardInterface.tsx index 1d7c27f..d5a33c3 100755 --- a/src/components/trials/wizard/WizardInterface.tsx +++ b/src/components/trials/wizard/WizardInterface.tsx @@ -32,6 +32,7 @@ import { WebcamPanel } from "./panels/WebcamPanel"; import { TrialStatusBar } from "./panels/TrialStatusBar"; import { api } from "~/trpc/react"; import { useWizardRos } from "~/hooks/useWizardRos"; +import { useTrialWebSocket, type TrialEvent } from "~/hooks/useWebSocket"; import { toast } from "sonner"; import { useTour } from "~/components/onboarding/TourProvider"; @@ -252,59 +253,65 @@ export const WizardInterface = React.memo(function WizardInterface({ [setAutonomousLifeRaw], ); - // Use polling for trial status updates (no trial WebSocket server exists) - const { data: pollingData } = api.trials.get.useQuery( - { id: trial.id }, - { - refetchInterval: trial.status === "in_progress" ? 5000 : 15000, - staleTime: 2000, - refetchOnWindowFocus: false, + // Trial WebSocket for real-time updates + const { + isConnected: wsConnected, + connectionError: wsError, + trialEvents: wsTrialEvents, + currentTrialStatus, + addLocalEvent, + } = useTrialWebSocket(trial.id, { + onStatusChange: (status) => { + // Update local trial state when WebSocket reports status changes + setTrial((prev) => ({ + ...prev, + status: status.status, + startedAt: status.startedAt + ? new Date(status.startedAt) + : prev.startedAt, + completedAt: status.completedAt + ? new Date(status.completedAt) + : prev.completedAt, + })); }, - ); - - // Poll for trial events - const { data: fetchedEvents } = api.trials.getEvents.useQuery( - { trialId: trial.id, limit: 100 }, - { - refetchInterval: 3000, - staleTime: 1000, - }, - ); - - // Update local trial state from polling only if changed - useEffect(() => { - if (pollingData && JSON.stringify(pollingData) !== JSON.stringify(trial)) { - // Only update if specific fields we care about have changed to avoid - // unnecessary re-renders that might cause UI flashing - if ( - pollingData.status !== trial.status || - pollingData.startedAt?.getTime() !== trial.startedAt?.getTime() || - pollingData.completedAt?.getTime() !== trial.completedAt?.getTime() - ) { - setTrial((prev) => { - // Double check inside setter to be safe - if ( - prev.status === pollingData.status && - prev.startedAt?.getTime() === pollingData.startedAt?.getTime() && - prev.completedAt?.getTime() === pollingData.completedAt?.getTime() - ) { - return prev; - } - return { - ...prev, - status: pollingData.status, - startedAt: pollingData.startedAt - ? new Date(pollingData.startedAt) - : prev.startedAt, - completedAt: pollingData.completedAt - ? new Date(pollingData.completedAt) - : prev.completedAt, - }; - }); + onTrialEvent: (event) => { + // Optionally show toast for new events + if (event.eventType === "trial_started") { + toast.info("Trial started"); + } else if (event.eventType === "trial_completed") { + toast.info("Trial completed"); + } else if (event.eventType === "trial_aborted") { + toast.warning("Trial aborted"); } + }, + }); + + // Update trial state from WebSocket status + useEffect(() => { + if (currentTrialStatus) { + setTrial((prev) => { + if ( + prev.status === currentTrialStatus.status && + prev.startedAt?.getTime() === + new Date(currentTrialStatus.startedAt ?? "").getTime() && + prev.completedAt?.getTime() === + new Date(currentTrialStatus.completedAt ?? "").getTime() + ) { + return prev; + } + return { + ...prev, + status: currentTrialStatus.status, + startedAt: currentTrialStatus.startedAt + ? new Date(currentTrialStatus.startedAt) + : prev.startedAt, + completedAt: currentTrialStatus.completedAt + ? new Date(currentTrialStatus.completedAt) + : prev.completedAt, + }; + }); } - // eslint-disable-next-line react-hooks/exhaustive-deps - }, [pollingData]); + }, [currentTrialStatus]); // Auto-start trial on mount if scheduled useEffect(() => { @@ -313,7 +320,7 @@ export const WizardInterface = React.memo(function WizardInterface({ } }, []); // Run once on mount - // Trial events from robot actions + // Trial events from WebSocket (and initial load) const trialEvents = useMemo< Array<{ type: string; @@ -322,8 +329,8 @@ export const WizardInterface = React.memo(function WizardInterface({ message?: string; }> >(() => { - return (fetchedEvents ?? []) - .map((event) => { + return (wsTrialEvents ?? []) + .map((event: TrialEvent) => { let message: string | undefined; const eventData = event.data as any; @@ -364,7 +371,7 @@ export const WizardInterface = React.memo(function WizardInterface({ }; }) .sort((a, b) => b.timestamp.getTime() - a.timestamp.getTime()); // Newest first - }, [fetchedEvents]); + }, [wsTrialEvents]); // Transform experiment steps to component format const steps: StepData[] = useMemo( diff --git a/src/hooks/useWebSocket.ts b/src/hooks/useWebSocket.ts index f6c6ec6..5966987 100755 --- a/src/hooks/useWebSocket.ts +++ b/src/hooks/useWebSocket.ts @@ -1,7 +1,5 @@ "use client"; -/* eslint-disable react-hooks/exhaustive-deps */ - import { useSession } from "~/lib/auth-client"; import { useCallback, useEffect, useRef, useState } from "react"; @@ -56,10 +54,42 @@ interface TrialActionExecutedMessage { interface InterventionLoggedMessage { type: "intervention_logged"; data: { + intervention?: unknown; timestamp: number; } & Record; } +interface TrialEventMessage { + type: "trial_event"; + data: { + event: unknown; + timestamp: number; + }; +} + +interface TrialEventsSnapshotMessage { + type: "trial_events_snapshot"; + data: { + events: unknown[]; + timestamp: number; + }; +} + +interface AnnotationAddedMessage { + type: "annotation_added"; + data: { + annotation: unknown; + timestamp: number; + }; +} + +interface PongMessage { + type: "pong"; + data: { + timestamp: number; + }; +} + interface StepChangedMessage { type: "step_changed"; data: { @@ -83,6 +113,10 @@ type KnownInboundMessage = | TrialStatusMessage | TrialActionExecutedMessage | InterventionLoggedMessage + | TrialEventMessage + | TrialEventsSnapshotMessage + | AnnotationAddedMessage + | PongMessage | StepChangedMessage | ErrorMessage; @@ -98,18 +132,247 @@ export interface OutgoingMessage { data: Record; } -export interface UseWebSocketOptions { +interface Subscription { trialId: string; onMessage?: (message: WebSocketMessage) => void; onConnect?: () => void; onDisconnect?: () => void; onError?: (error: Event) => void; - reconnectAttempts?: number; - reconnectInterval?: number; - heartbeatInterval?: number; } -export interface UseWebSocketReturn { +interface GlobalWSState { + isConnected: boolean; + isConnecting: boolean; + connectionError: string | null; + lastMessage: WebSocketMessage | null; +} + +type StateListener = (state: GlobalWSState) => void; + +class GlobalWebSocketManager { + private ws: WebSocket | null = null; + private subscriptions: Map = new Map(); + private stateListeners: Set = new Set(); + private sessionRef: { user: { id: string } } | null = null; + private heartbeatInterval: ReturnType | null = null; + private reconnectTimeout: ReturnType | null = null; + private attemptCount = 0; + private maxAttempts = 5; + + private state: GlobalWSState = { + isConnected: false, + isConnecting: false, + connectionError: null, + lastMessage: null, + }; + + private setState(partial: Partial) { + this.state = { ...this.state, ...partial }; + this.notifyListeners(); + } + + private notifyListeners() { + this.stateListeners.forEach((listener) => listener(this.state)); + } + + subscribe( + session: { user: { id: string } } | null, + subscription: Subscription, + ) { + this.sessionRef = session; + this.subscriptions.set(subscription.trialId, subscription); + + if (this.subscriptions.size === 1 && !this.ws) { + this.connect(); + } + + return () => { + this.subscriptions.delete(subscription.trialId); + // Don't auto-disconnect - keep global connection alive + }; + } + + sendMessage(message: OutgoingMessage) { + if (this.ws?.readyState === WebSocket.OPEN) { + this.ws.send(JSON.stringify(message)); + } + } + + connect() { + if ( + this.ws?.readyState === WebSocket.CONNECTING || + this.ws?.readyState === WebSocket.OPEN + ) { + return; + } + + if (!this.sessionRef?.user) { + this.setState({ connectionError: "No session", isConnecting: false }); + return; + } + + this.setState({ isConnecting: true, connectionError: null }); + + const token = btoa(JSON.stringify({ userId: this.sessionRef.user.id })); + const wsPort = process.env.NEXT_PUBLIC_WS_PORT || "3001"; + + // Collect all trial IDs from subscriptions + const trialIds = Array.from(this.subscriptions.keys()); + const trialIdParam = trialIds.length > 0 ? `&trialId=${trialIds[0]}` : ""; + const url = `ws://${typeof window !== "undefined" ? window.location.hostname : "localhost"}:${wsPort}/api/websocket?token=${token}${trialIdParam}`; + + try { + this.ws = new WebSocket(url); + + this.ws.onopen = () => { + console.log("[GlobalWS] Connected"); + this.setState({ isConnected: true, isConnecting: false }); + this.attemptCount = 0; + this.startHeartbeat(); + + // Subscribe to all subscribed trials + this.subscriptions.forEach((sub) => { + this.ws?.send( + JSON.stringify({ + type: "subscribe", + data: { trialId: sub.trialId }, + }), + ); + }); + + this.subscriptions.forEach((sub) => sub.onConnect?.()); + }; + + this.ws.onmessage = (event) => { + try { + const message = JSON.parse(event.data) as WebSocketMessage; + this.setState({ lastMessage: message }); + + if (message.type === "connection_established") { + const data = (message as ConnectionEstablishedMessage).data; + const sub = this.subscriptions.get(data.trialId); + if (sub) { + sub.onMessage?.(message); + } + } else if ( + message.type === "trial_event" || + message.type === "trial_status" + ) { + const data = (message as TrialEventMessage).data; + const event = data.event as { trialId?: string }; + if (event?.trialId) { + const sub = this.subscriptions.get(event.trialId); + sub?.onMessage?.(message); + } + } else { + // Broadcast to all subscriptions + this.subscriptions.forEach((sub) => sub.onMessage?.(message)); + } + } catch (error) { + console.error("[GlobalWS] Failed to parse message:", error); + } + }; + + this.ws.onclose = (event) => { + console.log("[GlobalWS] Disconnected:", event.code); + this.setState({ isConnected: false, isConnecting: false }); + this.stopHeartbeat(); + this.subscriptions.forEach((sub) => sub.onDisconnect?.()); + + // Auto-reconnect if not intentionally closed + if (event.code !== 1000 && this.subscriptions.size > 0) { + this.scheduleReconnect(); + } + }; + + this.ws.onerror = (error) => { + console.error("[GlobalWS] Error:", error); + this.setState({ + connectionError: "Connection error", + isConnecting: false, + }); + this.subscriptions.forEach((sub) => + sub.onError?.(new Event("ws_error")), + ); + }; + } catch (error) { + console.error("[GlobalWS] Failed to create:", error); + this.setState({ + connectionError: "Failed to create connection", + isConnecting: false, + }); + } + } + + disconnect() { + if (this.reconnectTimeout) { + clearTimeout(this.reconnectTimeout); + this.reconnectTimeout = null; + } + this.stopHeartbeat(); + if (this.ws) { + this.ws.close(1000, "Manual disconnect"); + this.ws = null; + } + this.setState({ isConnected: false, isConnecting: false }); + } + + private startHeartbeat() { + this.heartbeatInterval = setInterval(() => { + if (this.ws?.readyState === WebSocket.OPEN) { + this.ws.send(JSON.stringify({ type: "heartbeat", data: {} })); + } + }, 30000); + } + + private stopHeartbeat() { + if (this.heartbeatInterval) { + clearInterval(this.heartbeatInterval); + this.heartbeatInterval = null; + } + } + + private scheduleReconnect() { + if (this.attemptCount >= this.maxAttempts) { + this.setState({ connectionError: "Max reconnection attempts reached" }); + return; + } + + const delay = Math.min(30000, 1000 * Math.pow(1.5, this.attemptCount)); + this.attemptCount++; + + console.log( + `[GlobalWS] Reconnecting in ${delay}ms (attempt ${this.attemptCount})`, + ); + + this.reconnectTimeout = setTimeout(() => { + if (this.subscriptions.size > 0) { + this.connect(); + } + }, delay); + } + + getState(): GlobalWSState { + return this.state; + } + + addListener(listener: StateListener) { + this.stateListeners.add(listener); + return () => this.stateListeners.delete(listener); + } +} + +const globalWS = new GlobalWebSocketManager(); + +export interface UseGlobalWebSocketOptions { + trialId: string; + onMessage?: (message: WebSocketMessage) => void; + onConnect?: () => void; + onDisconnect?: () => void; + onError?: (error: Event) => void; +} + +export interface UseGlobalWebSocketReturn { isConnected: boolean; isConnecting: boolean; connectionError: string | null; @@ -119,333 +382,66 @@ export interface UseWebSocketReturn { lastMessage: WebSocketMessage | null; } -export function useWebSocket({ +export function useGlobalWebSocket({ trialId, onMessage, onConnect, onDisconnect, onError, - reconnectAttempts = 5, - reconnectInterval = 3000, - heartbeatInterval = 30000, -}: UseWebSocketOptions): UseWebSocketReturn { +}: UseGlobalWebSocketOptions): UseGlobalWebSocketReturn { const { data: session } = useSession(); - const [isConnected, setIsConnected] = useState(false); - const [isConnecting, setIsConnecting] = useState(false); + const [isConnected, setIsConnected] = useState(false); + const [isConnecting, setIsConnecting] = useState(false); const [connectionError, setConnectionError] = useState(null); - const [hasAttemptedConnection, setHasAttemptedConnection] = - useState(false); const [lastMessage, setLastMessage] = useState(null); - const wsRef = useRef(null); - const reconnectTimeoutRef = useRef(null); - const heartbeatTimeoutRef = useRef(null); - const attemptCountRef = useRef(0); - const mountedRef = useRef(true); - const connectionStableTimeoutRef = useRef(null); + const onMessageRef = useRef(onMessage); + const onConnectRef = useRef(onConnect); + const onDisconnectRef = useRef(onDisconnect); + const onErrorRef = useRef(onError); - // Generate auth token (simplified - in production use proper JWT) - const getAuthToken = useCallback((): string | null => { - if (!session?.user) return null; - // In production, this would be a proper JWT token - return btoa( - JSON.stringify({ userId: session.user.id, timestamp: Date.now() }), - ); - }, [session]); + onMessageRef.current = onMessage; + onConnectRef.current = onConnect; + onDisconnectRef.current = onDisconnect; + onErrorRef.current = onError; - const sendMessage = useCallback((message: OutgoingMessage): void => { - if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) { - wsRef.current.send(JSON.stringify(message)); - } else { - console.warn("WebSocket not connected, message not sent:", message); - } - }, []); - - const sendHeartbeat = useCallback((): void => { - sendMessage({ type: "heartbeat", data: {} }); - }, [sendMessage]); - - const scheduleHeartbeat = useCallback((): void => { - if (heartbeatTimeoutRef.current) { - clearTimeout(heartbeatTimeoutRef.current); - } - heartbeatTimeoutRef.current = setTimeout(() => { - if (isConnected && mountedRef.current) { - sendHeartbeat(); - scheduleHeartbeat(); - } - }, heartbeatInterval); - }, [isConnected, sendHeartbeat, heartbeatInterval]); - - const handleMessage = useCallback( - (event: MessageEvent): void => { - try { - const message = JSON.parse(event.data) as WebSocketMessage; - setLastMessage(message); - - // Handle system messages - switch (message.type) { - case "connection_established": { - console.log( - "WebSocket connection established:", - (message as ConnectionEstablishedMessage).data, - ); - setIsConnected(true); - setIsConnecting(false); - setConnectionError(null); - attemptCountRef.current = 0; - scheduleHeartbeat(); - onConnect?.(); - break; - } - - case "heartbeat_response": - // Heartbeat acknowledged, connection is alive - break; - - case "error": { - console.error("WebSocket server error:", message); - const msg = - (message as ErrorMessage).data?.message ?? "Server error"; - setConnectionError(msg); - onError?.(new Event("server_error")); - break; - } - - default: - // Pass to user-defined message handler - onMessage?.(message); - break; - } - } catch (error) { - console.error("Error parsing WebSocket message:", error); - setConnectionError("Failed to parse message"); - } - }, - [onMessage, onConnect, onError, scheduleHeartbeat], - ); - - const handleClose = useCallback( - (event: CloseEvent): void => { - console.log("WebSocket connection closed:", event.code, event.reason); - setIsConnected(false); - setIsConnecting(false); - - if (heartbeatTimeoutRef.current) { - clearTimeout(heartbeatTimeoutRef.current); - } - - onDisconnect?.(); - - // Attempt reconnection if not manually closed and component is still mounted - // In development, don't aggressively reconnect to prevent UI flashing - if ( - event.code !== 1000 && - mountedRef.current && - attemptCountRef.current < reconnectAttempts && - process.env.NODE_ENV !== "development" - ) { - attemptCountRef.current++; - const delay = - reconnectInterval * Math.pow(1.5, attemptCountRef.current - 1); // Exponential backoff - - console.log( - `Attempting reconnection ${attemptCountRef.current}/${reconnectAttempts} in ${delay}ms`, - ); - setConnectionError( - `Connection lost. Reconnecting... (${attemptCountRef.current}/${reconnectAttempts})`, - ); - - reconnectTimeoutRef.current = setTimeout(() => { - if (mountedRef.current) { - attemptCountRef.current = 0; - setIsConnecting(true); - setConnectionError(null); - } - }, delay); - } else if (attemptCountRef.current >= reconnectAttempts) { - setConnectionError("Failed to reconnect after maximum attempts"); - } else if ( - process.env.NODE_ENV === "development" && - event.code !== 1000 - ) { - // In development, set a stable error message without reconnection attempts - setConnectionError("WebSocket unavailable - using polling mode"); - } - }, - [onDisconnect, reconnectAttempts, reconnectInterval], - ); - - const handleError = useCallback( - (event: Event): void => { - // In development, WebSocket failures are expected with Edge Runtime - if (process.env.NODE_ENV === "development") { - // Only set error state after the first failed attempt to prevent flashing - if (!hasAttemptedConnection) { - setHasAttemptedConnection(true); - // Debounce the error state to prevent UI flashing - if (connectionStableTimeoutRef.current) { - clearTimeout(connectionStableTimeoutRef.current); - } - connectionStableTimeoutRef.current = setTimeout(() => { - setConnectionError("WebSocket unavailable - using polling mode"); - setIsConnecting(false); - }, 1000); - } - } else { - console.error("WebSocket error:", event); - setConnectionError("Connection error"); + useEffect(() => { + const unsubscribe = globalWS.subscribe(session, { + trialId, + onMessage: (msg) => { + setLastMessage(msg); + onMessageRef.current?.(msg); + }, + onConnect: () => { + setIsConnected(true); setIsConnecting(false); - } - onError?.(event); - }, - [onError, hasAttemptedConnection], - ); + setConnectionError(null); + onConnectRef.current?.(); + }, + onDisconnect: () => { + setIsConnected(false); + onDisconnectRef.current?.(); + }, + onError: (err) => { + setConnectionError("Connection error"); + onErrorRef.current?.(err); + }, + }); - const connectInternal = useCallback((): void => { - if (!session?.user || !trialId) { - if (!hasAttemptedConnection) { - setConnectionError("Missing authentication or trial ID"); - setHasAttemptedConnection(true); - } - return; - } + return unsubscribe; + }, [trialId, session]); - if ( - wsRef.current && - (wsRef.current.readyState === WebSocket.CONNECTING || - wsRef.current.readyState === WebSocket.OPEN) - ) { - return; // Already connecting or connected - } - - const token = getAuthToken(); - if (!token) { - if (!hasAttemptedConnection) { - setConnectionError("Failed to generate auth token"); - setHasAttemptedConnection(true); - } - return; - } - - // Only show connecting state for the first attempt or if we've been stable - if (!hasAttemptedConnection || isConnected) { - setIsConnecting(true); - } - - // Clear any pending error updates - if (connectionStableTimeoutRef.current) { - clearTimeout(connectionStableTimeoutRef.current); - } - - setConnectionError(null); - - try { - // Use appropriate WebSocket URL based on environment - const protocol = window.location.protocol === "https:" ? "wss:" : "ws:"; - const wsUrl = `${protocol}//${window.location.host}/api/websocket?trialId=${trialId}&token=${token}`; - - wsRef.current = new WebSocket(wsUrl); - wsRef.current.onmessage = handleMessage; - wsRef.current.onclose = handleClose; - wsRef.current.onerror = handleError; - - wsRef.current.onopen = () => { - console.log("WebSocket connection opened"); - // Connection establishment is handled in handleMessage - }; - } catch (error) { - console.error("Failed to create WebSocket connection:", error); - if (!hasAttemptedConnection) { - setConnectionError("Failed to create connection"); - setHasAttemptedConnection(true); - } - setIsConnecting(false); - } - }, [ - session, - trialId, - getAuthToken, - handleMessage, - handleClose, - handleError, - hasAttemptedConnection, - isConnected, - ]); - - const disconnect = useCallback((): void => { - mountedRef.current = false; - - if (reconnectTimeoutRef.current) { - clearTimeout(reconnectTimeoutRef.current); - } - - if (heartbeatTimeoutRef.current) { - clearTimeout(heartbeatTimeoutRef.current); - } - - if (connectionStableTimeoutRef.current) { - clearTimeout(connectionStableTimeoutRef.current); - } - - if (wsRef.current) { - wsRef.current.close(1000, "Manual disconnect"); - wsRef.current = null; - } - - setIsConnected(false); - setIsConnecting(false); - setConnectionError(null); - setHasAttemptedConnection(false); - attemptCountRef.current = 0; + const sendMessage = useCallback((message: OutgoingMessage) => { + globalWS.sendMessage(message); }, []); - const reconnect = useCallback((): void => { - disconnect(); - mountedRef.current = true; - attemptCountRef.current = 0; - setHasAttemptedConnection(false); - setTimeout(() => { - if (mountedRef.current) { - void connectInternal(); - } - }, 100); // Small delay to ensure cleanup - }, [disconnect, connectInternal]); + const disconnect = useCallback(() => { + globalWS.disconnect(); + }, []); - // Effect to establish initial connection - useEffect(() => { - if (session?.user?.id && trialId) { - // In development, only attempt connection once to prevent flashing - if (process.env.NODE_ENV === "development" && hasAttemptedConnection) { - return; - } - - // Trigger reconnection if timeout was set - if (reconnectTimeoutRef.current) { - clearTimeout(reconnectTimeoutRef.current); - reconnectTimeoutRef.current = null; - void connectInternal(); - } else { - void connectInternal(); - } - } - - return () => { - mountedRef.current = false; - disconnect(); - }; - }, [session?.user?.id, trialId, hasAttemptedConnection]); - - // Cleanup on unmount - useEffect(() => { - return () => { - mountedRef.current = false; - if (connectionStableTimeoutRef.current) { - clearTimeout(connectionStableTimeoutRef.current); - } - disconnect(); - }; - }, [disconnect]); + const reconnect = useCallback(() => { + globalWS.connect(); + }, []); return { isConnected, @@ -458,115 +454,180 @@ export function useWebSocket({ }; } -// Hook for trial-specific WebSocket events -export function useTrialWebSocket(trialId: string) { - const [trialEvents, setTrialEvents] = useState([]); - const [currentTrialStatus, setCurrentTrialStatus] = - useState(null); - const [wizardActions, setWizardActions] = useState([]); +// Legacy alias +export const useWebSocket = useGlobalWebSocket; - const handleMessage = useCallback((message: WebSocketMessage): void => { - // Add to events log - setTrialEvents((prev) => [...prev, message].slice(-100)); // Keep last 100 events +// Trial-specific hook +export interface TrialEvent { + id: string; + trialId: string; + eventType: string; + data: Record | null; + timestamp: Date; + createdBy?: string | null; +} - switch (message.type) { - case "trial_status": { - const data = (message as TrialStatusMessage).data; - setCurrentTrialStatus(data.trial); - break; +export interface TrialWebSocketState { + trialEvents: TrialEvent[]; + currentTrialStatus: TrialSnapshot | null; + wizardActions: WebSocketMessage[]; +} + +export function useTrialWebSocket( + trialId: string, + options?: { + onTrialEvent?: (event: TrialEvent) => void; + onStatusChange?: (status: TrialSnapshot) => void; + initialEvents?: TrialEvent[]; + initialStatus?: TrialSnapshot | null; + }, +) { + const [state, setState] = useState({ + trialEvents: options?.initialEvents ?? [], + currentTrialStatus: options?.initialStatus ?? null, + wizardActions: [], + }); + + const handleMessage = useCallback( + (message: WebSocketMessage): void => { + switch (message.type) { + case "trial_status": { + const data = (message as TrialStatusMessage).data; + const status = data.trial as TrialSnapshot; + setState((prev) => ({ + ...prev, + currentTrialStatus: status, + })); + options?.onStatusChange?.(status); + break; + } + + case "trial_events_snapshot": { + const data = (message as TrialEventsSnapshotMessage).data; + const events = ( + data.events as Array<{ + id: string; + trialId: string; + eventType: string; + data: Record | null; + timestamp: Date | string; + createdBy?: string | null; + }> + ).map((e) => ({ + ...e, + timestamp: + typeof e.timestamp === "string" + ? new Date(e.timestamp) + : e.timestamp, + })); + setState((prev) => ({ + ...prev, + trialEvents: events, + })); + break; + } + + case "trial_event": { + const data = (message as TrialEventMessage).data; + const event = data.event as { + id: string; + trialId: string; + eventType: string; + data: Record | null; + timestamp: Date | string; + createdBy?: string | null; + }; + const newEvent: TrialEvent = { + ...event, + timestamp: + typeof event.timestamp === "string" + ? new Date(event.timestamp) + : event.timestamp, + }; + setState((prev) => ({ + ...prev, + trialEvents: [...prev.trialEvents, newEvent].slice(-500), + })); + options?.onTrialEvent?.(newEvent); + break; + } + + case "trial_action_executed": + case "intervention_logged": + case "annotation_added": + case "step_changed": { + setState((prev) => ({ + ...prev, + wizardActions: [...prev.wizardActions, message].slice(-100), + })); + break; + } + + case "pong": + break; + + default: + if (process.env.NODE_ENV === "development") { + console.log(`[WS] Unknown message type: ${message.type}`); + } } + }, + [options], + ); - case "trial_action_executed": - case "intervention_logged": - case "step_changed": - setWizardActions((prev) => [...prev, message].slice(-50)); // Keep last 50 actions - break; - - case "step_changed": - // Handle step transitions (optional logging) - console.log("Step changed:", (message as StepChangedMessage).data); - break; - - default: - // Handle other trial-specific messages - break; - } - }, []); - - const webSocket = useWebSocket({ + const webSocket = useGlobalWebSocket({ trialId, onMessage: handleMessage, onConnect: () => { if (process.env.NODE_ENV === "development") { - console.log(`Connected to trial ${trialId} WebSocket`); + console.log(`[WS] Connected to trial ${trialId}`); } }, onDisconnect: () => { if (process.env.NODE_ENV === "development") { - console.log(`Disconnected from trial ${trialId} WebSocket`); + console.log(`[WS] Disconnected from trial ${trialId}`); } }, onError: () => { - // Suppress noisy WebSocket errors in development if (process.env.NODE_ENV !== "development") { - console.error(`Trial ${trialId} WebSocket connection failed`); + console.error(`[WS] Trial ${trialId} WebSocket connection failed`); } }, }); - // Request trial status after connection is established + // Request initial data after connection is established useEffect(() => { if (webSocket.isConnected) { webSocket.sendMessage({ type: "request_trial_status", data: {} }); + webSocket.sendMessage({ + type: "request_trial_events", + data: { limit: 500 }, + }); } - }, [webSocket.isConnected, webSocket]); + }, [webSocket.isConnected]); - // Trial-specific actions - const executeTrialAction = useCallback( - (actionType: string, actionData: Record): void => { - webSocket.sendMessage({ - type: "trial_action", - data: { - actionType, - ...actionData, - }, - }); - }, - [webSocket], - ); + // Helper to add an event locally (for optimistic updates) + const addLocalEvent = useCallback((event: TrialEvent) => { + setState((prev) => ({ + ...prev, + trialEvents: [...prev.trialEvents, event].slice(-500), + })); + }, []); - const logWizardIntervention = useCallback( - (interventionData: Record): void => { - webSocket.sendMessage({ - type: "wizard_intervention", - data: interventionData, - }); - }, - [webSocket], - ); - - const transitionStep = useCallback( - (stepData: { - from_step?: number; - to_step: number; - step_name?: string; - [k: string]: unknown; - }): void => { - webSocket.sendMessage({ - type: "step_transition", - data: stepData, - }); - }, - [webSocket], - ); + // Helper to update trial status locally + const updateLocalStatus = useCallback((status: TrialSnapshot) => { + setState((prev) => ({ + ...prev, + currentTrialStatus: status, + })); + }, []); return { ...webSocket, - trialEvents, - currentTrialStatus, - wizardActions, - executeTrialAction, - logWizardIntervention, - transitionStep, + trialEvents: state.trialEvents, + currentTrialStatus: state.currentTrialStatus, + wizardActions: state.wizardActions, + addLocalEvent, + updateLocalStatus, }; } diff --git a/src/server/api/routers/trials.ts b/src/server/api/routers/trials.ts index c77213a..12f2a5a 100755 --- a/src/server/api/routers/trials.ts +++ b/src/server/api/routers/trials.ts @@ -35,6 +35,7 @@ import { GetObjectCommand } from "@aws-sdk/client-s3"; import { getSignedUrl } from "@aws-sdk/s3-request-presigner"; import { env } from "~/env"; import { uploadFile } from "~/lib/storage/minio"; +import { wsManager } from "~/server/services/websocket-manager"; // Helper function to check if user has access to trial async function checkTrialAccess( @@ -591,6 +592,16 @@ export const trialsRouter = createTRPCRouter({ data: { userId }, }); + // Broadcast trial status update + await wsManager.broadcast(input.id, { + type: "trial_status", + data: { + trial: trial[0], + current_step_index: 0, + timestamp: Date.now(), + }, + }); + return trial[0]; }), @@ -643,6 +654,16 @@ export const trialsRouter = createTRPCRouter({ data: { userId, notes: input.notes }, }); + // Broadcast trial status update + await wsManager.broadcast(input.id, { + type: "trial_status", + data: { + trial, + current_step_index: 0, + timestamp: Date.now(), + }, + }); + return trial; }), @@ -696,6 +717,16 @@ export const trialsRouter = createTRPCRouter({ data: { userId, reason: input.reason }, }); + // Broadcast trial status update + await wsManager.broadcast(input.id, { + type: "trial_status", + data: { + trial: trial[0], + current_step_index: 0, + timestamp: Date.now(), + }, + }); + return trial[0]; }), @@ -846,6 +877,15 @@ export const trialsRouter = createTRPCRouter({ }) .returning(); + // Broadcast new event to all subscribers + await wsManager.broadcast(input.trialId, { + type: "trial_event", + data: { + event, + timestamp: Date.now(), + }, + }); + return event; }), @@ -881,6 +921,15 @@ export const trialsRouter = createTRPCRouter({ }) .returning(); + // Broadcast intervention to all subscribers + await wsManager.broadcast(input.trialId, { + type: "intervention_logged", + data: { + intervention, + timestamp: Date.now(), + }, + }); + return intervention; }), @@ -936,6 +985,15 @@ export const trialsRouter = createTRPCRouter({ }); } + // Broadcast annotation to all subscribers + await wsManager.broadcast(input.trialId, { + type: "annotation_added", + data: { + annotation, + timestamp: Date.now(), + }, + }); + return annotation; }), @@ -1302,20 +1360,33 @@ export const trialsRouter = createTRPCRouter({ } // Log the manual robot action execution - await db.insert(trialEvents).values({ - trialId: input.trialId, - eventType: "manual_robot_action", - actionId: null, // Ad-hoc action, not linked to a protocol action definition + const [event] = await db + .insert(trialEvents) + .values({ + trialId: input.trialId, + eventType: "manual_robot_action", + actionId: null, + data: { + userId, + pluginName: input.pluginName, + actionId: input.actionId, + parameters: input.parameters, + result: result.data, + duration: result.duration, + }, + timestamp: new Date(), + createdBy: userId, + }) + .returning(); + + // Broadcast robot action to all subscribers + await wsManager.broadcast(input.trialId, { + type: "trial_action_executed", data: { - userId, - pluginName: input.pluginName, - actionId: input.actionId, - parameters: input.parameters, - result: result.data, - duration: result.duration, + action_type: `${input.pluginName}.${input.actionId}`, + event, + timestamp: Date.now(), }, - timestamp: new Date(), - createdBy: userId, }); return { @@ -1347,21 +1418,34 @@ export const trialsRouter = createTRPCRouter({ "wizard", ]); - await db.insert(trialEvents).values({ - trialId: input.trialId, - eventType: "manual_robot_action", + const [event] = await db + .insert(trialEvents) + .values({ + trialId: input.trialId, + eventType: "manual_robot_action", + data: { + userId, + pluginName: input.pluginName, + actionId: input.actionId, + parameters: input.parameters, + result: input.result, + duration: input.duration, + error: input.error, + executionMode: "websocket_client", + }, + timestamp: new Date(), + createdBy: userId, + }) + .returning(); + + // Broadcast robot action to all subscribers + await wsManager.broadcast(input.trialId, { + type: "trial_action_executed", data: { - userId, - pluginName: input.pluginName, - actionId: input.actionId, - parameters: input.parameters, - result: input.result, - duration: input.duration, - error: input.error, - executionMode: "websocket_client", + action_type: `${input.pluginName}.${input.actionId}`, + event, + timestamp: Date.now(), }, - timestamp: new Date(), - createdBy: userId, }); return { success: true }; diff --git a/src/server/db/schema.ts b/src/server/db/schema.ts index 7b6a65c..b4dd5cc 100755 --- a/src/server/db/schema.ts +++ b/src/server/db/schema.ts @@ -485,6 +485,25 @@ export const trials = createTable("trial", { metadata: jsonb("metadata").default({}), }); +export const wsConnections = createTable("ws_connection", { + id: uuid("id").notNull().primaryKey().defaultRandom(), + trialId: uuid("trial_id") + .notNull() + .references(() => trials.id, { onDelete: "cascade" }), + clientId: text("client_id").notNull().unique(), + userId: text("user_id"), + connectedAt: timestamp("connected_at", { withTimezone: true }) + .default(sql`CURRENT_TIMESTAMP`) + .notNull(), +}); + +export const wsConnectionsRelations = relations(wsConnections, ({ one }) => ({ + trial: one(trials, { + fields: [wsConnections.trialId], + references: [trials.id], + }), +})); + export const steps = createTable( "step", { diff --git a/src/server/services/websocket-manager.ts b/src/server/services/websocket-manager.ts new file mode 100644 index 0000000..e7cc556 --- /dev/null +++ b/src/server/services/websocket-manager.ts @@ -0,0 +1,272 @@ +import { db } from "~/server/db"; +import { + trials, + trialEvents, + wsConnections, + experiments, +} from "~/server/db/schema"; +import { eq, sql } from "drizzle-orm"; + +interface ClientConnection { + socket: WebSocket; + trialId: string; + userId: string | null; + connectedAt: number; +} + +type OutgoingMessage = { + type: string; + data: Record; +}; + +class WebSocketManager { + private clients: Map = new Map(); + private heartbeatIntervals: Map> = + new Map(); + + private getTrialRoomClients(trialId: string): ClientConnection[] { + const clients: ClientConnection[] = []; + for (const [, client] of this.clients) { + if (client.trialId === trialId) { + clients.push(client); + } + } + return clients; + } + + addClient(clientId: string, connection: ClientConnection): void { + this.clients.set(clientId, connection); + console.log( + `[WS] Client ${clientId} added for trial ${connection.trialId}. Total: ${this.clients.size}`, + ); + } + + removeClient(clientId: string): void { + const client = this.clients.get(clientId); + if (client) { + console.log( + `[WS] Client ${clientId} removed from trial ${client.trialId}`, + ); + } + + const heartbeatInterval = this.heartbeatIntervals.get(clientId); + if (heartbeatInterval) { + clearInterval(heartbeatInterval); + this.heartbeatIntervals.delete(clientId); + } + + this.clients.delete(clientId); + } + + async subscribe( + clientId: string, + socket: WebSocket, + trialId: string, + userId: string | null, + ): Promise { + const client: ClientConnection = { + socket, + trialId, + userId, + connectedAt: Date.now(), + }; + + this.clients.set(clientId, client); + + const heartbeatInterval = setInterval(() => { + this.sendToClient(clientId, { type: "heartbeat", data: {} }); + }, 30000); + + this.heartbeatIntervals.set(clientId, heartbeatInterval); + + console.log( + `[WS] Client ${clientId} subscribed to trial ${trialId}. Total clients: ${this.clients.size}`, + ); + } + + unsubscribe(clientId: string): void { + const client = this.clients.get(clientId); + if (client) { + console.log( + `[WS] Client ${clientId} unsubscribed from trial ${client.trialId}`, + ); + } + + const heartbeatInterval = this.heartbeatIntervals.get(clientId); + if (heartbeatInterval) { + clearInterval(heartbeatInterval); + this.heartbeatIntervals.delete(clientId); + } + + this.clients.delete(clientId); + } + + sendToClient(clientId: string, message: OutgoingMessage): void { + const client = this.clients.get(clientId); + if (client?.socket.readyState === 1) { + try { + client.socket.send(JSON.stringify(message)); + } catch (error) { + console.error(`[WS] Error sending to client ${clientId}:`, error); + this.unsubscribe(clientId); + } + } + } + + async broadcast(trialId: string, message: OutgoingMessage): Promise { + const clients = this.getTrialRoomClients(trialId); + + if (clients.length === 0) { + return; + } + + const messageStr = JSON.stringify(message); + const disconnectedClients: string[] = []; + + for (const [clientId, client] of this.clients) { + if (client.trialId === trialId && client.socket.readyState === 1) { + try { + client.socket.send(messageStr); + } catch (error) { + console.error( + `[WS] Error broadcasting to client ${clientId}:`, + error, + ); + disconnectedClients.push(clientId); + } + } + } + + for (const clientId of disconnectedClients) { + this.unsubscribe(clientId); + } + + console.log( + `[WS] Broadcast to ${clients.length} clients for trial ${trialId}: ${message.type}`, + ); + } + + async broadcastToAll(message: OutgoingMessage): Promise { + const messageStr = JSON.stringify(message); + const disconnectedClients: string[] = []; + + for (const [clientId, client] of this.clients) { + if (client.socket.readyState === 1) { + try { + client.socket.send(messageStr); + } catch (error) { + console.error( + `[WS] Error broadcasting to client ${clientId}:`, + error, + ); + disconnectedClients.push(clientId); + } + } + } + + for (const clientId of disconnectedClients) { + this.unsubscribe(clientId); + } + } + + async getTrialStatus(trialId: string): Promise<{ + trial: { + id: string; + status: string; + startedAt: Date | null; + completedAt: Date | null; + }; + currentStepIndex: number; + } | null> { + const [trial] = await db + .select({ + id: trials.id, + status: trials.status, + startedAt: trials.startedAt, + completedAt: trials.completedAt, + }) + .from(trials) + .where(eq(trials.id, trialId)) + .limit(1); + + if (!trial) { + return null; + } + + return { + trial: { + id: trial.id, + status: trial.status, + startedAt: trial.startedAt, + completedAt: trial.completedAt, + }, + currentStepIndex: 0, + }; + } + + async getTrialEvents( + trialId: string, + limit: number = 100, + ): Promise { + const events = await db + .select() + .from(trialEvents) + .where(eq(trialEvents.trialId, trialId)) + .orderBy(trialEvents.timestamp) + .limit(limit); + + return events; + } + + getTrialStatusSync(trialId: string): { + trial: { + id: string; + status: string; + startedAt: Date | null; + completedAt: Date | null; + }; + currentStepIndex: number; + } | null { + return null; + } + + getTrialEventsSync(trialId: string, limit: number = 100): unknown[] { + return []; + } + + getConnectionCount(trialId?: string): number { + if (trialId) { + return this.getTrialRoomClients(trialId).length; + } + return this.clients.size; + } + + getConnectedTrialIds(): string[] { + const trialIds = new Set(); + for (const [, client] of this.clients) { + trialIds.add(client.trialId); + } + return Array.from(trialIds); + } + + async getTrialsWithActiveConnections(studyIds?: string[]): Promise { + const conditions = + studyIds && studyIds.length > 0 + ? sql`${wsConnections.trialId} IN ( + SELECT ${trials.id} FROM ${trials} + WHERE ${trials.experimentId} IN ( + SELECT ${experiments.id} FROM ${experiments} + WHERE ${experiments.studyId} IN (${sql.raw(studyIds.map((id) => `'${id}'`).join(","))}) + ) + )` + : undefined; + + const connections = await db + .selectDistinct({ trialId: wsConnections.trialId }) + .from(wsConnections); + + return connections.map((c) => c.trialId); + } +} + +export const wsManager = new WebSocketManager(); diff --git a/ws-server.ts b/ws-server.ts new file mode 100644 index 0000000..96f6923 --- /dev/null +++ b/ws-server.ts @@ -0,0 +1,192 @@ +import { serve, type ServerWebSocket } from "bun"; +import { wsManager } from "./src/server/services/websocket-manager"; +import { db } from "./src/server/db"; +import { wsConnections } from "./src/server/db/schema"; +import { eq } from "drizzle-orm"; + +const port = parseInt(process.env.WS_PORT || "3001", 10); + +interface WSData { + clientId: string; + trialId: string; + userId: string | null; +} + +function generateClientId(): string { + return `ws_${Date.now()}_${Math.random().toString(36).substring(2, 9)}`; +} + +async function recordConnection( + clientId: string, + trialId: string, + userId: string | null, +): Promise { + try { + await db.insert(wsConnections).values({ + clientId, + trialId, + userId, + }); + console.log(`[DB] Recorded connection for trial ${trialId}`); + } catch (error) { + console.error(`[DB] Failed to record connection:`, error); + } +} + +async function removeConnection(clientId: string): Promise { + try { + await db.delete(wsConnections).where(eq(wsConnections.clientId, clientId)); + console.log(`[DB] Removed connection ${clientId}`); + } catch (error) { + console.error(`[DB] Failed to remove connection:`, error); + } +} + +console.log(`Starting WebSocket server on port ${port}...`); + +serve({ + port, + fetch(req, server) { + const url = new URL(req.url); + + if (url.pathname === "/api/websocket") { + if (req.headers.get("upgrade") !== "websocket") { + return new Response("WebSocket upgrade required", { status: 426 }); + } + + const trialId = url.searchParams.get("trialId"); + const token = url.searchParams.get("token"); + + if (!trialId) { + return new Response("Missing trialId parameter", { status: 400 }); + } + + let userId: string | null = null; + if (token) { + try { + const tokenData = JSON.parse(atob(token)); + userId = tokenData.userId; + } catch { + return new Response("Invalid token", { status: 401 }); + } + } + + const clientId = generateClientId(); + const wsData: WSData = { clientId, trialId, userId }; + + const upgraded = server.upgrade(req, { data: wsData }); + + if (!upgraded) { + return new Response("WebSocket upgrade failed", { status: 500 }); + } + + return; + } + + return new Response("Not found", { status: 404 }); + }, + websocket: { + async open(ws: ServerWebSocket) { + const { clientId, trialId, userId } = ws.data; + + wsManager.addClient(clientId, { + socket: ws as unknown as WebSocket, + trialId, + userId, + connectedAt: Date.now(), + }); + + await recordConnection(clientId, trialId, userId); + + console.log( + `[WS] Client ${clientId} connected to trial ${trialId}. Total: ${wsManager.getConnectionCount()}`, + ); + + ws.send( + JSON.stringify({ + type: "connection_established", + data: { + trialId, + userId, + role: "connected", + connectedAt: Date.now(), + }, + }), + ); + }, + message(ws: ServerWebSocket, message) { + const { clientId, trialId } = ws.data; + + try { + const msg = JSON.parse(message.toString()); + + switch (msg.type) { + case "heartbeat": + ws.send( + JSON.stringify({ + type: "heartbeat_response", + data: { timestamp: Date.now() }, + }), + ); + break; + + case "request_trial_status": { + const status = wsManager.getTrialStatusSync(trialId); + ws.send( + JSON.stringify({ + type: "trial_status", + data: { + trial: status?.trial ?? null, + current_step_index: status?.currentStepIndex ?? 0, + timestamp: Date.now(), + }, + }), + ); + break; + } + + case "request_trial_events": { + const events = wsManager.getTrialEventsSync( + trialId, + msg.data?.limit ?? 100, + ); + ws.send( + JSON.stringify({ + type: "trial_events_snapshot", + data: { events, timestamp: Date.now() }, + }), + ); + break; + } + + case "ping": + ws.send( + JSON.stringify({ + type: "pong", + data: { timestamp: Date.now() }, + }), + ); + break; + + default: + console.log( + `[WS] Unknown message type from ${clientId}:`, + msg.type, + ); + } + } catch (error) { + console.error(`[WS] Error processing message from ${clientId}:`, error); + } + }, + async close(ws: ServerWebSocket) { + const { clientId } = ws.data; + console.log(`[WS] Client ${clientId} disconnected`); + wsManager.removeClient(clientId); + await removeConnection(clientId); + }, + }, +}); + +console.log( + `> WebSocket server running on ws://localhost:${port}/api/websocket`, +);