feat: implement WebSocket for real-time trial updates

- Create standalone WebSocket server (ws-server.ts) on port 3001 using Bun
- Add ws_connections table to track active connections in database
- Create global WebSocket manager that persists across component unmounts
- Fix useWebSocket hook to prevent infinite re-renders and use refs
- Fix TrialForm Select components with proper default values
- Add trialId to WebSocket URL for server-side tracking
- Update package.json with dev:ws script for separate WS server
This commit is contained in:
Sean O'Connor
2026-03-22 00:48:43 -04:00
parent 20d6d3de1a
commit a5762ec935
9 changed files with 1257 additions and 481 deletions

View File

@@ -11,7 +11,8 @@
"db:push": "drizzle-kit push",
"db:studio": "drizzle-kit studio",
"db:seed": "bun db:push && bun scripts/seed-dev.ts",
"dev": "next dev --turbo",
"dev": "bun run ws-server.ts & next dev --turbo",
"dev:ws": "bun run ws-server.ts",
"docker:up": "if [ \"$(uname)\" = \"Darwin\" ]; then colima start; fi && docker compose up -d",
"docker:down": "docker compose down && if [ \"$(uname)\" = \"Darwin\" ]; then colima stop; fi",
"format:check": "prettier --check \"**/*.{ts,tsx,js,jsx,mdx}\" --cache",

View File

@@ -0,0 +1,135 @@
import { NextRequest } from "next/server";
import { headers } from "next/headers";
import { wsManager } from "~/server/services/websocket-manager";
import { auth } from "~/lib/auth";
const clientConnections = new Map<
string,
{ socket: WebSocket; clientId: string }
>();
function generateClientId(): string {
return `ws_${Date.now()}_${Math.random().toString(36).substring(2, 9)}`;
}
export const runtime = "edge";
export const dynamic = "force-dynamic";
export async function GET(request: NextRequest) {
const url = new URL(request.url);
const trialId = url.searchParams.get("trialId");
const token = url.searchParams.get("token");
if (!trialId) {
return new Response("Missing trialId parameter", { status: 400 });
}
let userId: string | null = null;
try {
const session = await auth.api.getSession({
headers: await headers(),
});
if (session?.user?.id) {
userId = session.user.id;
}
} catch {
if (!token) {
return new Response("Authentication required", { status: 401 });
}
try {
const tokenData = JSON.parse(atob(token));
userId = tokenData.userId;
} catch {
return new Response("Invalid token", { status: 401 });
}
}
const pair = new WebSocketPair();
const clientId = generateClientId();
const serverWebSocket = Object.values(pair)[0] as WebSocket;
clientConnections.set(clientId, { socket: serverWebSocket, clientId });
await wsManager.subscribe(clientId, serverWebSocket, trialId, userId);
serverWebSocket.accept();
serverWebSocket.addEventListener("message", async (event) => {
try {
const message = JSON.parse(event.data as string);
switch (message.type) {
case "heartbeat":
wsManager.sendToClient(clientId, {
type: "heartbeat_response",
data: { timestamp: Date.now() },
});
break;
case "request_trial_status": {
const status = await wsManager.getTrialStatus(trialId);
wsManager.sendToClient(clientId, {
type: "trial_status",
data: {
trial: status?.trial ?? null,
current_step_index: status?.currentStepIndex ?? 0,
timestamp: Date.now(),
},
});
break;
}
case "request_trial_events": {
const events = await wsManager.getTrialEvents(
trialId,
message.data?.limit ?? 100,
);
wsManager.sendToClient(clientId, {
type: "trial_events_snapshot",
data: { events, timestamp: Date.now() },
});
break;
}
case "ping":
wsManager.sendToClient(clientId, {
type: "pong",
data: { timestamp: Date.now() },
});
break;
default:
console.log(
`[WS] Unknown message type from client ${clientId}:`,
message.type,
);
}
} catch (error) {
console.error(`[WS] Error processing message from ${clientId}:`, error);
}
});
serverWebSocket.addEventListener("close", () => {
wsManager.unsubscribe(clientId);
clientConnections.delete(clientId);
});
serverWebSocket.addEventListener("error", (error) => {
console.error(`[WS] Error for client ${clientId}:`, error);
wsManager.unsubscribe(clientId);
clientConnections.delete(clientId);
});
return new Response(null, {
status: 101,
webSocket: serverWebSocket,
} as ResponseInit);
}
declare global {
interface WebSocket {
accept(): void;
}
}

View File

@@ -163,6 +163,11 @@ export function TrialForm({ mode, trialId, studyId }: TrialFormProps) {
const form = useForm<TrialFormData>({
resolver: zodResolver(trialSchema),
defaultValues: {
experimentId: "" as any,
participantId: "" as any,
scheduledAt: new Date(),
wizardId: undefined,
notes: "",
sessionNumber: 1,
},
});
@@ -347,7 +352,7 @@ export function TrialForm({ mode, trialId, studyId }: TrialFormProps) {
<FormField>
<Label htmlFor="experimentId">Experiment *</Label>
<Select
value={form.watch("experimentId")}
value={form.watch("experimentId") ?? ""}
onValueChange={(value) => form.setValue("experimentId", value)}
disabled={experimentsLoading || mode === "edit"}
>
@@ -387,7 +392,7 @@ export function TrialForm({ mode, trialId, studyId }: TrialFormProps) {
<FormField>
<Label htmlFor="participantId">Participant *</Label>
<Select
value={form.watch("participantId")}
value={form.watch("participantId") ?? ""}
onValueChange={(value) => form.setValue("participantId", value)}
disabled={participantsLoading || mode === "edit"}
>

View File

@@ -32,6 +32,7 @@ import { WebcamPanel } from "./panels/WebcamPanel";
import { TrialStatusBar } from "./panels/TrialStatusBar";
import { api } from "~/trpc/react";
import { useWizardRos } from "~/hooks/useWizardRos";
import { useTrialWebSocket, type TrialEvent } from "~/hooks/useWebSocket";
import { toast } from "sonner";
import { useTour } from "~/components/onboarding/TourProvider";
@@ -252,59 +253,65 @@ export const WizardInterface = React.memo(function WizardInterface({
[setAutonomousLifeRaw],
);
// Use polling for trial status updates (no trial WebSocket server exists)
const { data: pollingData } = api.trials.get.useQuery(
{ id: trial.id },
{
refetchInterval: trial.status === "in_progress" ? 5000 : 15000,
staleTime: 2000,
refetchOnWindowFocus: false,
// Trial WebSocket for real-time updates
const {
isConnected: wsConnected,
connectionError: wsError,
trialEvents: wsTrialEvents,
currentTrialStatus,
addLocalEvent,
} = useTrialWebSocket(trial.id, {
onStatusChange: (status) => {
// Update local trial state when WebSocket reports status changes
setTrial((prev) => ({
...prev,
status: status.status,
startedAt: status.startedAt
? new Date(status.startedAt)
: prev.startedAt,
completedAt: status.completedAt
? new Date(status.completedAt)
: prev.completedAt,
}));
},
);
// Poll for trial events
const { data: fetchedEvents } = api.trials.getEvents.useQuery(
{ trialId: trial.id, limit: 100 },
{
refetchInterval: 3000,
staleTime: 1000,
onTrialEvent: (event) => {
// Optionally show toast for new events
if (event.eventType === "trial_started") {
toast.info("Trial started");
} else if (event.eventType === "trial_completed") {
toast.info("Trial completed");
} else if (event.eventType === "trial_aborted") {
toast.warning("Trial aborted");
}
},
);
});
// Update local trial state from polling only if changed
// Update trial state from WebSocket status
useEffect(() => {
if (pollingData && JSON.stringify(pollingData) !== JSON.stringify(trial)) {
// Only update if specific fields we care about have changed to avoid
// unnecessary re-renders that might cause UI flashing
if (
pollingData.status !== trial.status ||
pollingData.startedAt?.getTime() !== trial.startedAt?.getTime() ||
pollingData.completedAt?.getTime() !== trial.completedAt?.getTime()
) {
if (currentTrialStatus) {
setTrial((prev) => {
// Double check inside setter to be safe
if (
prev.status === pollingData.status &&
prev.startedAt?.getTime() === pollingData.startedAt?.getTime() &&
prev.completedAt?.getTime() === pollingData.completedAt?.getTime()
prev.status === currentTrialStatus.status &&
prev.startedAt?.getTime() ===
new Date(currentTrialStatus.startedAt ?? "").getTime() &&
prev.completedAt?.getTime() ===
new Date(currentTrialStatus.completedAt ?? "").getTime()
) {
return prev;
}
return {
...prev,
status: pollingData.status,
startedAt: pollingData.startedAt
? new Date(pollingData.startedAt)
status: currentTrialStatus.status,
startedAt: currentTrialStatus.startedAt
? new Date(currentTrialStatus.startedAt)
: prev.startedAt,
completedAt: pollingData.completedAt
? new Date(pollingData.completedAt)
completedAt: currentTrialStatus.completedAt
? new Date(currentTrialStatus.completedAt)
: prev.completedAt,
};
});
}
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [pollingData]);
}, [currentTrialStatus]);
// Auto-start trial on mount if scheduled
useEffect(() => {
@@ -313,7 +320,7 @@ export const WizardInterface = React.memo(function WizardInterface({
}
}, []); // Run once on mount
// Trial events from robot actions
// Trial events from WebSocket (and initial load)
const trialEvents = useMemo<
Array<{
type: string;
@@ -322,8 +329,8 @@ export const WizardInterface = React.memo(function WizardInterface({
message?: string;
}>
>(() => {
return (fetchedEvents ?? [])
.map((event) => {
return (wsTrialEvents ?? [])
.map((event: TrialEvent) => {
let message: string | undefined;
const eventData = event.data as any;
@@ -364,7 +371,7 @@ export const WizardInterface = React.memo(function WizardInterface({
};
})
.sort((a, b) => b.timestamp.getTime() - a.timestamp.getTime()); // Newest first
}, [fetchedEvents]);
}, [wsTrialEvents]);
// Transform experiment steps to component format
const steps: StepData[] = useMemo(

View File

@@ -1,7 +1,5 @@
"use client";
/* eslint-disable react-hooks/exhaustive-deps */
import { useSession } from "~/lib/auth-client";
import { useCallback, useEffect, useRef, useState } from "react";
@@ -56,10 +54,42 @@ interface TrialActionExecutedMessage {
interface InterventionLoggedMessage {
type: "intervention_logged";
data: {
intervention?: unknown;
timestamp: number;
} & Record<string, unknown>;
}
interface TrialEventMessage {
type: "trial_event";
data: {
event: unknown;
timestamp: number;
};
}
interface TrialEventsSnapshotMessage {
type: "trial_events_snapshot";
data: {
events: unknown[];
timestamp: number;
};
}
interface AnnotationAddedMessage {
type: "annotation_added";
data: {
annotation: unknown;
timestamp: number;
};
}
interface PongMessage {
type: "pong";
data: {
timestamp: number;
};
}
interface StepChangedMessage {
type: "step_changed";
data: {
@@ -83,6 +113,10 @@ type KnownInboundMessage =
| TrialStatusMessage
| TrialActionExecutedMessage
| InterventionLoggedMessage
| TrialEventMessage
| TrialEventsSnapshotMessage
| AnnotationAddedMessage
| PongMessage
| StepChangedMessage
| ErrorMessage;
@@ -98,18 +132,247 @@ export interface OutgoingMessage {
data: Record<string, unknown>;
}
export interface UseWebSocketOptions {
interface Subscription {
trialId: string;
onMessage?: (message: WebSocketMessage) => void;
onConnect?: () => void;
onDisconnect?: () => void;
onError?: (error: Event) => void;
reconnectAttempts?: number;
reconnectInterval?: number;
heartbeatInterval?: number;
}
export interface UseWebSocketReturn {
interface GlobalWSState {
isConnected: boolean;
isConnecting: boolean;
connectionError: string | null;
lastMessage: WebSocketMessage | null;
}
type StateListener = (state: GlobalWSState) => void;
class GlobalWebSocketManager {
private ws: WebSocket | null = null;
private subscriptions: Map<string, Subscription> = new Map();
private stateListeners: Set<StateListener> = new Set();
private sessionRef: { user: { id: string } } | null = null;
private heartbeatInterval: ReturnType<typeof setInterval> | null = null;
private reconnectTimeout: ReturnType<typeof setTimeout> | null = null;
private attemptCount = 0;
private maxAttempts = 5;
private state: GlobalWSState = {
isConnected: false,
isConnecting: false,
connectionError: null,
lastMessage: null,
};
private setState(partial: Partial<GlobalWSState>) {
this.state = { ...this.state, ...partial };
this.notifyListeners();
}
private notifyListeners() {
this.stateListeners.forEach((listener) => listener(this.state));
}
subscribe(
session: { user: { id: string } } | null,
subscription: Subscription,
) {
this.sessionRef = session;
this.subscriptions.set(subscription.trialId, subscription);
if (this.subscriptions.size === 1 && !this.ws) {
this.connect();
}
return () => {
this.subscriptions.delete(subscription.trialId);
// Don't auto-disconnect - keep global connection alive
};
}
sendMessage(message: OutgoingMessage) {
if (this.ws?.readyState === WebSocket.OPEN) {
this.ws.send(JSON.stringify(message));
}
}
connect() {
if (
this.ws?.readyState === WebSocket.CONNECTING ||
this.ws?.readyState === WebSocket.OPEN
) {
return;
}
if (!this.sessionRef?.user) {
this.setState({ connectionError: "No session", isConnecting: false });
return;
}
this.setState({ isConnecting: true, connectionError: null });
const token = btoa(JSON.stringify({ userId: this.sessionRef.user.id }));
const wsPort = process.env.NEXT_PUBLIC_WS_PORT || "3001";
// Collect all trial IDs from subscriptions
const trialIds = Array.from(this.subscriptions.keys());
const trialIdParam = trialIds.length > 0 ? `&trialId=${trialIds[0]}` : "";
const url = `ws://${typeof window !== "undefined" ? window.location.hostname : "localhost"}:${wsPort}/api/websocket?token=${token}${trialIdParam}`;
try {
this.ws = new WebSocket(url);
this.ws.onopen = () => {
console.log("[GlobalWS] Connected");
this.setState({ isConnected: true, isConnecting: false });
this.attemptCount = 0;
this.startHeartbeat();
// Subscribe to all subscribed trials
this.subscriptions.forEach((sub) => {
this.ws?.send(
JSON.stringify({
type: "subscribe",
data: { trialId: sub.trialId },
}),
);
});
this.subscriptions.forEach((sub) => sub.onConnect?.());
};
this.ws.onmessage = (event) => {
try {
const message = JSON.parse(event.data) as WebSocketMessage;
this.setState({ lastMessage: message });
if (message.type === "connection_established") {
const data = (message as ConnectionEstablishedMessage).data;
const sub = this.subscriptions.get(data.trialId);
if (sub) {
sub.onMessage?.(message);
}
} else if (
message.type === "trial_event" ||
message.type === "trial_status"
) {
const data = (message as TrialEventMessage).data;
const event = data.event as { trialId?: string };
if (event?.trialId) {
const sub = this.subscriptions.get(event.trialId);
sub?.onMessage?.(message);
}
} else {
// Broadcast to all subscriptions
this.subscriptions.forEach((sub) => sub.onMessage?.(message));
}
} catch (error) {
console.error("[GlobalWS] Failed to parse message:", error);
}
};
this.ws.onclose = (event) => {
console.log("[GlobalWS] Disconnected:", event.code);
this.setState({ isConnected: false, isConnecting: false });
this.stopHeartbeat();
this.subscriptions.forEach((sub) => sub.onDisconnect?.());
// Auto-reconnect if not intentionally closed
if (event.code !== 1000 && this.subscriptions.size > 0) {
this.scheduleReconnect();
}
};
this.ws.onerror = (error) => {
console.error("[GlobalWS] Error:", error);
this.setState({
connectionError: "Connection error",
isConnecting: false,
});
this.subscriptions.forEach((sub) =>
sub.onError?.(new Event("ws_error")),
);
};
} catch (error) {
console.error("[GlobalWS] Failed to create:", error);
this.setState({
connectionError: "Failed to create connection",
isConnecting: false,
});
}
}
disconnect() {
if (this.reconnectTimeout) {
clearTimeout(this.reconnectTimeout);
this.reconnectTimeout = null;
}
this.stopHeartbeat();
if (this.ws) {
this.ws.close(1000, "Manual disconnect");
this.ws = null;
}
this.setState({ isConnected: false, isConnecting: false });
}
private startHeartbeat() {
this.heartbeatInterval = setInterval(() => {
if (this.ws?.readyState === WebSocket.OPEN) {
this.ws.send(JSON.stringify({ type: "heartbeat", data: {} }));
}
}, 30000);
}
private stopHeartbeat() {
if (this.heartbeatInterval) {
clearInterval(this.heartbeatInterval);
this.heartbeatInterval = null;
}
}
private scheduleReconnect() {
if (this.attemptCount >= this.maxAttempts) {
this.setState({ connectionError: "Max reconnection attempts reached" });
return;
}
const delay = Math.min(30000, 1000 * Math.pow(1.5, this.attemptCount));
this.attemptCount++;
console.log(
`[GlobalWS] Reconnecting in ${delay}ms (attempt ${this.attemptCount})`,
);
this.reconnectTimeout = setTimeout(() => {
if (this.subscriptions.size > 0) {
this.connect();
}
}, delay);
}
getState(): GlobalWSState {
return this.state;
}
addListener(listener: StateListener) {
this.stateListeners.add(listener);
return () => this.stateListeners.delete(listener);
}
}
const globalWS = new GlobalWebSocketManager();
export interface UseGlobalWebSocketOptions {
trialId: string;
onMessage?: (message: WebSocketMessage) => void;
onConnect?: () => void;
onDisconnect?: () => void;
onError?: (error: Event) => void;
}
export interface UseGlobalWebSocketReturn {
isConnected: boolean;
isConnecting: boolean;
connectionError: string | null;
@@ -119,333 +382,66 @@ export interface UseWebSocketReturn {
lastMessage: WebSocketMessage | null;
}
export function useWebSocket({
export function useGlobalWebSocket({
trialId,
onMessage,
onConnect,
onDisconnect,
onError,
reconnectAttempts = 5,
reconnectInterval = 3000,
heartbeatInterval = 30000,
}: UseWebSocketOptions): UseWebSocketReturn {
}: UseGlobalWebSocketOptions): UseGlobalWebSocketReturn {
const { data: session } = useSession();
const [isConnected, setIsConnected] = useState<boolean>(false);
const [isConnecting, setIsConnecting] = useState<boolean>(false);
const [isConnected, setIsConnected] = useState(false);
const [isConnecting, setIsConnecting] = useState(false);
const [connectionError, setConnectionError] = useState<string | null>(null);
const [hasAttemptedConnection, setHasAttemptedConnection] =
useState<boolean>(false);
const [lastMessage, setLastMessage] = useState<WebSocketMessage | null>(null);
const wsRef = useRef<WebSocket | null>(null);
const reconnectTimeoutRef = useRef<NodeJS.Timeout | null>(null);
const heartbeatTimeoutRef = useRef<NodeJS.Timeout | null>(null);
const attemptCountRef = useRef<number>(0);
const mountedRef = useRef<boolean>(true);
const connectionStableTimeoutRef = useRef<NodeJS.Timeout | null>(null);
const onMessageRef = useRef(onMessage);
const onConnectRef = useRef(onConnect);
const onDisconnectRef = useRef(onDisconnect);
const onErrorRef = useRef(onError);
// Generate auth token (simplified - in production use proper JWT)
const getAuthToken = useCallback((): string | null => {
if (!session?.user) return null;
// In production, this would be a proper JWT token
return btoa(
JSON.stringify({ userId: session.user.id, timestamp: Date.now() }),
);
}, [session]);
onMessageRef.current = onMessage;
onConnectRef.current = onConnect;
onDisconnectRef.current = onDisconnect;
onErrorRef.current = onError;
const sendMessage = useCallback((message: OutgoingMessage): void => {
if (wsRef.current && wsRef.current.readyState === WebSocket.OPEN) {
wsRef.current.send(JSON.stringify(message));
} else {
console.warn("WebSocket not connected, message not sent:", message);
}
}, []);
const sendHeartbeat = useCallback((): void => {
sendMessage({ type: "heartbeat", data: {} });
}, [sendMessage]);
const scheduleHeartbeat = useCallback((): void => {
if (heartbeatTimeoutRef.current) {
clearTimeout(heartbeatTimeoutRef.current);
}
heartbeatTimeoutRef.current = setTimeout(() => {
if (isConnected && mountedRef.current) {
sendHeartbeat();
scheduleHeartbeat();
}
}, heartbeatInterval);
}, [isConnected, sendHeartbeat, heartbeatInterval]);
const handleMessage = useCallback(
(event: MessageEvent<string>): void => {
try {
const message = JSON.parse(event.data) as WebSocketMessage;
setLastMessage(message);
// Handle system messages
switch (message.type) {
case "connection_established": {
console.log(
"WebSocket connection established:",
(message as ConnectionEstablishedMessage).data,
);
useEffect(() => {
const unsubscribe = globalWS.subscribe(session, {
trialId,
onMessage: (msg) => {
setLastMessage(msg);
onMessageRef.current?.(msg);
},
onConnect: () => {
setIsConnected(true);
setIsConnecting(false);
setConnectionError(null);
attemptCountRef.current = 0;
scheduleHeartbeat();
onConnect?.();
break;
}
case "heartbeat_response":
// Heartbeat acknowledged, connection is alive
break;
case "error": {
console.error("WebSocket server error:", message);
const msg =
(message as ErrorMessage).data?.message ?? "Server error";
setConnectionError(msg);
onError?.(new Event("server_error"));
break;
}
default:
// Pass to user-defined message handler
onMessage?.(message);
break;
}
} catch (error) {
console.error("Error parsing WebSocket message:", error);
setConnectionError("Failed to parse message");
}
onConnectRef.current?.();
},
[onMessage, onConnect, onError, scheduleHeartbeat],
);
const handleClose = useCallback(
(event: CloseEvent): void => {
console.log("WebSocket connection closed:", event.code, event.reason);
onDisconnect: () => {
setIsConnected(false);
setIsConnecting(false);
if (heartbeatTimeoutRef.current) {
clearTimeout(heartbeatTimeoutRef.current);
}
onDisconnect?.();
// Attempt reconnection if not manually closed and component is still mounted
// In development, don't aggressively reconnect to prevent UI flashing
if (
event.code !== 1000 &&
mountedRef.current &&
attemptCountRef.current < reconnectAttempts &&
process.env.NODE_ENV !== "development"
) {
attemptCountRef.current++;
const delay =
reconnectInterval * Math.pow(1.5, attemptCountRef.current - 1); // Exponential backoff
console.log(
`Attempting reconnection ${attemptCountRef.current}/${reconnectAttempts} in ${delay}ms`,
);
setConnectionError(
`Connection lost. Reconnecting... (${attemptCountRef.current}/${reconnectAttempts})`,
);
reconnectTimeoutRef.current = setTimeout(() => {
if (mountedRef.current) {
attemptCountRef.current = 0;
setIsConnecting(true);
setConnectionError(null);
}
}, delay);
} else if (attemptCountRef.current >= reconnectAttempts) {
setConnectionError("Failed to reconnect after maximum attempts");
} else if (
process.env.NODE_ENV === "development" &&
event.code !== 1000
) {
// In development, set a stable error message without reconnection attempts
setConnectionError("WebSocket unavailable - using polling mode");
}
onDisconnectRef.current?.();
},
[onDisconnect, reconnectAttempts, reconnectInterval],
);
const handleError = useCallback(
(event: Event): void => {
// In development, WebSocket failures are expected with Edge Runtime
if (process.env.NODE_ENV === "development") {
// Only set error state after the first failed attempt to prevent flashing
if (!hasAttemptedConnection) {
setHasAttemptedConnection(true);
// Debounce the error state to prevent UI flashing
if (connectionStableTimeoutRef.current) {
clearTimeout(connectionStableTimeoutRef.current);
}
connectionStableTimeoutRef.current = setTimeout(() => {
setConnectionError("WebSocket unavailable - using polling mode");
setIsConnecting(false);
}, 1000);
}
} else {
console.error("WebSocket error:", event);
onError: (err) => {
setConnectionError("Connection error");
setIsConnecting(false);
}
onError?.(event);
onErrorRef.current?.(err);
},
[onError, hasAttemptedConnection],
);
});
const connectInternal = useCallback((): void => {
if (!session?.user || !trialId) {
if (!hasAttemptedConnection) {
setConnectionError("Missing authentication or trial ID");
setHasAttemptedConnection(true);
}
return;
}
return unsubscribe;
}, [trialId, session]);
if (
wsRef.current &&
(wsRef.current.readyState === WebSocket.CONNECTING ||
wsRef.current.readyState === WebSocket.OPEN)
) {
return; // Already connecting or connected
}
const token = getAuthToken();
if (!token) {
if (!hasAttemptedConnection) {
setConnectionError("Failed to generate auth token");
setHasAttemptedConnection(true);
}
return;
}
// Only show connecting state for the first attempt or if we've been stable
if (!hasAttemptedConnection || isConnected) {
setIsConnecting(true);
}
// Clear any pending error updates
if (connectionStableTimeoutRef.current) {
clearTimeout(connectionStableTimeoutRef.current);
}
setConnectionError(null);
try {
// Use appropriate WebSocket URL based on environment
const protocol = window.location.protocol === "https:" ? "wss:" : "ws:";
const wsUrl = `${protocol}//${window.location.host}/api/websocket?trialId=${trialId}&token=${token}`;
wsRef.current = new WebSocket(wsUrl);
wsRef.current.onmessage = handleMessage;
wsRef.current.onclose = handleClose;
wsRef.current.onerror = handleError;
wsRef.current.onopen = () => {
console.log("WebSocket connection opened");
// Connection establishment is handled in handleMessage
};
} catch (error) {
console.error("Failed to create WebSocket connection:", error);
if (!hasAttemptedConnection) {
setConnectionError("Failed to create connection");
setHasAttemptedConnection(true);
}
setIsConnecting(false);
}
}, [
session,
trialId,
getAuthToken,
handleMessage,
handleClose,
handleError,
hasAttemptedConnection,
isConnected,
]);
const disconnect = useCallback((): void => {
mountedRef.current = false;
if (reconnectTimeoutRef.current) {
clearTimeout(reconnectTimeoutRef.current);
}
if (heartbeatTimeoutRef.current) {
clearTimeout(heartbeatTimeoutRef.current);
}
if (connectionStableTimeoutRef.current) {
clearTimeout(connectionStableTimeoutRef.current);
}
if (wsRef.current) {
wsRef.current.close(1000, "Manual disconnect");
wsRef.current = null;
}
setIsConnected(false);
setIsConnecting(false);
setConnectionError(null);
setHasAttemptedConnection(false);
attemptCountRef.current = 0;
const sendMessage = useCallback((message: OutgoingMessage) => {
globalWS.sendMessage(message);
}, []);
const reconnect = useCallback((): void => {
disconnect();
mountedRef.current = true;
attemptCountRef.current = 0;
setHasAttemptedConnection(false);
setTimeout(() => {
if (mountedRef.current) {
void connectInternal();
}
}, 100); // Small delay to ensure cleanup
}, [disconnect, connectInternal]);
const disconnect = useCallback(() => {
globalWS.disconnect();
}, []);
// Effect to establish initial connection
useEffect(() => {
if (session?.user?.id && trialId) {
// In development, only attempt connection once to prevent flashing
if (process.env.NODE_ENV === "development" && hasAttemptedConnection) {
return;
}
// Trigger reconnection if timeout was set
if (reconnectTimeoutRef.current) {
clearTimeout(reconnectTimeoutRef.current);
reconnectTimeoutRef.current = null;
void connectInternal();
} else {
void connectInternal();
}
}
return () => {
mountedRef.current = false;
disconnect();
};
}, [session?.user?.id, trialId, hasAttemptedConnection]);
// Cleanup on unmount
useEffect(() => {
return () => {
mountedRef.current = false;
if (connectionStableTimeoutRef.current) {
clearTimeout(connectionStableTimeoutRef.current);
}
disconnect();
};
}, [disconnect]);
const reconnect = useCallback(() => {
globalWS.connect();
}, []);
return {
isConnected,
@@ -458,115 +454,180 @@ export function useWebSocket({
};
}
// Hook for trial-specific WebSocket events
export function useTrialWebSocket(trialId: string) {
const [trialEvents, setTrialEvents] = useState<WebSocketMessage[]>([]);
const [currentTrialStatus, setCurrentTrialStatus] =
useState<TrialSnapshot | null>(null);
const [wizardActions, setWizardActions] = useState<WebSocketMessage[]>([]);
// Legacy alias
export const useWebSocket = useGlobalWebSocket;
const handleMessage = useCallback((message: WebSocketMessage): void => {
// Add to events log
setTrialEvents((prev) => [...prev, message].slice(-100)); // Keep last 100 events
// Trial-specific hook
export interface TrialEvent {
id: string;
trialId: string;
eventType: string;
data: Record<string, unknown> | null;
timestamp: Date;
createdBy?: string | null;
}
export interface TrialWebSocketState {
trialEvents: TrialEvent[];
currentTrialStatus: TrialSnapshot | null;
wizardActions: WebSocketMessage[];
}
export function useTrialWebSocket(
trialId: string,
options?: {
onTrialEvent?: (event: TrialEvent) => void;
onStatusChange?: (status: TrialSnapshot) => void;
initialEvents?: TrialEvent[];
initialStatus?: TrialSnapshot | null;
},
) {
const [state, setState] = useState<TrialWebSocketState>({
trialEvents: options?.initialEvents ?? [],
currentTrialStatus: options?.initialStatus ?? null,
wizardActions: [],
});
const handleMessage = useCallback(
(message: WebSocketMessage): void => {
switch (message.type) {
case "trial_status": {
const data = (message as TrialStatusMessage).data;
setCurrentTrialStatus(data.trial);
const status = data.trial as TrialSnapshot;
setState((prev) => ({
...prev,
currentTrialStatus: status,
}));
options?.onStatusChange?.(status);
break;
}
case "trial_events_snapshot": {
const data = (message as TrialEventsSnapshotMessage).data;
const events = (
data.events as Array<{
id: string;
trialId: string;
eventType: string;
data: Record<string, unknown> | null;
timestamp: Date | string;
createdBy?: string | null;
}>
).map((e) => ({
...e,
timestamp:
typeof e.timestamp === "string"
? new Date(e.timestamp)
: e.timestamp,
}));
setState((prev) => ({
...prev,
trialEvents: events,
}));
break;
}
case "trial_event": {
const data = (message as TrialEventMessage).data;
const event = data.event as {
id: string;
trialId: string;
eventType: string;
data: Record<string, unknown> | null;
timestamp: Date | string;
createdBy?: string | null;
};
const newEvent: TrialEvent = {
...event,
timestamp:
typeof event.timestamp === "string"
? new Date(event.timestamp)
: event.timestamp,
};
setState((prev) => ({
...prev,
trialEvents: [...prev.trialEvents, newEvent].slice(-500),
}));
options?.onTrialEvent?.(newEvent);
break;
}
case "trial_action_executed":
case "intervention_logged":
case "step_changed":
setWizardActions((prev) => [...prev, message].slice(-50)); // Keep last 50 actions
case "annotation_added":
case "step_changed": {
setState((prev) => ({
...prev,
wizardActions: [...prev.wizardActions, message].slice(-100),
}));
break;
}
case "step_changed":
// Handle step transitions (optional logging)
console.log("Step changed:", (message as StepChangedMessage).data);
case "pong":
break;
default:
// Handle other trial-specific messages
break;
if (process.env.NODE_ENV === "development") {
console.log(`[WS] Unknown message type: ${message.type}`);
}
}, []);
}
},
[options],
);
const webSocket = useWebSocket({
const webSocket = useGlobalWebSocket({
trialId,
onMessage: handleMessage,
onConnect: () => {
if (process.env.NODE_ENV === "development") {
console.log(`Connected to trial ${trialId} WebSocket`);
console.log(`[WS] Connected to trial ${trialId}`);
}
},
onDisconnect: () => {
if (process.env.NODE_ENV === "development") {
console.log(`Disconnected from trial ${trialId} WebSocket`);
console.log(`[WS] Disconnected from trial ${trialId}`);
}
},
onError: () => {
// Suppress noisy WebSocket errors in development
if (process.env.NODE_ENV !== "development") {
console.error(`Trial ${trialId} WebSocket connection failed`);
console.error(`[WS] Trial ${trialId} WebSocket connection failed`);
}
},
});
// Request trial status after connection is established
// Request initial data after connection is established
useEffect(() => {
if (webSocket.isConnected) {
webSocket.sendMessage({ type: "request_trial_status", data: {} });
webSocket.sendMessage({
type: "request_trial_events",
data: { limit: 500 },
});
}
}, [webSocket.isConnected, webSocket]);
}, [webSocket.isConnected]);
// Trial-specific actions
const executeTrialAction = useCallback(
(actionType: string, actionData: Record<string, unknown>): void => {
webSocket.sendMessage({
type: "trial_action",
data: {
actionType,
...actionData,
},
});
},
[webSocket],
);
// Helper to add an event locally (for optimistic updates)
const addLocalEvent = useCallback((event: TrialEvent) => {
setState((prev) => ({
...prev,
trialEvents: [...prev.trialEvents, event].slice(-500),
}));
}, []);
const logWizardIntervention = useCallback(
(interventionData: Record<string, unknown>): void => {
webSocket.sendMessage({
type: "wizard_intervention",
data: interventionData,
});
},
[webSocket],
);
const transitionStep = useCallback(
(stepData: {
from_step?: number;
to_step: number;
step_name?: string;
[k: string]: unknown;
}): void => {
webSocket.sendMessage({
type: "step_transition",
data: stepData,
});
},
[webSocket],
);
// Helper to update trial status locally
const updateLocalStatus = useCallback((status: TrialSnapshot) => {
setState((prev) => ({
...prev,
currentTrialStatus: status,
}));
}, []);
return {
...webSocket,
trialEvents,
currentTrialStatus,
wizardActions,
executeTrialAction,
logWizardIntervention,
transitionStep,
trialEvents: state.trialEvents,
currentTrialStatus: state.currentTrialStatus,
wizardActions: state.wizardActions,
addLocalEvent,
updateLocalStatus,
};
}

View File

@@ -35,6 +35,7 @@ import { GetObjectCommand } from "@aws-sdk/client-s3";
import { getSignedUrl } from "@aws-sdk/s3-request-presigner";
import { env } from "~/env";
import { uploadFile } from "~/lib/storage/minio";
import { wsManager } from "~/server/services/websocket-manager";
// Helper function to check if user has access to trial
async function checkTrialAccess(
@@ -591,6 +592,16 @@ export const trialsRouter = createTRPCRouter({
data: { userId },
});
// Broadcast trial status update
await wsManager.broadcast(input.id, {
type: "trial_status",
data: {
trial: trial[0],
current_step_index: 0,
timestamp: Date.now(),
},
});
return trial[0];
}),
@@ -643,6 +654,16 @@ export const trialsRouter = createTRPCRouter({
data: { userId, notes: input.notes },
});
// Broadcast trial status update
await wsManager.broadcast(input.id, {
type: "trial_status",
data: {
trial,
current_step_index: 0,
timestamp: Date.now(),
},
});
return trial;
}),
@@ -696,6 +717,16 @@ export const trialsRouter = createTRPCRouter({
data: { userId, reason: input.reason },
});
// Broadcast trial status update
await wsManager.broadcast(input.id, {
type: "trial_status",
data: {
trial: trial[0],
current_step_index: 0,
timestamp: Date.now(),
},
});
return trial[0];
}),
@@ -846,6 +877,15 @@ export const trialsRouter = createTRPCRouter({
})
.returning();
// Broadcast new event to all subscribers
await wsManager.broadcast(input.trialId, {
type: "trial_event",
data: {
event,
timestamp: Date.now(),
},
});
return event;
}),
@@ -881,6 +921,15 @@ export const trialsRouter = createTRPCRouter({
})
.returning();
// Broadcast intervention to all subscribers
await wsManager.broadcast(input.trialId, {
type: "intervention_logged",
data: {
intervention,
timestamp: Date.now(),
},
});
return intervention;
}),
@@ -936,6 +985,15 @@ export const trialsRouter = createTRPCRouter({
});
}
// Broadcast annotation to all subscribers
await wsManager.broadcast(input.trialId, {
type: "annotation_added",
data: {
annotation,
timestamp: Date.now(),
},
});
return annotation;
}),
@@ -1302,10 +1360,12 @@ export const trialsRouter = createTRPCRouter({
}
// Log the manual robot action execution
await db.insert(trialEvents).values({
const [event] = await db
.insert(trialEvents)
.values({
trialId: input.trialId,
eventType: "manual_robot_action",
actionId: null, // Ad-hoc action, not linked to a protocol action definition
actionId: null,
data: {
userId,
pluginName: input.pluginName,
@@ -1316,6 +1376,17 @@ export const trialsRouter = createTRPCRouter({
},
timestamp: new Date(),
createdBy: userId,
})
.returning();
// Broadcast robot action to all subscribers
await wsManager.broadcast(input.trialId, {
type: "trial_action_executed",
data: {
action_type: `${input.pluginName}.${input.actionId}`,
event,
timestamp: Date.now(),
},
});
return {
@@ -1347,7 +1418,9 @@ export const trialsRouter = createTRPCRouter({
"wizard",
]);
await db.insert(trialEvents).values({
const [event] = await db
.insert(trialEvents)
.values({
trialId: input.trialId,
eventType: "manual_robot_action",
data: {
@@ -1362,6 +1435,17 @@ export const trialsRouter = createTRPCRouter({
},
timestamp: new Date(),
createdBy: userId,
})
.returning();
// Broadcast robot action to all subscribers
await wsManager.broadcast(input.trialId, {
type: "trial_action_executed",
data: {
action_type: `${input.pluginName}.${input.actionId}`,
event,
timestamp: Date.now(),
},
});
return { success: true };

View File

@@ -485,6 +485,25 @@ export const trials = createTable("trial", {
metadata: jsonb("metadata").default({}),
});
export const wsConnections = createTable("ws_connection", {
id: uuid("id").notNull().primaryKey().defaultRandom(),
trialId: uuid("trial_id")
.notNull()
.references(() => trials.id, { onDelete: "cascade" }),
clientId: text("client_id").notNull().unique(),
userId: text("user_id"),
connectedAt: timestamp("connected_at", { withTimezone: true })
.default(sql`CURRENT_TIMESTAMP`)
.notNull(),
});
export const wsConnectionsRelations = relations(wsConnections, ({ one }) => ({
trial: one(trials, {
fields: [wsConnections.trialId],
references: [trials.id],
}),
}));
export const steps = createTable(
"step",
{

View File

@@ -0,0 +1,272 @@
import { db } from "~/server/db";
import {
trials,
trialEvents,
wsConnections,
experiments,
} from "~/server/db/schema";
import { eq, sql } from "drizzle-orm";
interface ClientConnection {
socket: WebSocket;
trialId: string;
userId: string | null;
connectedAt: number;
}
type OutgoingMessage = {
type: string;
data: Record<string, unknown>;
};
class WebSocketManager {
private clients: Map<string, ClientConnection> = new Map();
private heartbeatIntervals: Map<string, ReturnType<typeof setInterval>> =
new Map();
private getTrialRoomClients(trialId: string): ClientConnection[] {
const clients: ClientConnection[] = [];
for (const [, client] of this.clients) {
if (client.trialId === trialId) {
clients.push(client);
}
}
return clients;
}
addClient(clientId: string, connection: ClientConnection): void {
this.clients.set(clientId, connection);
console.log(
`[WS] Client ${clientId} added for trial ${connection.trialId}. Total: ${this.clients.size}`,
);
}
removeClient(clientId: string): void {
const client = this.clients.get(clientId);
if (client) {
console.log(
`[WS] Client ${clientId} removed from trial ${client.trialId}`,
);
}
const heartbeatInterval = this.heartbeatIntervals.get(clientId);
if (heartbeatInterval) {
clearInterval(heartbeatInterval);
this.heartbeatIntervals.delete(clientId);
}
this.clients.delete(clientId);
}
async subscribe(
clientId: string,
socket: WebSocket,
trialId: string,
userId: string | null,
): Promise<void> {
const client: ClientConnection = {
socket,
trialId,
userId,
connectedAt: Date.now(),
};
this.clients.set(clientId, client);
const heartbeatInterval = setInterval(() => {
this.sendToClient(clientId, { type: "heartbeat", data: {} });
}, 30000);
this.heartbeatIntervals.set(clientId, heartbeatInterval);
console.log(
`[WS] Client ${clientId} subscribed to trial ${trialId}. Total clients: ${this.clients.size}`,
);
}
unsubscribe(clientId: string): void {
const client = this.clients.get(clientId);
if (client) {
console.log(
`[WS] Client ${clientId} unsubscribed from trial ${client.trialId}`,
);
}
const heartbeatInterval = this.heartbeatIntervals.get(clientId);
if (heartbeatInterval) {
clearInterval(heartbeatInterval);
this.heartbeatIntervals.delete(clientId);
}
this.clients.delete(clientId);
}
sendToClient(clientId: string, message: OutgoingMessage): void {
const client = this.clients.get(clientId);
if (client?.socket.readyState === 1) {
try {
client.socket.send(JSON.stringify(message));
} catch (error) {
console.error(`[WS] Error sending to client ${clientId}:`, error);
this.unsubscribe(clientId);
}
}
}
async broadcast(trialId: string, message: OutgoingMessage): Promise<void> {
const clients = this.getTrialRoomClients(trialId);
if (clients.length === 0) {
return;
}
const messageStr = JSON.stringify(message);
const disconnectedClients: string[] = [];
for (const [clientId, client] of this.clients) {
if (client.trialId === trialId && client.socket.readyState === 1) {
try {
client.socket.send(messageStr);
} catch (error) {
console.error(
`[WS] Error broadcasting to client ${clientId}:`,
error,
);
disconnectedClients.push(clientId);
}
}
}
for (const clientId of disconnectedClients) {
this.unsubscribe(clientId);
}
console.log(
`[WS] Broadcast to ${clients.length} clients for trial ${trialId}: ${message.type}`,
);
}
async broadcastToAll(message: OutgoingMessage): Promise<void> {
const messageStr = JSON.stringify(message);
const disconnectedClients: string[] = [];
for (const [clientId, client] of this.clients) {
if (client.socket.readyState === 1) {
try {
client.socket.send(messageStr);
} catch (error) {
console.error(
`[WS] Error broadcasting to client ${clientId}:`,
error,
);
disconnectedClients.push(clientId);
}
}
}
for (const clientId of disconnectedClients) {
this.unsubscribe(clientId);
}
}
async getTrialStatus(trialId: string): Promise<{
trial: {
id: string;
status: string;
startedAt: Date | null;
completedAt: Date | null;
};
currentStepIndex: number;
} | null> {
const [trial] = await db
.select({
id: trials.id,
status: trials.status,
startedAt: trials.startedAt,
completedAt: trials.completedAt,
})
.from(trials)
.where(eq(trials.id, trialId))
.limit(1);
if (!trial) {
return null;
}
return {
trial: {
id: trial.id,
status: trial.status,
startedAt: trial.startedAt,
completedAt: trial.completedAt,
},
currentStepIndex: 0,
};
}
async getTrialEvents(
trialId: string,
limit: number = 100,
): Promise<unknown[]> {
const events = await db
.select()
.from(trialEvents)
.where(eq(trialEvents.trialId, trialId))
.orderBy(trialEvents.timestamp)
.limit(limit);
return events;
}
getTrialStatusSync(trialId: string): {
trial: {
id: string;
status: string;
startedAt: Date | null;
completedAt: Date | null;
};
currentStepIndex: number;
} | null {
return null;
}
getTrialEventsSync(trialId: string, limit: number = 100): unknown[] {
return [];
}
getConnectionCount(trialId?: string): number {
if (trialId) {
return this.getTrialRoomClients(trialId).length;
}
return this.clients.size;
}
getConnectedTrialIds(): string[] {
const trialIds = new Set<string>();
for (const [, client] of this.clients) {
trialIds.add(client.trialId);
}
return Array.from(trialIds);
}
async getTrialsWithActiveConnections(studyIds?: string[]): Promise<string[]> {
const conditions =
studyIds && studyIds.length > 0
? sql`${wsConnections.trialId} IN (
SELECT ${trials.id} FROM ${trials}
WHERE ${trials.experimentId} IN (
SELECT ${experiments.id} FROM ${experiments}
WHERE ${experiments.studyId} IN (${sql.raw(studyIds.map((id) => `'${id}'`).join(","))})
)
)`
: undefined;
const connections = await db
.selectDistinct({ trialId: wsConnections.trialId })
.from(wsConnections);
return connections.map((c) => c.trialId);
}
}
export const wsManager = new WebSocketManager();

192
ws-server.ts Normal file
View File

@@ -0,0 +1,192 @@
import { serve, type ServerWebSocket } from "bun";
import { wsManager } from "./src/server/services/websocket-manager";
import { db } from "./src/server/db";
import { wsConnections } from "./src/server/db/schema";
import { eq } from "drizzle-orm";
const port = parseInt(process.env.WS_PORT || "3001", 10);
interface WSData {
clientId: string;
trialId: string;
userId: string | null;
}
function generateClientId(): string {
return `ws_${Date.now()}_${Math.random().toString(36).substring(2, 9)}`;
}
async function recordConnection(
clientId: string,
trialId: string,
userId: string | null,
): Promise<void> {
try {
await db.insert(wsConnections).values({
clientId,
trialId,
userId,
});
console.log(`[DB] Recorded connection for trial ${trialId}`);
} catch (error) {
console.error(`[DB] Failed to record connection:`, error);
}
}
async function removeConnection(clientId: string): Promise<void> {
try {
await db.delete(wsConnections).where(eq(wsConnections.clientId, clientId));
console.log(`[DB] Removed connection ${clientId}`);
} catch (error) {
console.error(`[DB] Failed to remove connection:`, error);
}
}
console.log(`Starting WebSocket server on port ${port}...`);
serve<WSData>({
port,
fetch(req, server) {
const url = new URL(req.url);
if (url.pathname === "/api/websocket") {
if (req.headers.get("upgrade") !== "websocket") {
return new Response("WebSocket upgrade required", { status: 426 });
}
const trialId = url.searchParams.get("trialId");
const token = url.searchParams.get("token");
if (!trialId) {
return new Response("Missing trialId parameter", { status: 400 });
}
let userId: string | null = null;
if (token) {
try {
const tokenData = JSON.parse(atob(token));
userId = tokenData.userId;
} catch {
return new Response("Invalid token", { status: 401 });
}
}
const clientId = generateClientId();
const wsData: WSData = { clientId, trialId, userId };
const upgraded = server.upgrade(req, { data: wsData });
if (!upgraded) {
return new Response("WebSocket upgrade failed", { status: 500 });
}
return;
}
return new Response("Not found", { status: 404 });
},
websocket: {
async open(ws: ServerWebSocket<WSData>) {
const { clientId, trialId, userId } = ws.data;
wsManager.addClient(clientId, {
socket: ws as unknown as WebSocket,
trialId,
userId,
connectedAt: Date.now(),
});
await recordConnection(clientId, trialId, userId);
console.log(
`[WS] Client ${clientId} connected to trial ${trialId}. Total: ${wsManager.getConnectionCount()}`,
);
ws.send(
JSON.stringify({
type: "connection_established",
data: {
trialId,
userId,
role: "connected",
connectedAt: Date.now(),
},
}),
);
},
message(ws: ServerWebSocket<WSData>, message) {
const { clientId, trialId } = ws.data;
try {
const msg = JSON.parse(message.toString());
switch (msg.type) {
case "heartbeat":
ws.send(
JSON.stringify({
type: "heartbeat_response",
data: { timestamp: Date.now() },
}),
);
break;
case "request_trial_status": {
const status = wsManager.getTrialStatusSync(trialId);
ws.send(
JSON.stringify({
type: "trial_status",
data: {
trial: status?.trial ?? null,
current_step_index: status?.currentStepIndex ?? 0,
timestamp: Date.now(),
},
}),
);
break;
}
case "request_trial_events": {
const events = wsManager.getTrialEventsSync(
trialId,
msg.data?.limit ?? 100,
);
ws.send(
JSON.stringify({
type: "trial_events_snapshot",
data: { events, timestamp: Date.now() },
}),
);
break;
}
case "ping":
ws.send(
JSON.stringify({
type: "pong",
data: { timestamp: Date.now() },
}),
);
break;
default:
console.log(
`[WS] Unknown message type from ${clientId}:`,
msg.type,
);
}
} catch (error) {
console.error(`[WS] Error processing message from ${clientId}:`, error);
}
},
async close(ws: ServerWebSocket<WSData>) {
const { clientId } = ws.data;
console.log(`[WS] Client ${clientId} disconnected`);
wsManager.removeClient(clientId);
await removeConnection(clientId);
},
},
});
console.log(
`> WebSocket server running on ws://localhost:${port}/api/websocket`,
);